diff --git a/android/build.gradle b/android/build.gradle
index ae39ca5..165dd0b 100644
--- a/android/build.gradle
+++ b/android/build.gradle
@@ -1,5 +1,3 @@
-package android
-
group 'com.tfliteflutter.tflite_flutter_helper'
version '1.0-SNAPSHOT'
diff --git a/android/src/main/AndroidManifest.xml b/android/src/main/AndroidManifest.xml
index fd29b5d..90f3d5b 100644
--- a/android/src/main/AndroidManifest.xml
+++ b/android/src/main/AndroidManifest.xml
@@ -1,3 +1,4 @@
+
diff --git a/android/src/main/kotlin/com/tfliteflutter/tflite_flutter_helper/TfliteFlutterHelperPlugin.kt b/android/src/main/kotlin/com/tfliteflutter/tflite_flutter_helper/TfliteFlutterHelperPlugin.kt
index 62ae201..2e751d5 100644
--- a/android/src/main/kotlin/com/tfliteflutter/tflite_flutter_helper/TfliteFlutterHelperPlugin.kt
+++ b/android/src/main/kotlin/com/tfliteflutter/tflite_flutter_helper/TfliteFlutterHelperPlugin.kt
@@ -1,35 +1,278 @@
package com.tfliteflutter.tflite_flutter_helper
-import androidx.annotation.NonNull
-
import io.flutter.embedding.engine.plugins.FlutterPlugin
import io.flutter.plugin.common.MethodCall
import io.flutter.plugin.common.MethodChannel
import io.flutter.plugin.common.MethodChannel.MethodCallHandler
import io.flutter.plugin.common.MethodChannel.Result
+import android.Manifest
+import android.app.Activity
+import android.content.Context
+import android.content.pm.PackageManager
+import android.media.*
+import android.media.AudioRecord.OnRecordPositionUpdateListener
+import android.util.Log
+import androidx.annotation.NonNull
+import androidx.core.app.ActivityCompat
+import androidx.core.content.ContextCompat
+import io.flutter.embedding.engine.plugins.activity.ActivityAware
+import io.flutter.embedding.engine.plugins.activity.ActivityPluginBinding
+import io.flutter.plugin.common.PluginRegistry
+import java.nio.ByteBuffer
+import java.nio.ByteOrder
+import java.nio.ShortBuffer
+
+enum class SoundStreamErrors {
+ FailedToRecord,
+ FailedToPlay,
+ FailedToStop,
+ FailedToWriteBuffer,
+ Unknown,
+}
+
+enum class SoundStreamStatus {
+ Unset,
+ Initialized,
+ Playing,
+ Stopped,
+}
+
+const val methodChannelName = "com.tfliteflutter.tflite_flutter_helper:methods"
/** TfliteFlutterHelperPlugin */
-class TfliteFlutterHelperPlugin: FlutterPlugin, MethodCallHandler {
- /// The MethodChannel that will the communication between Flutter and native Android
- ///
- /// This local reference serves to register the plugin with the Flutter Engine and unregister it
- /// when the Flutter Engine is detached from the Activity
- private lateinit var channel : MethodChannel
-
- override fun onAttachedToEngine(@NonNull flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) {
- channel = MethodChannel(flutterPluginBinding.binaryMessenger, "tflite_flutter_helper")
- channel.setMethodCallHandler(this)
- }
-
- override fun onMethodCall(@NonNull call: MethodCall, @NonNull result: Result) {
- if (call.method == "getPlatformVersion") {
- result.success("Android ${android.os.Build.VERSION.RELEASE}")
- } else {
- result.notImplemented()
- }
- }
-
- override fun onDetachedFromEngine(@NonNull binding: FlutterPlugin.FlutterPluginBinding) {
- channel.setMethodCallHandler(null)
- }
+class TfliteFlutterHelperPlugin : FlutterPlugin,
+ MethodCallHandler,
+ PluginRegistry.RequestPermissionsResultListener,
+ ActivityAware {
+
+ private val logTag = "TfLiteFlutterHelperPlugin"
+ private val audioRecordPermissionCode = 14887
+
+ private lateinit var methodChannel: MethodChannel
+ private var currentActivity: Activity? = null
+ private var pluginContext: Context? = null
+ private var permissionToRecordAudio: Boolean = false
+ private var activeResult: Result? = null
+ private var debugLogging: Boolean = false
+
+ //========= Recorder's vars
+ private val mRecordFormat = AudioFormat.ENCODING_PCM_16BIT
+ private var mRecordSampleRate = 16000 // 16Khz
+ private var mRecorderBufferSize = 8192
+ private var mPeriodFrames = 8192
+ private var audioData: ShortArray? = null
+ private var mRecorder: AudioRecord? = null
+ private var mListener: OnRecordPositionUpdateListener? = null
+
+
+ override fun onAttachedToEngine(@NonNull flutterPluginBinding: FlutterPlugin.FlutterPluginBinding) {
+ pluginContext = flutterPluginBinding.applicationContext
+ methodChannel = MethodChannel(flutterPluginBinding.binaryMessenger, methodChannelName)
+ methodChannel.setMethodCallHandler(this)
+ }
+
+ override fun onMethodCall(@NonNull call: MethodCall, @NonNull result: Result) {
+ try {
+ when (call.method) {
+ "hasPermission" -> hasPermission(result)
+ "initializeRecorder" -> initializeRecorder(call, result)
+ "startRecording" -> startRecording(result)
+ "stopRecording" -> stopRecording(result)
+ else -> result.notImplemented()
+ }
+ } catch (e: Exception) {
+ Log.e(logTag, "Unexpected exception", e)
+ // TODO: implement result.error
+ }
+ }
+
+ override fun onDetachedFromEngine(@NonNull binding: FlutterPlugin.FlutterPluginBinding) {
+ methodChannel.setMethodCallHandler(null)
+ mListener?.onMarkerReached(null)
+ mListener?.onPeriodicNotification(null)
+ mListener = null
+ mRecorder?.stop()
+ mRecorder?.release()
+ mRecorder = null
+ }
+
+ override fun onDetachedFromActivity() {
+// currentActivity
+ }
+
+ override fun onReattachedToActivityForConfigChanges(binding: ActivityPluginBinding) {
+ currentActivity = binding.activity
+ binding.addRequestPermissionsResultListener(this)
+ }
+
+ override fun onAttachedToActivity(binding: ActivityPluginBinding) {
+ currentActivity = binding.activity
+ binding.addRequestPermissionsResultListener(this)
+ }
+
+ override fun onDetachedFromActivityForConfigChanges() {
+// currentActivity = null
+ }
+
+ /** ======== Plugin methods ======== **/
+
+ private fun hasRecordPermission(): Boolean {
+ if (permissionToRecordAudio) return true
+
+ val localContext = pluginContext
+ permissionToRecordAudio = localContext != null && ContextCompat.checkSelfPermission(localContext,
+ Manifest.permission.RECORD_AUDIO) == PackageManager.PERMISSION_GRANTED
+ return permissionToRecordAudio
+
+ }
+
+ private fun hasPermission(result: Result) {
+ result.success(hasRecordPermission())
+ }
+
+ private fun requestRecordPermission() {
+ val localActivity = currentActivity
+ if (!hasRecordPermission() && localActivity != null) {
+ debugLog("requesting RECORD_AUDIO permission")
+ ActivityCompat.requestPermissions(localActivity,
+ arrayOf(Manifest.permission.RECORD_AUDIO), audioRecordPermissionCode)
+ }
+ }
+
+ override fun onRequestPermissionsResult(requestCode: Int, permissions: Array?,
+ grantResults: IntArray?): Boolean {
+ when (requestCode) {
+ audioRecordPermissionCode -> {
+ if (grantResults != null) {
+ permissionToRecordAudio = grantResults.isNotEmpty() &&
+ grantResults[0] == PackageManager.PERMISSION_GRANTED
+ }
+ completeInitializeRecorder()
+ return true
+ }
+ }
+ return false
+ }
+
+ private fun initializeRecorder(@NonNull call: MethodCall, @NonNull result: Result) {
+ mRecordSampleRate = call.argument("sampleRate") ?: mRecordSampleRate
+ debugLogging = call.argument("showLogs") ?: false
+ mPeriodFrames = AudioRecord.getMinBufferSize(mRecordSampleRate, AudioFormat.CHANNEL_IN_MONO, mRecordFormat)
+ mRecorderBufferSize = mPeriodFrames * 2
+ audioData = ShortArray(mPeriodFrames)
+ activeResult = result
+
+ val localContext = pluginContext
+ if (null == localContext) {
+ completeInitializeRecorder()
+ return
+ }
+ permissionToRecordAudio = ContextCompat.checkSelfPermission(localContext,
+ Manifest.permission.RECORD_AUDIO) == PackageManager.PERMISSION_GRANTED
+ if (!permissionToRecordAudio) {
+ requestRecordPermission()
+ } else {
+ debugLog("has permission, completing")
+ completeInitializeRecorder()
+ }
+ debugLog("leaving initializeIfPermitted")
+ }
+
+ private fun initRecorder() {
+ if (mRecorder?.state == AudioRecord.STATE_INITIALIZED) {
+ return
+ }
+ mRecorder = AudioRecord(MediaRecorder.AudioSource.MIC, mRecordSampleRate, AudioFormat.CHANNEL_IN_MONO, mRecordFormat, mRecorderBufferSize)
+ if (mRecorder != null) {
+ mListener = createRecordListener()
+ mRecorder?.positionNotificationPeriod = mPeriodFrames
+ mRecorder?.setRecordPositionUpdateListener(mListener)
+ }
+ }
+
+ private fun completeInitializeRecorder() {
+
+ debugLog("completeInitialize")
+ val initResult: HashMap = HashMap()
+
+ if (permissionToRecordAudio) {
+ mRecorder?.release()
+ initRecorder()
+ initResult["isMeteringEnabled"] = true
+ sendRecorderStatus(SoundStreamStatus.Initialized)
+ }
+
+ initResult["success"] = permissionToRecordAudio
+ debugLog("sending result")
+ activeResult?.success(initResult)
+ debugLog("leaving complete")
+ activeResult = null
+ }
+
+ private fun sendEventMethod(name: String, data: Any) {
+ val eventData: HashMap = HashMap()
+ eventData["name"] = name
+ eventData["data"] = data
+ methodChannel.invokeMethod("platformEvent", eventData)
+ }
+
+ private fun debugLog(msg: String) {
+ if (debugLogging) {
+ Log.d(logTag, msg)
+ }
+ }
+
+ private fun startRecording(result: Result) {
+ try {
+ if (mRecorder!!.recordingState == AudioRecord.RECORDSTATE_RECORDING) {
+ result.success(true)
+ return
+ }
+ initRecorder()
+ mRecorder!!.startRecording()
+ sendRecorderStatus(SoundStreamStatus.Playing)
+ result.success(true)
+ } catch (e: IllegalStateException) {
+ debugLog("record() failed")
+ result.error(SoundStreamErrors.FailedToRecord.name, "Failed to start recording", e.localizedMessage)
+ }
+ }
+
+ private fun stopRecording(result: Result) {
+ try {
+ if (mRecorder!!.recordingState == AudioRecord.RECORDSTATE_STOPPED) {
+ result.success(true)
+ return
+ }
+ mRecorder!!.stop()
+ sendRecorderStatus(SoundStreamStatus.Stopped)
+ result.success(true)
+ } catch (e: IllegalStateException) {
+ debugLog("record() failed")
+ result.error(SoundStreamErrors.FailedToRecord.name, "Failed to start recording", e.localizedMessage)
+ }
+ }
+
+ private fun sendRecorderStatus(status: SoundStreamStatus) {
+ sendEventMethod("recorderStatus", status.name)
+ }
+
+ private fun createRecordListener(): OnRecordPositionUpdateListener? {
+ return object : OnRecordPositionUpdateListener {
+ override fun onMarkerReached(recorder: AudioRecord) {
+ recorder.read(audioData!!, 0, mRecorderBufferSize)
+ }
+
+ override fun onPeriodicNotification(recorder: AudioRecord) {
+ val data = audioData!!
+ val shortOut = recorder.read(data, 0, mPeriodFrames)
+ // https://flutter.io/platform-channels/#codec
+ // convert short to int because of platform-channel's limitation
+ val byteBuffer = ByteBuffer.allocate(shortOut * 2)
+ byteBuffer.order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().put(data)
+
+ sendEventMethod("dataPeriod", byteBuffer.array())
+ }
+ }
+ }
}
diff --git a/lib/src/audio/recorder_stream.dart b/lib/src/audio/recorder_stream.dart
new file mode 100644
index 0000000..1a9088d
--- /dev/null
+++ b/lib/src/audio/recorder_stream.dart
@@ -0,0 +1,70 @@
+import 'dart:async';
+import 'dart:typed_data';
+import 'sound_stream.dart';
+
+class RecorderStream {
+ static final RecorderStream _instance = RecorderStream._internal();
+ factory RecorderStream() => _instance;
+
+ final _audioStreamController = StreamController.broadcast();
+
+ final _recorderStatusController =
+ StreamController.broadcast();
+
+ RecorderStream._internal() {
+ SoundStream();
+ eventsStreamController.stream.listen(_eventListener);
+ _recorderStatusController.add(SoundStreamStatus.Unset);
+ _audioStreamController.add(Uint8List(0));
+ }
+
+ /// Initialize Recorder with specified [sampleRate]
+ Future initialize({int sampleRate = 16000, bool showLogs = false}) =>
+ methodChannel.invokeMethod("initializeRecorder", {
+ "sampleRate": sampleRate,
+ "showLogs": showLogs,
+ });
+
+ /// Start recording. Recorder will start pushing audio chunks (PCM 16bit data)
+ /// to audiostream as Uint8List
+ Future start() =>
+ methodChannel.invokeMethod("startRecording");
+
+ /// Recorder will stop recording and sending audio chunks to the [audioStream].
+ Future stop() =>
+ methodChannel.invokeMethod("stopRecording");
+
+ /// Current status of the [RecorderStream]
+ Stream get status => _recorderStatusController.stream;
+
+ /// Stream of PCM 16bit data from Microphone
+ Stream get audioStream => _audioStreamController.stream;
+
+ void _eventListener(dynamic event) {
+ if (event == null) return;
+ final String eventName = event["name"] ?? "";
+ switch (eventName) {
+ case "dataPeriod":
+ final Uint8List audioData =
+ Uint8List.fromList(event["data"] ?? []);
+ if (audioData.isNotEmpty) _audioStreamController.add(audioData);
+ break;
+ case "recorderStatus":
+ final String status = event["data"] ?? "Unset";
+ _recorderStatusController.add(SoundStreamStatus.values.firstWhere(
+ (value) => enumToString(value) == status,
+ orElse: () => SoundStreamStatus.Unset,
+ ));
+ break;
+ }
+ }
+
+ /// Stop and close all streams. This cannot be undone
+ /// Only call this method if you don't want to use this anymore
+ void dispose() {
+ stop();
+ eventsStreamController.close();
+ _recorderStatusController.close();
+ _audioStreamController.close();
+ }
+}
diff --git a/lib/src/audio/sound_stream.dart b/lib/src/audio/sound_stream.dart
new file mode 100644
index 0000000..1b870ed
--- /dev/null
+++ b/lib/src/audio/sound_stream.dart
@@ -0,0 +1,40 @@
+import 'dart:async';
+
+import 'package:flutter/services.dart';
+import 'recorder_stream.dart';
+
+const methodChannelName = 'com.tfliteflutter.tflite_flutter_helper:methods';
+
+const MethodChannel methodChannel = MethodChannel(methodChannelName);
+
+final eventsStreamController = StreamController.broadcast();
+
+enum SoundStreamStatus {
+ Unset,
+ Initialized,
+ Playing,
+ Stopped,
+}
+
+class SoundStream {
+ static final SoundStream _instance = SoundStream._internal();
+ factory SoundStream() => _instance;
+ SoundStream._internal() {
+ methodChannel.setMethodCallHandler(_onMethodCall);
+ }
+
+ /// Return [RecorderStream] instance (Singleton).
+ RecorderStream get recorder => RecorderStream();
+
+
+ Future _onMethodCall(MethodCall call) async {
+ switch (call.method) {
+ case "platformEvent":
+ eventsStreamController.add(call.arguments);
+ break;
+ }
+ return null;
+ }
+}
+
+String enumToString(Object o) => o.toString().split('.').last;
diff --git a/lib/src/audio/tensor_audio.dart b/lib/src/audio/tensor_audio.dart
new file mode 100644
index 0000000..f197f2f
--- /dev/null
+++ b/lib/src/audio/tensor_audio.dart
@@ -0,0 +1,173 @@
+import 'dart:math';
+import 'dart:typed_data';
+
+import 'package:quiver/check.dart';
+import 'package:tflite_flutter/tflite_flutter.dart';
+import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';
+
+class TensorAudio {
+ final String TAG = "TensorAudioDart";
+ late final FloatRingBuffer buffer;
+ late final TensorAudioFormat format;
+
+ TensorAudio._(this.format, int sampleCount){
+ this.buffer = FloatRingBuffer._(sampleCount * format.channelCount);
+ }
+
+ static TensorAudio create(TensorAudioFormat format, int sampleCount) {
+ return TensorAudio._(format, sampleCount);
+ }
+
+ void loadDoubleList(List src) {
+ loadDoubleListOffset(src, 0, src.length);
+ }
+
+ void loadDoubleListOffset(List src, int offsetInFloat,
+ int sizeInFloat) {
+ checkArgument(
+ sizeInFloat % format.channelCount == 0,
+ message:
+ "Size ($sizeInFloat) needs to be a multiplier of the number of channels (${format
+ .channelCount})",
+ );
+ buffer.loadOffset(src, offsetInFloat, sizeInFloat);
+ }
+
+ void loadShortBytes(Uint8List shortBytes) {
+ ByteData byteData = ByteData.sublistView(shortBytes);
+ List shortList = [];
+ for (int i = 0; i < byteData.lengthInBytes; i += 2) {
+ shortList.add(byteData.getInt16(i, Endian.little));
+ }
+ loadList(shortList);
+ }
+
+ void loadFloatBytes(Uint8List floatBytes) {
+ ByteData byteData = ByteData.sublistView(floatBytes);
+ List doubleList = [];
+ for (int i = 0; i < byteData.lengthInBytes; i += 4) {
+ doubleList.add(byteData.getFloat32(i, Endian.little));
+ }
+ loadDoubleList(doubleList);
+ }
+
+ void loadList(List src) {
+ loadListOffset(src, 0, src.length);
+ }
+
+ void loadListOffset(List src, int offsetInShort, int sizeInShort) {
+ checkArgument(
+ offsetInShort + sizeInShort <= src.length,
+ message: "Index out of range. offset ($offsetInShort) + size ($sizeInShort) should <= newData.length (${src
+ .length})");
+ List floatData = List.filled(sizeInShort, 0.0);
+ for (int i = offsetInShort; i < sizeInShort; i++) {
+ // Convert the data to PCM Float encoding i.e. values between -1 and 1
+ floatData[i] = src[i] / (pow(2, 15) - 1);
+ }
+ loadDoubleList(floatData);
+ }
+
+
+ /// Returns a float {@link TensorBuffer} holding all the available audio samples in {@link
+ /// android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1].
+ TensorBuffer get tensorBuffer {
+ ByteBuffer byteBuffer = buffer.buffer;
+ TensorBuffer tensorBuffer =
+ // TODO: Confirm Shape
+ TensorBuffer.createFixedSize(
+ [1, byteBuffer
+ .asFloat32List()
+ .length
+ ],
+ TfLiteType.float32);
+ tensorBuffer.loadBuffer(byteBuffer);
+ return tensorBuffer;
+ }
+
+ /// Returns the {@link TensorAudioFormat} associated with the tensor.
+ // TODO: Rename
+ TensorAudioFormat get gformat {
+ return format;
+ }
+}
+
+// TODO: Update documentation according to flutter
+/// Wraps a few constants describing the format of the incoming audio samples, namely number of
+/// channels and the sample rate. By default, channels is set to 1.
+class TensorAudioFormat {
+ static const int DEFAULT_CHANNELS = 1;
+ late final int _channelCount;
+ late final int _sampleRate;
+
+ TensorAudioFormat._(this._channelCount, this._sampleRate);
+
+ static TensorAudioFormat create(int channelCount, int sampleRate) {
+ checkArgument(channelCount > 0,
+ message: "Number of channels should be greater than 0");
+ checkArgument(
+ sampleRate > 0, message: "Sample rate should be greater than 0");
+ return TensorAudioFormat._(channelCount, sampleRate);
+ }
+
+ int get channelCount => _channelCount;
+
+ int get sampleRate => _sampleRate;
+}
+
+/// Actual implementation of the ring buffer. */
+class FloatRingBuffer {
+ late final List _buffer;
+ int _nextIndex = 0;
+
+ FloatRingBuffer._(int flatSize) {
+ _buffer = List.filled(flatSize, 0.0);
+ }
+
+ /// Loads the entire float array to the ring buffer. If the float array is longer than ring
+ /// buffer's capacity, samples with lower indicies in the array will be ignored.
+ void load(List newData) {
+ loadOffset(newData, 0, newData.length);
+ }
+
+ /// Loads a slice of the float array to the ring buffer. If the float array is longer than ring
+ /// buffer's capacity, samples with lower indicies in the array will be ignored.
+ void loadOffset(List newData, int offset, int size) {
+ checkArgument(
+ offset + size <= newData.length,
+ message:
+ "Index out of range. offset ($offset) + size ($size) should <= newData.length (${newData
+ .length})",
+ );
+ // If buffer can't hold all the data, only keep the most recent data of size buffer.length
+ if (size > _buffer.length) {
+ offset = size - _buffer.length;
+ size = _buffer.length;
+ }
+ if (_nextIndex + size < _buffer.length) {
+ // No need to wrap nextIndex, just copy newData[offset:offset + size]
+ // to buffer[nextIndex:nextIndex+size]
+ List.copyRange(_buffer, _nextIndex, newData, offset, offset + size);
+ } else {
+ // Need to wrap nextIndex, perform copy in two chunks.
+ int firstChunkSize = _buffer.length - _nextIndex;
+ // First copy newData[offset:offset+firstChunkSize] to buffer[nextIndex:buffer.length]
+ List.copyRange(
+ _buffer, _nextIndex, newData, offset, offset + firstChunkSize);
+ // Then copy newData[offset+firstChunkSize:offset+size] to buffer[0:size-firstChunkSize]
+ List.copyRange(
+ _buffer, 0, newData, offset + firstChunkSize, offset + size);
+ }
+
+ _nextIndex = (_nextIndex + size) % _buffer.length;
+ }
+
+ ByteBuffer get buffer {
+ // TODO: Make sure there is no endianness issue
+ return Float32List
+ .fromList(_buffer)
+ .buffer;
+ }
+
+ int get capacity => _buffer.length;
+}
diff --git a/lib/tflite_flutter_helper.dart b/lib/tflite_flutter_helper.dart
index c97f7b6..f4a5300 100644
--- a/lib/tflite_flutter_helper.dart
+++ b/lib/tflite_flutter_helper.dart
@@ -36,3 +36,6 @@ export 'src/label/tensor_label.dart';
export 'src/tensorbuffer/tensorbuffer.dart';
export 'src/tensorbuffer/tensorbufferfloat.dart';
export 'src/tensorbuffer/tensorbufferuint8.dart';
+export 'src/audio/recorder_stream.dart';
+export 'src/audio/sound_stream.dart';
+export 'src/audio/tensor_audio.dart';
diff --git a/pubspec.lock b/pubspec.lock
index abb57da..8b748ce 100644
--- a/pubspec.lock
+++ b/pubspec.lock
@@ -239,7 +239,7 @@ packages:
name: tflite_flutter
url: "https://pub.dartlang.org"
source: hosted
- version: "0.8.0"
+ version: "0.9.0"
tuple:
dependency: "direct main"
description: