diff --git a/example/audio_classification/android/build.gradle b/example/audio_classification/android/build.gradle index 9b6ed06..3204979 100644 --- a/example/audio_classification/android/build.gradle +++ b/example/audio_classification/android/build.gradle @@ -6,7 +6,7 @@ buildscript { } dependencies { - classpath 'com.android.tools.build:gradle:4.1.0' + classpath 'com.android.tools.build:gradle:4.2.0' classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version" } } diff --git a/example/audio_classification/android/gradle/wrapper/gradle-wrapper.properties b/example/audio_classification/android/gradle/wrapper/gradle-wrapper.properties index bc6a58a..939efa2 100644 --- a/example/audio_classification/android/gradle/wrapper/gradle-wrapper.properties +++ b/example/audio_classification/android/gradle/wrapper/gradle-wrapper.properties @@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.7-all.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-6.7.1-all.zip diff --git a/example/audio_classification/lib/classifier.dart b/example/audio_classification/lib/classifier.dart index ce705dd..b00b147 100644 --- a/example/audio_classification/lib/classifier.dart +++ b/example/audio_classification/lib/classifier.dart @@ -1,5 +1,6 @@ import 'dart:typed_data'; +import 'package:audio_classification/main.dart'; import 'package:flutter/services.dart'; import 'package:collection/collection.dart'; import 'package:tflite_flutter/tflite_flutter.dart'; @@ -56,7 +57,8 @@ class Classifier { Category predict(List audioSample) { final pres = DateTime.now().millisecondsSinceEpoch; Uint8List bytes = Uint8List.fromList(audioSample); - TensorAudio tensorAudio = TensorAudio.create(TensorAudioFormat.create(1, _inputShape[0]), _inputShape[0]); + TensorAudio tensorAudio = TensorAudio.create( + TensorAudioFormat.create(1, sampleRate), _inputShape[0]); tensorAudio.loadShortBytes(bytes); final pre = DateTime.now().millisecondsSinceEpoch - pres; print('Time to load audio tensor: $pre ms'); @@ -65,18 +67,19 @@ class Classifier { print(tensorAudio.tensorBuffer.getShape()); final runs = DateTime.now().millisecondsSinceEpoch; - interpreter.run(tensorAudio.tensorBuffer.getBuffer(), _outputBuffer.getBuffer()); + interpreter.run( + tensorAudio.tensorBuffer.getBuffer(), _outputBuffer.getBuffer()); final run = DateTime.now().millisecondsSinceEpoch - runs; print(_outputBuffer.getDoubleList()); Map labeledProb = {}; - for(int i = 0; i < _outputBuffer.getDoubleList().length; i++) { + for (int i = 0; i < _outputBuffer.getDoubleList().length; i++) { labeledProb[labels[i]!] = _outputBuffer.getDoubleValue(i); } final top = getTopProbability(labeledProb); print(top); print('Time to run inference: $run ms'); - return Category(top.key, top.value); + return top.first; } void close() { @@ -84,11 +87,15 @@ class Classifier { } } -MapEntry getTopProbability(Map labeledProb) { +List getTopProbability(Map labeledProb) { var pq = PriorityQueue>(compare); pq.addAll(labeledProb.entries); - - return pq.first; + var result = []; + for (int i = 0; i < 5; i++) { + result.add(Category(pq.first.key, pq.first.value)); + pq.removeFirst(); + } + return result; } int compare(MapEntry e1, MapEntry e2) { diff --git a/example/audio_classification/lib/main.dart b/example/audio_classification/lib/main.dart index 532feb0..2fb9596 100644 --- a/example/audio_classification/lib/main.dart +++ b/example/audio_classification/lib/main.dart @@ -73,8 +73,7 @@ class _MyAppState extends State { await session.configure(AudioSessionConfiguration.speech()); // Listen to errors during playback. _player.playbackEventStream.listen( - (event) { - }, + (event) {}, onError: (Object e, StackTrace stackTrace) { print('A stream error occurred: $e'); }, @@ -99,7 +98,11 @@ class _MyAppState extends State { return MaterialApp( home: Scaffold( appBar: AppBar( - title: const Text('Audio Classification'), + backgroundColor: Colors.orange, + title: const Text( + 'TFL Audio Classification', + style: TextStyle(color: Colors.white), + ), ), body: Column( mainAxisAlignment: MainAxisAlignment.center, @@ -107,16 +110,25 @@ class _MyAppState extends State { Row( mainAxisAlignment: MainAxisAlignment.spaceAround, children: [ - IconButton( - iconSize: 64.0, - icon: Icon(_isRecording ? Icons.mic_off : Icons.mic), - onPressed: _isRecording ? () { - _recorder.stop(); - _loadMicChunks(); - } : () { - _micChunks = []; - _recorder.start(); - }, + Column( + children: [ + IconButton( + iconSize: 64.0, + icon: Icon(_isRecording ? Icons.mic_off : Icons.mic), + onPressed: _isRecording + ? () { + _recorder.stop(); + _loadMicChunks(); + } + : () { + _micChunks = []; + _recorder.start(); + }, + ), + Text( + 'Record', + ), + ], ), StreamBuilder( stream: _player.playerStateStream, @@ -166,17 +178,20 @@ class _MyAppState extends State { child: Text("Predict"), ), ), - if(prediction != null) - Padding( - padding: const EdgeInsets.all(32.0), - child: Column( - children: [ - Text(prediction!.label, style: TextStyle(fontWeight: FontWeight.bold),), - SizedBox(height: 4), - Text(prediction!.score.toString()), - ], + if (prediction != null) + Padding( + padding: const EdgeInsets.all(32.0), + child: Column( + children: [ + Text( + prediction!.label, + style: TextStyle(fontWeight: FontWeight.bold), + ), + SizedBox(height: 4), + Text(prediction!.score.toString()), + ], + ), ), - ), ], ), ), diff --git a/example/audio_classification/pubspec.lock b/example/audio_classification/pubspec.lock index 7d5cdcc..ae3b3b2 100644 --- a/example/audio_classification/pubspec.lock +++ b/example/audio_classification/pubspec.lock @@ -290,9 +290,9 @@ packages: tflite_flutter: dependency: transitive description: - name: tflite_flutter - url: "https://pub.dartlang.org" - source: hosted + path: "../../../tflite_flutter_plugin" + relative: true + source: path version: "0.9.0" tflite_flutter_helper: dependency: "direct main"