From a063b1c304dd0a01a2853dfa7b6d50c2a0b55546 Mon Sep 17 00:00:00 2001 From: am15h Date: Tue, 10 Aug 2021 03:20:50 +0530 Subject: [PATCH] add bindings and dart API for nlclassifier --- lib/src/task/bindings/dlib.dart | 41 +++++++++++ .../nl_classifier/bert_nl_classifier.dart | 51 +++++++++++++ .../text/nl_classifier/nl_classifer.dart | 40 +++++++++++ .../bindings/text/nl_classifier/types.dart | 71 +++++++++++++++++++ lib/src/task/bindings/text/qa/bert_qa.dart | 47 ++++++++++++ lib/src/task/bindings/text/qa/types.dart | 23 ++++++ .../nl_classifier/bert_nl_classifier.dart | 0 .../text/nl_classifier/nl_classifier.dart | 37 ++++++++++ .../nl_classifier/nl_classifier_options.dart | 64 +++++++++++++++++ lib/src/task/text/qa/bert_qa.dart | 0 pubspec.lock | 2 +- pubspec.yaml | 1 + 12 files changed, 376 insertions(+), 1 deletion(-) create mode 100644 lib/src/task/bindings/dlib.dart create mode 100644 lib/src/task/bindings/text/nl_classifier/bert_nl_classifier.dart create mode 100644 lib/src/task/bindings/text/nl_classifier/nl_classifer.dart create mode 100644 lib/src/task/bindings/text/nl_classifier/types.dart create mode 100644 lib/src/task/bindings/text/qa/bert_qa.dart create mode 100644 lib/src/task/bindings/text/qa/types.dart create mode 100644 lib/src/task/text/nl_classifier/bert_nl_classifier.dart create mode 100644 lib/src/task/text/nl_classifier/nl_classifier.dart create mode 100644 lib/src/task/text/nl_classifier/nl_classifier_options.dart create mode 100644 lib/src/task/text/qa/bert_qa.dart diff --git a/lib/src/task/bindings/dlib.dart b/lib/src/task/bindings/dlib.dart new file mode 100644 index 0000000..8316722 --- /dev/null +++ b/lib/src/task/bindings/dlib.dart @@ -0,0 +1,41 @@ +import 'dart:ffi'; +import 'dart:io'; + +const Set _supported = {'linux', 'mac', 'win'}; + +String get binaryName { + String os, ext; + if (Platform.isLinux) { + os = 'linux'; + ext = 'so'; + } else if (Platform.isMacOS) { + os = 'mac'; + ext = 'so'; + } else if (Platform.isWindows) { + os = 'win'; + ext = 'dll'; + } else { + throw Exception('Unsupported platform!'); + } + + if (!_supported.contains(os)) { + throw UnsupportedError('Unsupported platform: $os!'); + } + + return 'libtensorflowlite_c-$os.$ext'; +} + +/// TensorFlowLite C library. +// ignore: missing_return +DynamicLibrary tflitelib = () { + if (Platform.isAndroid) { + return DynamicLibrary.open('libtensorflowlite_c.so'); + } else if (Platform.isIOS) { + return DynamicLibrary.process(); + } else { + final binaryPath = Platform.script.resolveUri(Uri.directory('.')).path + + 'blobs/$binaryName'; + final binaryFilePath = Uri(path: binaryPath).toFilePath(); + return DynamicLibrary.open(binaryFilePath); + } +}(); diff --git a/lib/src/task/bindings/text/nl_classifier/bert_nl_classifier.dart b/lib/src/task/bindings/text/nl_classifier/bert_nl_classifier.dart new file mode 100644 index 0000000..71fa57a --- /dev/null +++ b/lib/src/task/bindings/text/nl_classifier/bert_nl_classifier.dart @@ -0,0 +1,51 @@ +import 'dart:ffi'; + +import 'package:ffi/ffi.dart'; +import 'types.dart'; + +import 'package:tflite_flutter_helper/src/task/bindings/dlib.dart'; + +// ignore_for_file: non_constant_identifier_names, camel_case_types + +// Creates BertBertNLClassifier from model path and options, returns nullptr if the +// file doesn't exist or is not a well formatted TFLite model path. +Pointer Function(Pointer modelPath, + Pointer options) +BertNLClassifierFromFileAndOptions = tflitelib + .lookup>( + 'BertNLClassifierFromFileAndOptions') + .asFunction(); + +typedef _BertNLClassifierFromFileAndOptions_native_t = Pointer Function( + Pointer modelPath, + Pointer options); + +// Creates BertNLClassifier from model path and default options, returns nullptr +// if the file doesn't exist or is not a well formatted TFLite model path. +Pointer Function(Pointer modelPath) +BertNLClassifierFromFile = tflitelib + .lookup>( + 'BertNLClassifierFromFile') + .asFunction(); + +typedef _BertNLClassifierFromFile_native_t = Pointer Function( + Pointer modelPath); + +// Invokes the encapsulated TFLite model and classifies the input text. +Pointer Function(Pointer classifier, + Pointer text) +BertNLClassifierClassify = tflitelib + .lookup>( + 'BertNLClassifierClassify') + .asFunction(); + +typedef _BertNLClassifierClassify_native_t = Pointer Function(Pointer classifier, + Pointer text); + +// Deletes BertNLClassifer instance +void Function(Pointer) BertNLClassifierDelete = tflitelib + .lookup>( + 'BertNLClassifierDelete') + .asFunction(); + +typedef _BertNLClassifierDelete_native_t = Void Function(Pointer); \ No newline at end of file diff --git a/lib/src/task/bindings/text/nl_classifier/nl_classifer.dart b/lib/src/task/bindings/text/nl_classifier/nl_classifer.dart new file mode 100644 index 0000000..5c3f47c --- /dev/null +++ b/lib/src/task/bindings/text/nl_classifier/nl_classifer.dart @@ -0,0 +1,40 @@ +import 'dart:ffi'; + +import 'package:ffi/ffi.dart'; +import 'types.dart'; + +import 'package:tflite_flutter_helper/src/task/bindings/dlib.dart'; + +// ignore_for_file: non_constant_identifier_names, camel_case_types + +// Creates NLClassifier from model path and options, returns nullptr if the file +// doesn't exist or is not a well formatted TFLite model path. +Pointer Function(Pointer modelPath, + Pointer options) +NLClassifierFromFileAndOptions = tflitelib + .lookup>( + 'NLClassifierFromFileAndOptions') + .asFunction(); + +typedef _NLClassifierFromFileAndOptions_native_t = Pointer Function( + Pointer modelPath, + Pointer options); + +// Invokes the encapsulated TFLite model and classifies the input text. +Pointer Function(Pointer classifier, + Pointer text) +NLClassifierClassify = tflitelib + .lookup>( + 'NLClassifierClassify') + .asFunction(); + +typedef _NLClassifierClassify_native_t = Pointer Function(Pointer classifier, + Pointer text); + +// Deletes NLClassifer instance +void Function(Pointer) NLClassifierDelete = tflitelib + .lookup>( + 'NLClassifierDelete') + .asFunction(); + +typedef _NLClassifierDelete_native_t = Void Function(Pointer); \ No newline at end of file diff --git a/lib/src/task/bindings/text/nl_classifier/types.dart b/lib/src/task/bindings/text/nl_classifier/types.dart new file mode 100644 index 0000000..7d98446 --- /dev/null +++ b/lib/src/task/bindings/text/nl_classifier/types.dart @@ -0,0 +1,71 @@ +import 'dart:ffi'; + +import 'package:ffi/ffi.dart'; + +class TfLiteNLClassifier extends Opaque {} + +// struct NLClassifierOptions { +// int input_tensor_index; +// int output_score_tensor_index; +// int output_label_tensor_index; +// const char* input_tensor_name; +// const char* output_score_tensor_name; +// const char* output_label_tensor_name; +// }; + +class TfLiteNLClassifierOptions extends Struct { + @Int32() + external int inputTensorIndex; + + @Int32() + external int outputScoreTensorIndex; + + @Int32() + external int outputLabelTensorIndex; + + external Pointer inputTensorName; + + external Pointer outputScoreTensorName; + + external Pointer outputLabelTensorName; + + static Pointer allocate( + int inputTensorIndex, + int outputScoreTensorIndex, + int outputLabelTensorIndex, + String inputTensorName, + String outputScoreTensorName, + String outputLabelTensorName, + ) { + final result = calloc(); + result.ref + ..inputTensorIndex = inputTensorIndex + ..outputScoreTensorIndex = outputScoreTensorIndex + ..outputLabelTensorIndex = outputLabelTensorIndex + ..inputTensorName = inputTensorName.toNativeUtf8() + ..outputScoreTensorName = outputScoreTensorName.toNativeUtf8() + ..outputLabelTensorName = outputLabelTensorName.toNativeUtf8(); + return result; + } +} + +class TfLiteCategories extends Struct { + @Int32() + external int size; + + external Pointer categories; +} + +class TfLiteCategory extends Struct { + external Pointer text; + + @Double() + external double score; +} + +class TfLiteBertNLClassifier extends Opaque {} + +class TfLiteBertNLClassifierOptions extends Struct { + @Int32() + external int maxSeqLen; +} diff --git a/lib/src/task/bindings/text/qa/bert_qa.dart b/lib/src/task/bindings/text/qa/bert_qa.dart new file mode 100644 index 0000000..7d78f1e --- /dev/null +++ b/lib/src/task/bindings/text/qa/bert_qa.dart @@ -0,0 +1,47 @@ +import 'dart:ffi'; + +import 'package:ffi/ffi.dart'; +import 'types.dart'; + +import 'package:tflite_flutter_helper/src/task/bindings/dlib.dart'; + +// ignore_for_file: non_constant_identifier_names, camel_case_types + +// Creates BertQuestionAnswerer from model path, returns nullptr if the file +// doesn't exist or is not a well formatted TFLite model path. +Pointer Function(Pointer modelPath) +BertQuestionAnswererFromFile = tflitelib + .lookup>( + 'BertQuestionAnswererFromFile') + .asFunction(); + +typedef _BertQuestionAnswererFromFile_native_t = Pointer Function( + Pointer modelPath); + +// Invokes the encapsulated TFLite model and answers a question based on +// context. +Pointer Function(Pointer questionAnswerer, + Pointer context, Pointer question) +BertQuestionAnswererAnswer = tflitelib + .lookup>( + 'BertQuestionAnswererAnswer') + .asFunction(); + +typedef _BertQuestionAnswererAnswer_native_t = Pointer Function(Pointer questionAnswerer, + Pointer context, Pointer question); + +// Deletes BertQuestionAnswerer instance +void Function(Pointer) BertQuestionAnswererDelete = tflitelib + .lookup>( + 'BertQuestionAnswererDelete') + .asFunction(); + +typedef _BertQuestionAnswererDelete_native_t = Void Function(Pointer); + +// Deletes BertQuestionAnswererQaAnswers instance +void Function(Pointer) BertQuestionAnswererQaAnswersDelete = tflitelib + .lookup>( + 'BertQuestionAnswererQaAnswersDelete') + .asFunction(); + +typedef _BertQuestionAnswererQaAnswersDelete_native_t = Void Function(Pointer); \ No newline at end of file diff --git a/lib/src/task/bindings/text/qa/types.dart b/lib/src/task/bindings/text/qa/types.dart new file mode 100644 index 0000000..dec02db --- /dev/null +++ b/lib/src/task/bindings/text/qa/types.dart @@ -0,0 +1,23 @@ +import 'dart:ffi'; + +import 'package:ffi/ffi.dart'; + +class TfLiteBertQuestionAnswerer extends Opaque {} + +class QaAnswer extends Struct { + @Int32() + external int start; + @Int32() + external int end; + @Double() + external double logit; + + external Pointer text; +} + +class QaAnswers extends Struct { + @Int32() + external int size; + + Pointer answers; +} \ No newline at end of file diff --git a/lib/src/task/text/nl_classifier/bert_nl_classifier.dart b/lib/src/task/text/nl_classifier/bert_nl_classifier.dart new file mode 100644 index 0000000..e69de29 diff --git a/lib/src/task/text/nl_classifier/nl_classifier.dart b/lib/src/task/text/nl_classifier/nl_classifier.dart new file mode 100644 index 0000000..87ff484 --- /dev/null +++ b/lib/src/task/text/nl_classifier/nl_classifier.dart @@ -0,0 +1,37 @@ +import 'dart:ffi'; +import 'package:ffi/ffi.dart'; +import 'package:quiver/check.dart'; +import 'package:tflite_flutter_helper/src/label/category.dart'; +import 'package:tflite_flutter_helper/src/task/bindings/text/nl_classifier/nl_classifer.dart'; +import 'package:tflite_flutter_helper/src/task/bindings/text/nl_classifier/types.dart'; +import 'package:tflite_flutter_helper/src/task/text/nl_classifier/nl_classifier_options.dart'; + +class NLClassifier { + final Pointer _classifier; + bool _deleted = false; + Pointer get base => _classifier; + + NLClassifier._(this._classifier); + + factory NLClassifier._create(String modelPath, NLClassifierOptions options) { + final classiferPtr = + NLClassifierFromFileAndOptions(modelPath.toNativeUtf8(), options.base); + return NLClassifier._(classiferPtr); + } + + List classify(String text) { + final ref = NLClassifierClassify(base, text.toNativeUtf8()).ref; + var categoryList = List.generate( + ref.size, + (i) => Category( + ref.categories[i].text.toDartString(), ref.categories[i].score), + ); + return categoryList; + } + + void delete() { + checkState(!_deleted, message: 'NLCLassifier already deleted.'); + NLClassifierDelete(base); + _deleted = true; + } +} diff --git a/lib/src/task/text/nl_classifier/nl_classifier_options.dart b/lib/src/task/text/nl_classifier/nl_classifier_options.dart new file mode 100644 index 0000000..a604a52 --- /dev/null +++ b/lib/src/task/text/nl_classifier/nl_classifier_options.dart @@ -0,0 +1,64 @@ +import 'dart:ffi'; + +import 'package:ffi/ffi.dart'; +import 'package:quiver/check.dart'; +import 'package:tflite_flutter_helper/src/task/bindings/text/nl_classifier/types.dart'; + +class NLClassifierOptions { + final Pointer _options; + bool _deleted = false; + + Pointer get base => _options; + NLClassifierOptions._(this._options); + + /// Creates a new options instance. + factory NLClassifierOptions() { + final optionsPtr = calloc(); + return NLClassifierOptions._(optionsPtr); + } + + int get inputTensorIndex => base.ref.inputTensorIndex; + + set inputTensorIndex(int value) { + base.ref.inputTensorIndex = value; + } + + int get outputScoreTensorIndex => base.ref.outputScoreTensorIndex; + + set outputScoreTensorIndex(int value) { + base.ref.outputScoreTensorIndex = value; + } + + int get outputLabelTensorIndex => base.ref.outputLabelTensorIndex; + + set outputLabelTensorIndex(int value) { + base.ref.outputLabelTensorIndex = value; + } + + String get inputTensorName => base.ref.inputTensorName.toDartString(); + + set inputTensorName(String value) { + base.ref.inputTensorName = value.toNativeUtf8(); + } + + String get outputScoreTensorName => + base.ref.outputScoreTensorName.toDartString(); + + set outputScoreTensorName(String value) { + base.ref.outputScoreTensorName = value.toNativeUtf8(); + } + + String get outputLabelTensorName => + base.ref.outputLabelTensorName.toDartString(); + + set outputLabelTensorName(String value) { + base.ref.outputLabelTensorName = value.toNativeUtf8(); + } + + /// Destroys the options instance. + void delete() { + checkState(!_deleted, message: 'NLClassifierOptions already deleted.'); + calloc.free(_options); + _deleted = true; + } +} diff --git a/lib/src/task/text/qa/bert_qa.dart b/lib/src/task/text/qa/bert_qa.dart new file mode 100644 index 0000000..e69de29 diff --git a/pubspec.lock b/pubspec.lock index 8b748ce..3c64af1 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -65,7 +65,7 @@ packages: source: hosted version: "1.2.0" ffi: - dependency: transitive + dependency: "direct main" description: name: ffi url: "https://pub.dartlang.org" diff --git a/pubspec.yaml b/pubspec.yaml index 6709272..e0228c3 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -15,6 +15,7 @@ dependencies: tflite_flutter: ^0.9.0 image: ^3.0.2 tuple: ^2.0.0 + ffi: ^1.0.0 dev_dependencies: flutter_test: