forked from am15h/tflite_flutter_helper
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add bindings and dart API for nlclassifier
- Loading branch information
Showing
12 changed files
with
376 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import 'dart:ffi'; | ||
import 'dart:io'; | ||
|
||
const Set<String> _supported = {'linux', 'mac', 'win'}; | ||
|
||
String get binaryName { | ||
String os, ext; | ||
if (Platform.isLinux) { | ||
os = 'linux'; | ||
ext = 'so'; | ||
} else if (Platform.isMacOS) { | ||
os = 'mac'; | ||
ext = 'so'; | ||
} else if (Platform.isWindows) { | ||
os = 'win'; | ||
ext = 'dll'; | ||
} else { | ||
throw Exception('Unsupported platform!'); | ||
} | ||
|
||
if (!_supported.contains(os)) { | ||
throw UnsupportedError('Unsupported platform: $os!'); | ||
} | ||
|
||
return 'libtensorflowlite_c-$os.$ext'; | ||
} | ||
|
||
/// TensorFlowLite C library. | ||
// ignore: missing_return | ||
DynamicLibrary tflitelib = () { | ||
if (Platform.isAndroid) { | ||
return DynamicLibrary.open('libtensorflowlite_c.so'); | ||
} else if (Platform.isIOS) { | ||
return DynamicLibrary.process(); | ||
} else { | ||
final binaryPath = Platform.script.resolveUri(Uri.directory('.')).path + | ||
'blobs/$binaryName'; | ||
final binaryFilePath = Uri(path: binaryPath).toFilePath(); | ||
return DynamicLibrary.open(binaryFilePath); | ||
} | ||
}(); |
51 changes: 51 additions & 0 deletions
51
lib/src/task/bindings/text/nl_classifier/bert_nl_classifier.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import 'dart:ffi'; | ||
|
||
import 'package:ffi/ffi.dart'; | ||
import 'types.dart'; | ||
|
||
import 'package:tflite_flutter_helper/src/task/bindings/dlib.dart'; | ||
|
||
// ignore_for_file: non_constant_identifier_names, camel_case_types | ||
|
||
// Creates BertBertNLClassifier from model path and options, returns nullptr if the | ||
// file doesn't exist or is not a well formatted TFLite model path. | ||
Pointer<TfLiteBertNLClassifier> Function(Pointer<Utf8> modelPath, | ||
Pointer<TfLiteBertNLClassifierOptions> options) | ||
BertNLClassifierFromFileAndOptions = tflitelib | ||
.lookup<NativeFunction<_BertNLClassifierFromFileAndOptions_native_t>>( | ||
'BertNLClassifierFromFileAndOptions') | ||
.asFunction(); | ||
|
||
typedef _BertNLClassifierFromFileAndOptions_native_t = Pointer<TfLiteBertNLClassifier> Function( | ||
Pointer<Utf8> modelPath, | ||
Pointer<TfLiteBertNLClassifierOptions> options); | ||
|
||
// Creates BertNLClassifier from model path and default options, returns nullptr | ||
// if the file doesn't exist or is not a well formatted TFLite model path. | ||
Pointer<TfLiteBertNLClassifier> Function(Pointer<Utf8> modelPath) | ||
BertNLClassifierFromFile = tflitelib | ||
.lookup<NativeFunction<_BertNLClassifierFromFile_native_t>>( | ||
'BertNLClassifierFromFile') | ||
.asFunction(); | ||
|
||
typedef _BertNLClassifierFromFile_native_t = Pointer<TfLiteBertNLClassifier> Function( | ||
Pointer<Utf8> modelPath); | ||
|
||
// Invokes the encapsulated TFLite model and classifies the input text. | ||
Pointer<TfLiteCategories> Function(Pointer<TfLiteBertNLClassifier> classifier, | ||
Pointer<Utf8> text) | ||
BertNLClassifierClassify = tflitelib | ||
.lookup<NativeFunction<_BertNLClassifierClassify_native_t>>( | ||
'BertNLClassifierClassify') | ||
.asFunction(); | ||
|
||
typedef _BertNLClassifierClassify_native_t = Pointer<TfLiteCategories> Function(Pointer<TfLiteBertNLClassifier> classifier, | ||
Pointer<Utf8> text); | ||
|
||
// Deletes BertNLClassifer instance | ||
void Function(Pointer<TfLiteBertNLClassifier>) BertNLClassifierDelete = tflitelib | ||
.lookup<NativeFunction<_BertNLClassifierDelete_native_t>>( | ||
'BertNLClassifierDelete') | ||
.asFunction(); | ||
|
||
typedef _BertNLClassifierDelete_native_t = Void Function(Pointer<TfLiteBertNLClassifier>); |
40 changes: 40 additions & 0 deletions
40
lib/src/task/bindings/text/nl_classifier/nl_classifer.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import 'dart:ffi'; | ||
|
||
import 'package:ffi/ffi.dart'; | ||
import 'types.dart'; | ||
|
||
import 'package:tflite_flutter_helper/src/task/bindings/dlib.dart'; | ||
|
||
// ignore_for_file: non_constant_identifier_names, camel_case_types | ||
|
||
// Creates NLClassifier from model path and options, returns nullptr if the file | ||
// doesn't exist or is not a well formatted TFLite model path. | ||
Pointer<TfLiteNLClassifier> Function(Pointer<Utf8> modelPath, | ||
Pointer<TfLiteNLClassifierOptions> options) | ||
NLClassifierFromFileAndOptions = tflitelib | ||
.lookup<NativeFunction<_NLClassifierFromFileAndOptions_native_t>>( | ||
'NLClassifierFromFileAndOptions') | ||
.asFunction(); | ||
|
||
typedef _NLClassifierFromFileAndOptions_native_t = Pointer<TfLiteNLClassifier> Function( | ||
Pointer<Utf8> modelPath, | ||
Pointer<TfLiteNLClassifierOptions> options); | ||
|
||
// Invokes the encapsulated TFLite model and classifies the input text. | ||
Pointer<TfLiteCategories> Function(Pointer<TfLiteNLClassifier> classifier, | ||
Pointer<Utf8> text) | ||
NLClassifierClassify = tflitelib | ||
.lookup<NativeFunction<_NLClassifierClassify_native_t>>( | ||
'NLClassifierClassify') | ||
.asFunction(); | ||
|
||
typedef _NLClassifierClassify_native_t = Pointer<TfLiteCategories> Function(Pointer<TfLiteNLClassifier> classifier, | ||
Pointer<Utf8> text); | ||
|
||
// Deletes NLClassifer instance | ||
void Function(Pointer<TfLiteNLClassifier>) NLClassifierDelete = tflitelib | ||
.lookup<NativeFunction<_NLClassifierDelete_native_t>>( | ||
'NLClassifierDelete') | ||
.asFunction(); | ||
|
||
typedef _NLClassifierDelete_native_t = Void Function(Pointer<TfLiteNLClassifier>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import 'dart:ffi'; | ||
|
||
import 'package:ffi/ffi.dart'; | ||
|
||
class TfLiteNLClassifier extends Opaque {} | ||
|
||
// struct NLClassifierOptions { | ||
// int input_tensor_index; | ||
// int output_score_tensor_index; | ||
// int output_label_tensor_index; | ||
// const char* input_tensor_name; | ||
// const char* output_score_tensor_name; | ||
// const char* output_label_tensor_name; | ||
// }; | ||
|
||
class TfLiteNLClassifierOptions extends Struct { | ||
@Int32() | ||
external int inputTensorIndex; | ||
|
||
@Int32() | ||
external int outputScoreTensorIndex; | ||
|
||
@Int32() | ||
external int outputLabelTensorIndex; | ||
|
||
external Pointer<Utf8> inputTensorName; | ||
|
||
external Pointer<Utf8> outputScoreTensorName; | ||
|
||
external Pointer<Utf8> outputLabelTensorName; | ||
|
||
static Pointer<TfLiteNLClassifierOptions> allocate( | ||
int inputTensorIndex, | ||
int outputScoreTensorIndex, | ||
int outputLabelTensorIndex, | ||
String inputTensorName, | ||
String outputScoreTensorName, | ||
String outputLabelTensorName, | ||
) { | ||
final result = calloc<TfLiteNLClassifierOptions>(); | ||
result.ref | ||
..inputTensorIndex = inputTensorIndex | ||
..outputScoreTensorIndex = outputScoreTensorIndex | ||
..outputLabelTensorIndex = outputLabelTensorIndex | ||
..inputTensorName = inputTensorName.toNativeUtf8() | ||
..outputScoreTensorName = outputScoreTensorName.toNativeUtf8() | ||
..outputLabelTensorName = outputLabelTensorName.toNativeUtf8(); | ||
return result; | ||
} | ||
} | ||
|
||
class TfLiteCategories extends Struct { | ||
@Int32() | ||
external int size; | ||
|
||
external Pointer<TfLiteCategory> categories; | ||
} | ||
|
||
class TfLiteCategory extends Struct { | ||
external Pointer<Utf8> text; | ||
|
||
@Double() | ||
external double score; | ||
} | ||
|
||
class TfLiteBertNLClassifier extends Opaque {} | ||
|
||
class TfLiteBertNLClassifierOptions extends Struct { | ||
@Int32() | ||
external int maxSeqLen; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import 'dart:ffi'; | ||
|
||
import 'package:ffi/ffi.dart'; | ||
import 'types.dart'; | ||
|
||
import 'package:tflite_flutter_helper/src/task/bindings/dlib.dart'; | ||
|
||
// ignore_for_file: non_constant_identifier_names, camel_case_types | ||
|
||
// Creates BertQuestionAnswerer from model path, returns nullptr if the file | ||
// doesn't exist or is not a well formatted TFLite model path. | ||
Pointer<TfLiteBertQuestionAnswerer> Function(Pointer<Utf8> modelPath) | ||
BertQuestionAnswererFromFile = tflitelib | ||
.lookup<NativeFunction<_BertQuestionAnswererFromFile_native_t>>( | ||
'BertQuestionAnswererFromFile') | ||
.asFunction(); | ||
|
||
typedef _BertQuestionAnswererFromFile_native_t = Pointer<TfLiteBertQuestionAnswerer> Function( | ||
Pointer<Utf8> modelPath); | ||
|
||
// Invokes the encapsulated TFLite model and answers a question based on | ||
// context. | ||
Pointer<QaAnswers> Function(Pointer<TfLiteBertQuestionAnswerer> questionAnswerer, | ||
Pointer<Utf8> context, Pointer<Utf8> question) | ||
BertQuestionAnswererAnswer = tflitelib | ||
.lookup<NativeFunction<_BertQuestionAnswererAnswer_native_t>>( | ||
'BertQuestionAnswererAnswer') | ||
.asFunction(); | ||
|
||
typedef _BertQuestionAnswererAnswer_native_t = Pointer<QaAnswers> Function(Pointer<TfLiteBertQuestionAnswerer> questionAnswerer, | ||
Pointer<Utf8> context, Pointer<Utf8> question); | ||
|
||
// Deletes BertQuestionAnswerer instance | ||
void Function(Pointer<TfLiteBertQuestionAnswerer>) BertQuestionAnswererDelete = tflitelib | ||
.lookup<NativeFunction<_BertQuestionAnswererDelete_native_t>>( | ||
'BertQuestionAnswererDelete') | ||
.asFunction(); | ||
|
||
typedef _BertQuestionAnswererDelete_native_t = Void Function(Pointer<TfLiteBertQuestionAnswerer>); | ||
|
||
// Deletes BertQuestionAnswererQaAnswers instance | ||
void Function(Pointer<QaAnswers>) BertQuestionAnswererQaAnswersDelete = tflitelib | ||
.lookup<NativeFunction<_BertQuestionAnswererQaAnswersDelete_native_t>>( | ||
'BertQuestionAnswererQaAnswersDelete') | ||
.asFunction(); | ||
|
||
typedef _BertQuestionAnswererQaAnswersDelete_native_t = Void Function(Pointer<QaAnswers>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import 'dart:ffi'; | ||
|
||
import 'package:ffi/ffi.dart'; | ||
|
||
class TfLiteBertQuestionAnswerer extends Opaque {} | ||
|
||
class QaAnswer extends Struct { | ||
@Int32() | ||
external int start; | ||
@Int32() | ||
external int end; | ||
@Double() | ||
external double logit; | ||
|
||
external Pointer<Utf8> text; | ||
} | ||
|
||
class QaAnswers extends Struct { | ||
@Int32() | ||
external int size; | ||
|
||
Pointer<QaAnswer> answers; | ||
} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import 'dart:ffi'; | ||
import 'package:ffi/ffi.dart'; | ||
import 'package:quiver/check.dart'; | ||
import 'package:tflite_flutter_helper/src/label/category.dart'; | ||
import 'package:tflite_flutter_helper/src/task/bindings/text/nl_classifier/nl_classifer.dart'; | ||
import 'package:tflite_flutter_helper/src/task/bindings/text/nl_classifier/types.dart'; | ||
import 'package:tflite_flutter_helper/src/task/text/nl_classifier/nl_classifier_options.dart'; | ||
|
||
class NLClassifier { | ||
final Pointer<TfLiteNLClassifier> _classifier; | ||
bool _deleted = false; | ||
Pointer<TfLiteNLClassifier> get base => _classifier; | ||
|
||
NLClassifier._(this._classifier); | ||
|
||
factory NLClassifier._create(String modelPath, NLClassifierOptions options) { | ||
final classiferPtr = | ||
NLClassifierFromFileAndOptions(modelPath.toNativeUtf8(), options.base); | ||
return NLClassifier._(classiferPtr); | ||
} | ||
|
||
List<Category> classify(String text) { | ||
final ref = NLClassifierClassify(base, text.toNativeUtf8()).ref; | ||
var categoryList = List.generate( | ||
ref.size, | ||
(i) => Category( | ||
ref.categories[i].text.toDartString(), ref.categories[i].score), | ||
); | ||
return categoryList; | ||
} | ||
|
||
void delete() { | ||
checkState(!_deleted, message: 'NLCLassifier already deleted.'); | ||
NLClassifierDelete(base); | ||
_deleted = true; | ||
} | ||
} |
64 changes: 64 additions & 0 deletions
64
lib/src/task/text/nl_classifier/nl_classifier_options.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import 'dart:ffi'; | ||
|
||
import 'package:ffi/ffi.dart'; | ||
import 'package:quiver/check.dart'; | ||
import 'package:tflite_flutter_helper/src/task/bindings/text/nl_classifier/types.dart'; | ||
|
||
class NLClassifierOptions { | ||
final Pointer<TfLiteNLClassifierOptions> _options; | ||
bool _deleted = false; | ||
|
||
Pointer<TfLiteNLClassifierOptions> get base => _options; | ||
NLClassifierOptions._(this._options); | ||
|
||
/// Creates a new options instance. | ||
factory NLClassifierOptions() { | ||
final optionsPtr = calloc<TfLiteNLClassifierOptions>(); | ||
return NLClassifierOptions._(optionsPtr); | ||
} | ||
|
||
int get inputTensorIndex => base.ref.inputTensorIndex; | ||
|
||
set inputTensorIndex(int value) { | ||
base.ref.inputTensorIndex = value; | ||
} | ||
|
||
int get outputScoreTensorIndex => base.ref.outputScoreTensorIndex; | ||
|
||
set outputScoreTensorIndex(int value) { | ||
base.ref.outputScoreTensorIndex = value; | ||
} | ||
|
||
int get outputLabelTensorIndex => base.ref.outputLabelTensorIndex; | ||
|
||
set outputLabelTensorIndex(int value) { | ||
base.ref.outputLabelTensorIndex = value; | ||
} | ||
|
||
String get inputTensorName => base.ref.inputTensorName.toDartString(); | ||
|
||
set inputTensorName(String value) { | ||
base.ref.inputTensorName = value.toNativeUtf8(); | ||
} | ||
|
||
String get outputScoreTensorName => | ||
base.ref.outputScoreTensorName.toDartString(); | ||
|
||
set outputScoreTensorName(String value) { | ||
base.ref.outputScoreTensorName = value.toNativeUtf8(); | ||
} | ||
|
||
String get outputLabelTensorName => | ||
base.ref.outputLabelTensorName.toDartString(); | ||
|
||
set outputLabelTensorName(String value) { | ||
base.ref.outputLabelTensorName = value.toNativeUtf8(); | ||
} | ||
|
||
/// Destroys the options instance. | ||
void delete() { | ||
checkState(!_deleted, message: 'NLClassifierOptions already deleted.'); | ||
calloc.free(_options); | ||
_deleted = true; | ||
} | ||
} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.