Skip to content

Commit

Permalink
Merge pull request am15h#34 from am15h/task_library
Browse files Browse the repository at this point in the history
Update Support and Task library
  • Loading branch information
am15h authored Aug 20, 2021
2 parents 5b76352 + 9cd35d9 commit a7dd0d0
Show file tree
Hide file tree
Showing 19 changed files with 886 additions and 214 deletions.
69 changes: 54 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
# TensorFlow Lite Flutter Helper Library

Makes use of TensorFlow Lite Interpreter on Flutter easier by
providing simple architecture for processing and manipulating
input and output of TFLite Models.

API design and documentation is identical to the TensorFlow Lite
Android Support Library.
TFLite Flutter Helper Library brings [TFLite Support Library](https://www.tensorflow.org/lite/inference_with_metadata/lite_support) and [TFLite Support Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview) to Flutter and helps users to develop ML and deploy TFLite models onto mobile devices quickly without compromising on performance.

## Getting Started

### Setup TFLite Flutter Plugin

Include `tflite_flutter: ^<latest_version>` in your pubspec.yaml. Follow the initial setup
instructions given [here](https://github.com/am15h/tflite_flutter_plugin#most-important-initial-setup)
Follow the initial setup instructions given [here](https://github.com/am15h/tflite_flutter_plugin#most-important-initial-setup)

## Image Processing
### Basic image manipulation and conversion

TFLite Helper depends on [flutter image package](https://pub.dev/packages/image) internally for
Image Processing.

### Basic image manipulation and conversion

The TensorFlow Lite Support Library has a suite of basic image manipulation methods such as crop
and resize. To use it, create an `ImageProcessor` and add the required operations.
To convert the image into the tensor format required by the TensorFlow Lite interpreter,
Expand All @@ -42,6 +34,22 @@ TensorImage tensorImage = TensorImage.fromFile(imageFile);
tensorImage = imageProcessor.process(tensorImage);
```

Sample app: [Image Classification](https://github.com/am15h/tflite_flutter_helper/tree/master/example/image_classification)

### Basic audio data processing

The TensorFlow Lite Support Library also defines a TensorAudio class wrapping some basic audio data processing methods.

```dart
TensorAudio tensorAudio = TensorAudio.create(
TensorAudioFormat.create(1, sampleRate), size);
tensorAudio.loadShortBytes(audioBytes);
TensorBuffer inputBuffer = tensorAudio.tensorBuffer;
```

Sample app: [Audio Classification](https://github.com/am15h/tflite_flutter_helper/tree/master/example/audio_classification)

### Create output objects and run the model

```dart
Expand Down Expand Up @@ -141,8 +149,39 @@ QuantizationParams inputParams = interpreter.getInputTensor(0).params;
QuantizationParams outputParams = interpreter.getOutputTensor(0).params;
```

## Coming Soon
## Task Library

Currently, Text based models like `NLClassifier`, `BertNLClassifier` and `BertQuestionAnswerer` are available to use with the Flutter Task Library.

### Integrate Natural Langugae Classifier

The Task Library's `NLClassifier` API classifies input text into different categories, and is a versatile and configurable API that can handle most text classification models. Detailed guide is available [here](https://www.tensorflow.org/lite/inference_with_metadata/task_library/nl_classifier).

```dart
final classifier = await NLClassifier.createFromAsset('assets/$_modelFileName',
options: NLClassifierOptions());
List<Category> predictions = classifier.classify(rawText);
```

Sample app: [Text Classification](https://github.com/am15h/tflite_flutter_plugin/tree/master/example/lib) using Task Library.

### Integrate BERT natural language classifier

The Task Library `BertNLClassifier` API is very similar to the `NLClassifier` that classifies input text into different categories, except that this API is specially tailored for Bert related models that require Wordpiece and Sentencepiece tokenizations outside the TFLite model. Detailed guide is available [here](https://www.tensorflow.org/lite/inference_with_metadata/task_library/bert_nl_classifier).

```dart
final classifier = await BertNLClassifier.createFromAsset('assets/$_modelFileName',
options: BertNLClassifierOptions());
List<Category> predictions = classifier.classify(rawText);
```

### Integrate BERT question answerer

The Task Library `BertQuestionAnswerer` API loads a Bert model and answers questions based on the content of a given passage. For more information, see the documentation for the Question-Answer model [here](https://www.tensorflow.org/lite/models/bert_qa/overview). Detailed guide is available [here](https://www.tensorflow.org/lite/inference_with_metadata/task_library/bert_question_answerer).

```dart
final bertQuestionAnswerer = await BertQuestionAnswerer.createFromAsset('assets/$_modelFileName');
List<QaAnswer> answeres = bertQuestionAnswerer.answer(context, question);
```

* More image operations
* Support for text-related applications.
* Support for audio-related applications.
Sample app: [Bert Question Answerer Sample](https://github.com/am15h/tflite_flutter_helper/tree/master/example/bert_question_answer)
37 changes: 35 additions & 2 deletions example/audio_classification/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,36 @@
# Audio Classification Flutter App
# Real-time Audio Classification Flutter

Demonstrates the usage of TensorAudio API.
Real-time Audio Classification in flutter. It uses:

* Interpreter API from TFLite Flutter Plugin.
* TensorAudio API from TFLite Flutter Support Library.
* [YAMNet](https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1),
an audio event classification model.

<p align="center">
<img src="audio_demo.gif" alt="animated" />
</p>

## Build and run

### Step 1. Clone TFLite Flutter Helper repository

Clone TFLite Flutter Helper repository to your computer to get the demo
application.

```
git clone https://github.com/am15h/tflite_flutter_helper
```

### Step 2. Run the application

```
cd example/audio_classification/
flutter run
```

## Resources used:

* [TensorFlow Lite](https://www.tensorflow.org/lite)
* [Audio Classification using TensorFlow Lite](https://www.tensorflow.org/lite/examples/audio_classification/overview)
* [YAMNet audio classification model](https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1)
Binary file added example/audio_classification/audio_demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 36 additions & 1 deletion example/audio_classification/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ packages:
url: "https://pub.dartlang.org"
source: hosted
version: "2.1.0"
camera:
dependency: transitive
description:
name: camera
url: "https://pub.dartlang.org"
source: hosted
version: "0.8.1+7"
camera_platform_interface:
dependency: transitive
description:
name: camera_platform_interface
url: "https://pub.dartlang.org"
source: hosted
version: "2.1.0"
characters:
dependency: transitive
description:
Expand Down Expand Up @@ -50,6 +64,13 @@ packages:
url: "https://pub.dartlang.org"
source: hosted
version: "1.15.0"
cross_file:
dependency: transitive
description:
name: cross_file
url: "https://pub.dartlang.org"
source: hosted
version: "0.3.1+4"
crypto:
dependency: transitive
description:
Expand Down Expand Up @@ -158,6 +179,13 @@ packages:
url: "https://pub.dartlang.org"
source: hosted
version: "2.0.1"
pedantic:
dependency: transitive
description:
name: pedantic
url: "https://pub.dartlang.org"
source: hosted
version: "1.11.1"
petitparser:
dependency: transitive
description:
Expand Down Expand Up @@ -226,6 +254,13 @@ packages:
url: "https://pub.dartlang.org"
source: hosted
version: "2.1.0"
stream_transform:
dependency: transitive
description:
name: stream_transform
url: "https://pub.dartlang.org"
source: hosted
version: "2.0.0"
string_scanner:
dependency: transitive
description:
Expand Down Expand Up @@ -305,4 +340,4 @@ packages:
version: "5.1.2"
sdks:
dart: ">=2.13.0 <3.0.0"
flutter: ">=1.26.0-17.6.pre"
flutter: ">=2.0.0"
1 change: 1 addition & 0 deletions example/bert_question_answer
Submodule bert_question_answer added at 2c380f
2 changes: 1 addition & 1 deletion example/image_classification/android/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ android {
defaultConfig {
// TODO: Specify your own unique Application ID (https://developer.android.com/studio/build/application-id.html).
applicationId "com.example.imageclassification"
minSdkVersion 16
minSdkVersion 21
targetSdkVersion 28
versionCode flutterVersionCode.toInteger()
versionName flutterVersionName
Expand Down
14 changes: 6 additions & 8 deletions example/image_classification/lib/classifier.dart
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ abstract class Classifier {
late TensorImage _inputImage;
late TensorBuffer _outputBuffer;

TfLiteType _outputType = TfLiteType.uint8;
late TfLiteType _inputType;
late TfLiteType _outputType;

final String _labelsFileName = 'assets/labels.txt';

Expand Down Expand Up @@ -52,6 +53,7 @@ abstract class Classifier {

_inputShape = interpreter.getInputTensor(0).shape;
_outputShape = interpreter.getOutputTensor(0).shape;
_inputType = interpreter.getInputTensor(0).type;
_outputType = interpreter.getOutputTensor(0).type;

_outputBuffer = TensorBuffer.createFixedSize(_outputShape, _outputType);
Expand Down Expand Up @@ -83,11 +85,9 @@ abstract class Classifier {
}

Category predict(Image image) {
if (interpreter == null) {
throw StateError('Cannot run inference, Intrepreter is null');
}
final pres = DateTime.now().millisecondsSinceEpoch;
_inputImage = TensorImage.fromImage(image);
_inputImage = TensorImage(_inputType);
_inputImage.loadImage(image);
_inputImage = _preProcess();
final pre = DateTime.now().millisecondsSinceEpoch - pres;

Expand All @@ -108,9 +108,7 @@ abstract class Classifier {
}

void close() {
if (interpreter != null) {
interpreter.close();
}
interpreter.close();
}
}

Expand Down
38 changes: 33 additions & 5 deletions example/image_classification/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ packages:
url: "https://pub.dartlang.org"
source: hosted
version: "2.1.0"
camera:
dependency: transitive
description:
name: camera
url: "https://pub.dartlang.org"
source: hosted
version: "0.8.1+7"
camera_platform_interface:
dependency: transitive
description:
name: camera_platform_interface
url: "https://pub.dartlang.org"
source: hosted
version: "2.1.0"
characters:
dependency: transitive
description:
Expand Down Expand Up @@ -50,6 +64,13 @@ packages:
url: "https://pub.dartlang.org"
source: hosted
version: "1.15.0"
cross_file:
dependency: transitive
description:
name: cross_file
url: "https://pub.dartlang.org"
source: hosted
version: "0.3.1+4"
crypto:
dependency: transitive
description:
Expand Down Expand Up @@ -302,6 +323,13 @@ packages:
url: "https://pub.dartlang.org"
source: hosted
version: "2.1.0"
stream_transform:
dependency: transitive
description:
name: stream_transform
url: "https://pub.dartlang.org"
source: hosted
version: "2.0.0"
string_scanner:
dependency: transitive
description:
Expand Down Expand Up @@ -331,11 +359,11 @@ packages:
source: hosted
version: "0.3.0"
tflite_flutter:
dependency: "direct main"
dependency: transitive
description:
path: "../../../tflite_flutter_plugin"
relative: true
source: path
name: tflite_flutter
url: "https://pub.dartlang.org"
source: hosted
version: "0.9.0"
tflite_flutter_helper:
dependency: "direct main"
Expand Down Expand Up @@ -402,4 +430,4 @@ packages:
version: "5.1.0"
sdks:
dart: ">=2.12.0 <3.0.0"
flutter: ">=1.26.0-17.6.pre"
flutter: ">=2.0.0"
3 changes: 0 additions & 3 deletions example/image_classification/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ dependencies:
logger: ^1.0.0

path_provider:
tflite_flutter:
path:
../../../tflite_flutter_plugin
tflite_flutter_helper:
path:
../../
Expand Down
30 changes: 30 additions & 0 deletions lib/src/image/base_image_container.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import 'package:camera/camera.dart';
import 'package:image/image.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tflite_flutter_helper/src/image/color_space_type.dart';
import 'package:tflite_flutter_helper/src/tensorbuffer/tensorbuffer.dart';

abstract class BaseImageContainer {

/// Performs deep copy of the {@link ImageContainer}. */
BaseImageContainer clone();

/// Returns the width of the image. */
int get width;

/// Returns the height of the image. */
int get height;

/// Gets the {@link Image} representation of the underlying image format. */
Image get image;

/// Gets the {@link TensorBuffer} representation with the specific {@code dataType} of the
/// underlying image format.
TensorBuffer getTensorBuffer(TfLiteType dataType);

/// Gets the {@link Image} representation of the underlying image format. */
CameraImage get mediaImage;

/// Returns the color space type of the image. */
ColorSpaceType get colorSpaceType;
}
Loading

0 comments on commit a7dd0d0

Please sign in to comment.