From f5e4d88e6894fe587c05b28ecd61b772d4575dab Mon Sep 17 00:00:00 2001 From: Amish Garg Date: Fri, 16 Jul 2021 03:37:08 +0530 Subject: [PATCH] add TensorAudio and recording support --- android/build.gradle | 2 - android/src/main/AndroidManifest.xml | 1 + .../TfliteFlutterHelperPlugin.kt | 293 ++++++++++++++++-- lib/src/audio/recorder_stream.dart | 70 +++++ lib/src/audio/sound_stream.dart | 40 +++ lib/src/audio/tensor_audio.dart | 173 +++++++++++ lib/tflite_flutter_helper.dart | 3 + pubspec.lock | 2 +- 8 files changed, 556 insertions(+), 28 deletions(-) create mode 100644 lib/src/audio/recorder_stream.dart create mode 100644 lib/src/audio/sound_stream.dart create mode 100644 lib/src/audio/tensor_audio.dart diff --git a/android/build.gradle b/android/build.gradle index ae39ca5..165dd0b 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -1,5 +1,3 @@ -package android - group 'com.tfliteflutter.tflite_flutter_helper' version '1.0-SNAPSHOT' diff --git a/android/src/main/AndroidManifest.xml b/android/src/main/AndroidManifest.xml index fd29b5d..90f3d5b 100644 --- a/android/src/main/AndroidManifest.xml +++ b/android/src/main/AndroidManifest.xml @@ -1,3 +1,4 @@ + diff --git a/android/src/main/kotlin/com/tfliteflutter/tflite_flutter_helper/TfliteFlutterHelperPlugin.kt b/android/src/main/kotlin/com/tfliteflutter/tflite_flutter_helper/TfliteFlutterHelperPlugin.kt index 62ae201..2e751d5 100644 --- a/android/src/main/kotlin/com/tfliteflutter/tflite_flutter_helper/TfliteFlutterHelperPlugin.kt +++ b/android/src/main/kotlin/com/tfliteflutter/tflite_flutter_helper/TfliteFlutterHelperPlugin.kt @@ -1,35 +1,278 @@ package com.tfliteflutter.tflite_flutter_helper -import androidx.annotation.NonNull - import io.flutter.embedding.engine.plugins.FlutterPlugin import io.flutter.plugin.common.MethodCall import io.flutter.plugin.common.MethodChannel import io.flutter.plugin.common.MethodChannel.MethodCallHandler import io.flutter.plugin.common.MethodChannel.Result +import android.Manifest +import android.app.Activity +import android.content.Context +import android.content.pm.PackageManager +import android.media.* +import android.media.AudioRecord.OnRecordPositionUpdateListener +import android.util.Log +import androidx.annotation.NonNull +import androidx.core.app.ActivityCompat +import androidx.core.content.ContextCompat +import io.flutter.embedding.engine.plugins.activity.ActivityAware +import io.flutter.embedding.engine.plugins.activity.ActivityPluginBinding +import io.flutter.plugin.common.PluginRegistry +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.ShortBuffer + +enum class SoundStreamErrors { + FailedToRecord, + FailedToPlay, + FailedToStop, + FailedToWriteBuffer, + Unknown, +} + +enum class SoundStreamStatus { + Unset, + Initialized, + Playing, + Stopped, +} + +const val methodChannelName = "com.tfliteflutter.tflite_flutter_helper:methods" /** TfliteFlutterHelperPlugin */ -class TfliteFlutterHelperPlugin: FlutterPlugin, MethodCallHandler { - /// The MethodChannel that will the communication between Flutter and native Android - /// - /// This local reference serves to register the plugin with the Flutter Engine and unregister it - /// when the Flutter Engine is detached from the Activity - private lateinit var channel : MethodChannel - - override fun onAttachedToEngine(@NonNull flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { - channel = MethodChannel(flutterPluginBinding.binaryMessenger, "tflite_flutter_helper") - channel.setMethodCallHandler(this) - } - - override fun onMethodCall(@NonNull call: MethodCall, @NonNull result: Result) { - if (call.method == "getPlatformVersion") { - result.success("Android ${android.os.Build.VERSION.RELEASE}") - } else { - result.notImplemented() - } - } - - override fun onDetachedFromEngine(@NonNull binding: FlutterPlugin.FlutterPluginBinding) { - channel.setMethodCallHandler(null) - } +class TfliteFlutterHelperPlugin : FlutterPlugin, + MethodCallHandler, + PluginRegistry.RequestPermissionsResultListener, + ActivityAware { + + private val logTag = "TfLiteFlutterHelperPlugin" + private val audioRecordPermissionCode = 14887 + + private lateinit var methodChannel: MethodChannel + private var currentActivity: Activity? = null + private var pluginContext: Context? = null + private var permissionToRecordAudio: Boolean = false + private var activeResult: Result? = null + private var debugLogging: Boolean = false + + //========= Recorder's vars + private val mRecordFormat = AudioFormat.ENCODING_PCM_16BIT + private var mRecordSampleRate = 16000 // 16Khz + private var mRecorderBufferSize = 8192 + private var mPeriodFrames = 8192 + private var audioData: ShortArray? = null + private var mRecorder: AudioRecord? = null + private var mListener: OnRecordPositionUpdateListener? = null + + + override fun onAttachedToEngine(@NonNull flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) { + pluginContext = flutterPluginBinding.applicationContext + methodChannel = MethodChannel(flutterPluginBinding.binaryMessenger, methodChannelName) + methodChannel.setMethodCallHandler(this) + } + + override fun onMethodCall(@NonNull call: MethodCall, @NonNull result: Result) { + try { + when (call.method) { + "hasPermission" -> hasPermission(result) + "initializeRecorder" -> initializeRecorder(call, result) + "startRecording" -> startRecording(result) + "stopRecording" -> stopRecording(result) + else -> result.notImplemented() + } + } catch (e: Exception) { + Log.e(logTag, "Unexpected exception", e) + // TODO: implement result.error + } + } + + override fun onDetachedFromEngine(@NonNull binding: FlutterPlugin.FlutterPluginBinding) { + methodChannel.setMethodCallHandler(null) + mListener?.onMarkerReached(null) + mListener?.onPeriodicNotification(null) + mListener = null + mRecorder?.stop() + mRecorder?.release() + mRecorder = null + } + + override fun onDetachedFromActivity() { +// currentActivity + } + + override fun onReattachedToActivityForConfigChanges(binding: ActivityPluginBinding) { + currentActivity = binding.activity + binding.addRequestPermissionsResultListener(this) + } + + override fun onAttachedToActivity(binding: ActivityPluginBinding) { + currentActivity = binding.activity + binding.addRequestPermissionsResultListener(this) + } + + override fun onDetachedFromActivityForConfigChanges() { +// currentActivity = null + } + + /** ======== Plugin methods ======== **/ + + private fun hasRecordPermission(): Boolean { + if (permissionToRecordAudio) return true + + val localContext = pluginContext + permissionToRecordAudio = localContext != null && ContextCompat.checkSelfPermission(localContext, + Manifest.permission.RECORD_AUDIO) == PackageManager.PERMISSION_GRANTED + return permissionToRecordAudio + + } + + private fun hasPermission(result: Result) { + result.success(hasRecordPermission()) + } + + private fun requestRecordPermission() { + val localActivity = currentActivity + if (!hasRecordPermission() && localActivity != null) { + debugLog("requesting RECORD_AUDIO permission") + ActivityCompat.requestPermissions(localActivity, + arrayOf(Manifest.permission.RECORD_AUDIO), audioRecordPermissionCode) + } + } + + override fun onRequestPermissionsResult(requestCode: Int, permissions: Array?, + grantResults: IntArray?): Boolean { + when (requestCode) { + audioRecordPermissionCode -> { + if (grantResults != null) { + permissionToRecordAudio = grantResults.isNotEmpty() && + grantResults[0] == PackageManager.PERMISSION_GRANTED + } + completeInitializeRecorder() + return true + } + } + return false + } + + private fun initializeRecorder(@NonNull call: MethodCall, @NonNull result: Result) { + mRecordSampleRate = call.argument("sampleRate") ?: mRecordSampleRate + debugLogging = call.argument("showLogs") ?: false + mPeriodFrames = AudioRecord.getMinBufferSize(mRecordSampleRate, AudioFormat.CHANNEL_IN_MONO, mRecordFormat) + mRecorderBufferSize = mPeriodFrames * 2 + audioData = ShortArray(mPeriodFrames) + activeResult = result + + val localContext = pluginContext + if (null == localContext) { + completeInitializeRecorder() + return + } + permissionToRecordAudio = ContextCompat.checkSelfPermission(localContext, + Manifest.permission.RECORD_AUDIO) == PackageManager.PERMISSION_GRANTED + if (!permissionToRecordAudio) { + requestRecordPermission() + } else { + debugLog("has permission, completing") + completeInitializeRecorder() + } + debugLog("leaving initializeIfPermitted") + } + + private fun initRecorder() { + if (mRecorder?.state == AudioRecord.STATE_INITIALIZED) { + return + } + mRecorder = AudioRecord(MediaRecorder.AudioSource.MIC, mRecordSampleRate, AudioFormat.CHANNEL_IN_MONO, mRecordFormat, mRecorderBufferSize) + if (mRecorder != null) { + mListener = createRecordListener() + mRecorder?.positionNotificationPeriod = mPeriodFrames + mRecorder?.setRecordPositionUpdateListener(mListener) + } + } + + private fun completeInitializeRecorder() { + + debugLog("completeInitialize") + val initResult: HashMap = HashMap() + + if (permissionToRecordAudio) { + mRecorder?.release() + initRecorder() + initResult["isMeteringEnabled"] = true + sendRecorderStatus(SoundStreamStatus.Initialized) + } + + initResult["success"] = permissionToRecordAudio + debugLog("sending result") + activeResult?.success(initResult) + debugLog("leaving complete") + activeResult = null + } + + private fun sendEventMethod(name: String, data: Any) { + val eventData: HashMap = HashMap() + eventData["name"] = name + eventData["data"] = data + methodChannel.invokeMethod("platformEvent", eventData) + } + + private fun debugLog(msg: String) { + if (debugLogging) { + Log.d(logTag, msg) + } + } + + private fun startRecording(result: Result) { + try { + if (mRecorder!!.recordingState == AudioRecord.RECORDSTATE_RECORDING) { + result.success(true) + return + } + initRecorder() + mRecorder!!.startRecording() + sendRecorderStatus(SoundStreamStatus.Playing) + result.success(true) + } catch (e: IllegalStateException) { + debugLog("record() failed") + result.error(SoundStreamErrors.FailedToRecord.name, "Failed to start recording", e.localizedMessage) + } + } + + private fun stopRecording(result: Result) { + try { + if (mRecorder!!.recordingState == AudioRecord.RECORDSTATE_STOPPED) { + result.success(true) + return + } + mRecorder!!.stop() + sendRecorderStatus(SoundStreamStatus.Stopped) + result.success(true) + } catch (e: IllegalStateException) { + debugLog("record() failed") + result.error(SoundStreamErrors.FailedToRecord.name, "Failed to start recording", e.localizedMessage) + } + } + + private fun sendRecorderStatus(status: SoundStreamStatus) { + sendEventMethod("recorderStatus", status.name) + } + + private fun createRecordListener(): OnRecordPositionUpdateListener? { + return object : OnRecordPositionUpdateListener { + override fun onMarkerReached(recorder: AudioRecord) { + recorder.read(audioData!!, 0, mRecorderBufferSize) + } + + override fun onPeriodicNotification(recorder: AudioRecord) { + val data = audioData!! + val shortOut = recorder.read(data, 0, mPeriodFrames) + // https://flutter.io/platform-channels/#codec + // convert short to int because of platform-channel's limitation + val byteBuffer = ByteBuffer.allocate(shortOut * 2) + byteBuffer.order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().put(data) + + sendEventMethod("dataPeriod", byteBuffer.array()) + } + } + } } diff --git a/lib/src/audio/recorder_stream.dart b/lib/src/audio/recorder_stream.dart new file mode 100644 index 0000000..1a9088d --- /dev/null +++ b/lib/src/audio/recorder_stream.dart @@ -0,0 +1,70 @@ +import 'dart:async'; +import 'dart:typed_data'; +import 'sound_stream.dart'; + +class RecorderStream { + static final RecorderStream _instance = RecorderStream._internal(); + factory RecorderStream() => _instance; + + final _audioStreamController = StreamController.broadcast(); + + final _recorderStatusController = + StreamController.broadcast(); + + RecorderStream._internal() { + SoundStream(); + eventsStreamController.stream.listen(_eventListener); + _recorderStatusController.add(SoundStreamStatus.Unset); + _audioStreamController.add(Uint8List(0)); + } + + /// Initialize Recorder with specified [sampleRate] + Future initialize({int sampleRate = 16000, bool showLogs = false}) => + methodChannel.invokeMethod("initializeRecorder", { + "sampleRate": sampleRate, + "showLogs": showLogs, + }); + + /// Start recording. Recorder will start pushing audio chunks (PCM 16bit data) + /// to audiostream as Uint8List + Future start() => + methodChannel.invokeMethod("startRecording"); + + /// Recorder will stop recording and sending audio chunks to the [audioStream]. + Future stop() => + methodChannel.invokeMethod("stopRecording"); + + /// Current status of the [RecorderStream] + Stream get status => _recorderStatusController.stream; + + /// Stream of PCM 16bit data from Microphone + Stream get audioStream => _audioStreamController.stream; + + void _eventListener(dynamic event) { + if (event == null) return; + final String eventName = event["name"] ?? ""; + switch (eventName) { + case "dataPeriod": + final Uint8List audioData = + Uint8List.fromList(event["data"] ?? []); + if (audioData.isNotEmpty) _audioStreamController.add(audioData); + break; + case "recorderStatus": + final String status = event["data"] ?? "Unset"; + _recorderStatusController.add(SoundStreamStatus.values.firstWhere( + (value) => enumToString(value) == status, + orElse: () => SoundStreamStatus.Unset, + )); + break; + } + } + + /// Stop and close all streams. This cannot be undone + /// Only call this method if you don't want to use this anymore + void dispose() { + stop(); + eventsStreamController.close(); + _recorderStatusController.close(); + _audioStreamController.close(); + } +} diff --git a/lib/src/audio/sound_stream.dart b/lib/src/audio/sound_stream.dart new file mode 100644 index 0000000..1b870ed --- /dev/null +++ b/lib/src/audio/sound_stream.dart @@ -0,0 +1,40 @@ +import 'dart:async'; + +import 'package:flutter/services.dart'; +import 'recorder_stream.dart'; + +const methodChannelName = 'com.tfliteflutter.tflite_flutter_helper:methods'; + +const MethodChannel methodChannel = MethodChannel(methodChannelName); + +final eventsStreamController = StreamController.broadcast(); + +enum SoundStreamStatus { + Unset, + Initialized, + Playing, + Stopped, +} + +class SoundStream { + static final SoundStream _instance = SoundStream._internal(); + factory SoundStream() => _instance; + SoundStream._internal() { + methodChannel.setMethodCallHandler(_onMethodCall); + } + + /// Return [RecorderStream] instance (Singleton). + RecorderStream get recorder => RecorderStream(); + + + Future _onMethodCall(MethodCall call) async { + switch (call.method) { + case "platformEvent": + eventsStreamController.add(call.arguments); + break; + } + return null; + } +} + +String enumToString(Object o) => o.toString().split('.').last; diff --git a/lib/src/audio/tensor_audio.dart b/lib/src/audio/tensor_audio.dart new file mode 100644 index 0000000..f197f2f --- /dev/null +++ b/lib/src/audio/tensor_audio.dart @@ -0,0 +1,173 @@ +import 'dart:math'; +import 'dart:typed_data'; + +import 'package:quiver/check.dart'; +import 'package:tflite_flutter/tflite_flutter.dart'; +import 'package:tflite_flutter_helper/tflite_flutter_helper.dart'; + +class TensorAudio { + final String TAG = "TensorAudioDart"; + late final FloatRingBuffer buffer; + late final TensorAudioFormat format; + + TensorAudio._(this.format, int sampleCount){ + this.buffer = FloatRingBuffer._(sampleCount * format.channelCount); + } + + static TensorAudio create(TensorAudioFormat format, int sampleCount) { + return TensorAudio._(format, sampleCount); + } + + void loadDoubleList(List src) { + loadDoubleListOffset(src, 0, src.length); + } + + void loadDoubleListOffset(List src, int offsetInFloat, + int sizeInFloat) { + checkArgument( + sizeInFloat % format.channelCount == 0, + message: + "Size ($sizeInFloat) needs to be a multiplier of the number of channels (${format + .channelCount})", + ); + buffer.loadOffset(src, offsetInFloat, sizeInFloat); + } + + void loadShortBytes(Uint8List shortBytes) { + ByteData byteData = ByteData.sublistView(shortBytes); + List shortList = []; + for (int i = 0; i < byteData.lengthInBytes; i += 2) { + shortList.add(byteData.getInt16(i, Endian.little)); + } + loadList(shortList); + } + + void loadFloatBytes(Uint8List floatBytes) { + ByteData byteData = ByteData.sublistView(floatBytes); + List doubleList = []; + for (int i = 0; i < byteData.lengthInBytes; i += 4) { + doubleList.add(byteData.getFloat32(i, Endian.little)); + } + loadDoubleList(doubleList); + } + + void loadList(List src) { + loadListOffset(src, 0, src.length); + } + + void loadListOffset(List src, int offsetInShort, int sizeInShort) { + checkArgument( + offsetInShort + sizeInShort <= src.length, + message: "Index out of range. offset ($offsetInShort) + size ($sizeInShort) should <= newData.length (${src + .length})"); + List floatData = List.filled(sizeInShort, 0.0); + for (int i = offsetInShort; i < sizeInShort; i++) { + // Convert the data to PCM Float encoding i.e. values between -1 and 1 + floatData[i] = src[i] / (pow(2, 15) - 1); + } + loadDoubleList(floatData); + } + + + /// Returns a float {@link TensorBuffer} holding all the available audio samples in {@link + /// android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1]. + TensorBuffer get tensorBuffer { + ByteBuffer byteBuffer = buffer.buffer; + TensorBuffer tensorBuffer = + // TODO: Confirm Shape + TensorBuffer.createFixedSize( + [1, byteBuffer + .asFloat32List() + .length + ], + TfLiteType.float32); + tensorBuffer.loadBuffer(byteBuffer); + return tensorBuffer; + } + + /// Returns the {@link TensorAudioFormat} associated with the tensor. + // TODO: Rename + TensorAudioFormat get gformat { + return format; + } +} + +// TODO: Update documentation according to flutter +/// Wraps a few constants describing the format of the incoming audio samples, namely number of +/// channels and the sample rate. By default, channels is set to 1. +class TensorAudioFormat { + static const int DEFAULT_CHANNELS = 1; + late final int _channelCount; + late final int _sampleRate; + + TensorAudioFormat._(this._channelCount, this._sampleRate); + + static TensorAudioFormat create(int channelCount, int sampleRate) { + checkArgument(channelCount > 0, + message: "Number of channels should be greater than 0"); + checkArgument( + sampleRate > 0, message: "Sample rate should be greater than 0"); + return TensorAudioFormat._(channelCount, sampleRate); + } + + int get channelCount => _channelCount; + + int get sampleRate => _sampleRate; +} + +/// Actual implementation of the ring buffer. */ +class FloatRingBuffer { + late final List _buffer; + int _nextIndex = 0; + + FloatRingBuffer._(int flatSize) { + _buffer = List.filled(flatSize, 0.0); + } + + /// Loads the entire float array to the ring buffer. If the float array is longer than ring + /// buffer's capacity, samples with lower indicies in the array will be ignored. + void load(List newData) { + loadOffset(newData, 0, newData.length); + } + + /// Loads a slice of the float array to the ring buffer. If the float array is longer than ring + /// buffer's capacity, samples with lower indicies in the array will be ignored. + void loadOffset(List newData, int offset, int size) { + checkArgument( + offset + size <= newData.length, + message: + "Index out of range. offset ($offset) + size ($size) should <= newData.length (${newData + .length})", + ); + // If buffer can't hold all the data, only keep the most recent data of size buffer.length + if (size > _buffer.length) { + offset = size - _buffer.length; + size = _buffer.length; + } + if (_nextIndex + size < _buffer.length) { + // No need to wrap nextIndex, just copy newData[offset:offset + size] + // to buffer[nextIndex:nextIndex+size] + List.copyRange(_buffer, _nextIndex, newData, offset, offset + size); + } else { + // Need to wrap nextIndex, perform copy in two chunks. + int firstChunkSize = _buffer.length - _nextIndex; + // First copy newData[offset:offset+firstChunkSize] to buffer[nextIndex:buffer.length] + List.copyRange( + _buffer, _nextIndex, newData, offset, offset + firstChunkSize); + // Then copy newData[offset+firstChunkSize:offset+size] to buffer[0:size-firstChunkSize] + List.copyRange( + _buffer, 0, newData, offset + firstChunkSize, offset + size); + } + + _nextIndex = (_nextIndex + size) % _buffer.length; + } + + ByteBuffer get buffer { + // TODO: Make sure there is no endianness issue + return Float32List + .fromList(_buffer) + .buffer; + } + + int get capacity => _buffer.length; +} diff --git a/lib/tflite_flutter_helper.dart b/lib/tflite_flutter_helper.dart index c97f7b6..f4a5300 100644 --- a/lib/tflite_flutter_helper.dart +++ b/lib/tflite_flutter_helper.dart @@ -36,3 +36,6 @@ export 'src/label/tensor_label.dart'; export 'src/tensorbuffer/tensorbuffer.dart'; export 'src/tensorbuffer/tensorbufferfloat.dart'; export 'src/tensorbuffer/tensorbufferuint8.dart'; +export 'src/audio/recorder_stream.dart'; +export 'src/audio/sound_stream.dart'; +export 'src/audio/tensor_audio.dart'; diff --git a/pubspec.lock b/pubspec.lock index abb57da..8b748ce 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -239,7 +239,7 @@ packages: name: tflite_flutter url: "https://pub.dartlang.org" source: hosted - version: "0.8.0" + version: "0.9.0" tuple: dependency: "direct main" description: