Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
am15h committed Aug 19, 2021
1 parent 3fa5ecd commit 1403d94
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 35 deletions.
2 changes: 1 addition & 1 deletion example/audio_classification/android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ buildscript {
}

dependencies {
classpath 'com.android.tools.build:gradle:4.1.0'
classpath 'com.android.tools.build:gradle:4.2.0'
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-6.7-all.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-6.7.1-all.zip
21 changes: 14 additions & 7 deletions example/audio_classification/lib/classifier.dart
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import 'dart:typed_data';

import 'package:audio_classification/main.dart';
import 'package:flutter/services.dart';
import 'package:collection/collection.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
Expand Down Expand Up @@ -56,7 +57,8 @@ class Classifier {
Category predict(List<int> audioSample) {
final pres = DateTime.now().millisecondsSinceEpoch;
Uint8List bytes = Uint8List.fromList(audioSample);
TensorAudio tensorAudio = TensorAudio.create(TensorAudioFormat.create(1, _inputShape[0]), _inputShape[0]);
TensorAudio tensorAudio = TensorAudio.create(
TensorAudioFormat.create(1, sampleRate), _inputShape[0]);
tensorAudio.loadShortBytes(bytes);
final pre = DateTime.now().millisecondsSinceEpoch - pres;
print('Time to load audio tensor: $pre ms');
Expand All @@ -65,30 +67,35 @@ class Classifier {
print(tensorAudio.tensorBuffer.getShape());

final runs = DateTime.now().millisecondsSinceEpoch;
interpreter.run(tensorAudio.tensorBuffer.getBuffer(), _outputBuffer.getBuffer());
interpreter.run(
tensorAudio.tensorBuffer.getBuffer(), _outputBuffer.getBuffer());
final run = DateTime.now().millisecondsSinceEpoch - runs;

print(_outputBuffer.getDoubleList());
Map<String, double> labeledProb = {};
for(int i = 0; i < _outputBuffer.getDoubleList().length; i++) {
for (int i = 0; i < _outputBuffer.getDoubleList().length; i++) {
labeledProb[labels[i]!] = _outputBuffer.getDoubleValue(i);
}
final top = getTopProbability(labeledProb);
print(top);
print('Time to run inference: $run ms');
return Category(top.key, top.value);
return top.first;
}

void close() {
interpreter.close();
}
}

MapEntry<String, double> getTopProbability(Map<String, double> labeledProb) {
List<Category> getTopProbability(Map<String, double> labeledProb) {
var pq = PriorityQueue<MapEntry<String, double>>(compare);
pq.addAll(labeledProb.entries);

return pq.first;
var result = <Category>[];
for (int i = 0; i < 5; i++) {
result.add(Category(pq.first.key, pq.first.value));
pq.removeFirst();
}
return result;
}

int compare(MapEntry<String, double> e1, MapEntry<String, double> e2) {
Expand Down
61 changes: 38 additions & 23 deletions example/audio_classification/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ class _MyAppState extends State<MyApp> {
await session.configure(AudioSessionConfiguration.speech());
// Listen to errors during playback.
_player.playbackEventStream.listen(
(event) {
},
(event) {},
onError: (Object e, StackTrace stackTrace) {
print('A stream error occurred: $e');
},
Expand All @@ -99,24 +98,37 @@ class _MyAppState extends State<MyApp> {
return MaterialApp(
home: Scaffold(
appBar: AppBar(
title: const Text('Audio Classification'),
backgroundColor: Colors.orange,
title: const Text(
'TFL Audio Classification',
style: TextStyle(color: Colors.white),
),
),
body: Column(
mainAxisAlignment: MainAxisAlignment.center,
children: [
Row(
mainAxisAlignment: MainAxisAlignment.spaceAround,
children: [
IconButton(
iconSize: 64.0,
icon: Icon(_isRecording ? Icons.mic_off : Icons.mic),
onPressed: _isRecording ? () {
_recorder.stop();
_loadMicChunks();
} : () {
_micChunks = [];
_recorder.start();
},
Column(
children: [
IconButton(
iconSize: 64.0,
icon: Icon(_isRecording ? Icons.mic_off : Icons.mic),
onPressed: _isRecording
? () {
_recorder.stop();
_loadMicChunks();
}
: () {
_micChunks = [];
_recorder.start();
},
),
Text(
'Record',
),
],
),
StreamBuilder<PlayerState>(
stream: _player.playerStateStream,
Expand Down Expand Up @@ -166,17 +178,20 @@ class _MyAppState extends State<MyApp> {
child: Text("Predict"),
),
),
if(prediction != null)
Padding(
padding: const EdgeInsets.all(32.0),
child: Column(
children: [
Text(prediction!.label, style: TextStyle(fontWeight: FontWeight.bold),),
SizedBox(height: 4),
Text(prediction!.score.toString()),
],
if (prediction != null)
Padding(
padding: const EdgeInsets.all(32.0),
child: Column(
children: [
Text(
prediction!.label,
style: TextStyle(fontWeight: FontWeight.bold),
),
SizedBox(height: 4),
Text(prediction!.score.toString()),
],
),
),
),
],
),
),
Expand Down
6 changes: 3 additions & 3 deletions example/audio_classification/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ packages:
tflite_flutter:
dependency: transitive
description:
name: tflite_flutter
url: "https://pub.dartlang.org"
source: hosted
path: "../../../tflite_flutter_plugin"
relative: true
source: path
version: "0.9.0"
tflite_flutter_helper:
dependency: "direct main"
Expand Down

0 comments on commit 1403d94

Please sign in to comment.