Skip to content

Commit

Permalink
add bindings and dart API for nlclassifier
Browse files Browse the repository at this point in the history
  • Loading branch information
am15h committed Aug 9, 2021
1 parent 25da4e0 commit a063b1c
Show file tree
Hide file tree
Showing 12 changed files with 376 additions and 1 deletion.
41 changes: 41 additions & 0 deletions lib/src/task/bindings/dlib.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import 'dart:ffi';
import 'dart:io';

const Set<String> _supported = {'linux', 'mac', 'win'};

String get binaryName {
String os, ext;
if (Platform.isLinux) {
os = 'linux';
ext = 'so';
} else if (Platform.isMacOS) {
os = 'mac';
ext = 'so';
} else if (Platform.isWindows) {
os = 'win';
ext = 'dll';
} else {
throw Exception('Unsupported platform!');
}

if (!_supported.contains(os)) {
throw UnsupportedError('Unsupported platform: $os!');
}

return 'libtensorflowlite_c-$os.$ext';
}

/// TensorFlowLite C library.
// ignore: missing_return
DynamicLibrary tflitelib = () {
if (Platform.isAndroid) {
return DynamicLibrary.open('libtensorflowlite_c.so');
} else if (Platform.isIOS) {
return DynamicLibrary.process();
} else {
final binaryPath = Platform.script.resolveUri(Uri.directory('.')).path +
'blobs/$binaryName';
final binaryFilePath = Uri(path: binaryPath).toFilePath();
return DynamicLibrary.open(binaryFilePath);
}
}();
51 changes: 51 additions & 0 deletions lib/src/task/bindings/text/nl_classifier/bert_nl_classifier.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import 'dart:ffi';

import 'package:ffi/ffi.dart';
import 'types.dart';

import 'package:tflite_flutter_helper/src/task/bindings/dlib.dart';

// ignore_for_file: non_constant_identifier_names, camel_case_types

// Creates BertBertNLClassifier from model path and options, returns nullptr if the
// file doesn't exist or is not a well formatted TFLite model path.
Pointer<TfLiteBertNLClassifier> Function(Pointer<Utf8> modelPath,
Pointer<TfLiteBertNLClassifierOptions> options)
BertNLClassifierFromFileAndOptions = tflitelib
.lookup<NativeFunction<_BertNLClassifierFromFileAndOptions_native_t>>(
'BertNLClassifierFromFileAndOptions')
.asFunction();

typedef _BertNLClassifierFromFileAndOptions_native_t = Pointer<TfLiteBertNLClassifier> Function(
Pointer<Utf8> modelPath,
Pointer<TfLiteBertNLClassifierOptions> options);

// Creates BertNLClassifier from model path and default options, returns nullptr
// if the file doesn't exist or is not a well formatted TFLite model path.
Pointer<TfLiteBertNLClassifier> Function(Pointer<Utf8> modelPath)
BertNLClassifierFromFile = tflitelib
.lookup<NativeFunction<_BertNLClassifierFromFile_native_t>>(
'BertNLClassifierFromFile')
.asFunction();

typedef _BertNLClassifierFromFile_native_t = Pointer<TfLiteBertNLClassifier> Function(
Pointer<Utf8> modelPath);

// Invokes the encapsulated TFLite model and classifies the input text.
Pointer<TfLiteCategories> Function(Pointer<TfLiteBertNLClassifier> classifier,
Pointer<Utf8> text)
BertNLClassifierClassify = tflitelib
.lookup<NativeFunction<_BertNLClassifierClassify_native_t>>(
'BertNLClassifierClassify')
.asFunction();

typedef _BertNLClassifierClassify_native_t = Pointer<TfLiteCategories> Function(Pointer<TfLiteBertNLClassifier> classifier,
Pointer<Utf8> text);

// Deletes BertNLClassifer instance
void Function(Pointer<TfLiteBertNLClassifier>) BertNLClassifierDelete = tflitelib
.lookup<NativeFunction<_BertNLClassifierDelete_native_t>>(
'BertNLClassifierDelete')
.asFunction();

typedef _BertNLClassifierDelete_native_t = Void Function(Pointer<TfLiteBertNLClassifier>);
40 changes: 40 additions & 0 deletions lib/src/task/bindings/text/nl_classifier/nl_classifer.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import 'dart:ffi';

import 'package:ffi/ffi.dart';
import 'types.dart';

import 'package:tflite_flutter_helper/src/task/bindings/dlib.dart';

// ignore_for_file: non_constant_identifier_names, camel_case_types

// Creates NLClassifier from model path and options, returns nullptr if the file
// doesn't exist or is not a well formatted TFLite model path.
Pointer<TfLiteNLClassifier> Function(Pointer<Utf8> modelPath,
Pointer<TfLiteNLClassifierOptions> options)
NLClassifierFromFileAndOptions = tflitelib
.lookup<NativeFunction<_NLClassifierFromFileAndOptions_native_t>>(
'NLClassifierFromFileAndOptions')
.asFunction();

typedef _NLClassifierFromFileAndOptions_native_t = Pointer<TfLiteNLClassifier> Function(
Pointer<Utf8> modelPath,
Pointer<TfLiteNLClassifierOptions> options);

// Invokes the encapsulated TFLite model and classifies the input text.
Pointer<TfLiteCategories> Function(Pointer<TfLiteNLClassifier> classifier,
Pointer<Utf8> text)
NLClassifierClassify = tflitelib
.lookup<NativeFunction<_NLClassifierClassify_native_t>>(
'NLClassifierClassify')
.asFunction();

typedef _NLClassifierClassify_native_t = Pointer<TfLiteCategories> Function(Pointer<TfLiteNLClassifier> classifier,
Pointer<Utf8> text);

// Deletes NLClassifer instance
void Function(Pointer<TfLiteNLClassifier>) NLClassifierDelete = tflitelib
.lookup<NativeFunction<_NLClassifierDelete_native_t>>(
'NLClassifierDelete')
.asFunction();

typedef _NLClassifierDelete_native_t = Void Function(Pointer<TfLiteNLClassifier>);
71 changes: 71 additions & 0 deletions lib/src/task/bindings/text/nl_classifier/types.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import 'dart:ffi';

import 'package:ffi/ffi.dart';

class TfLiteNLClassifier extends Opaque {}

// struct NLClassifierOptions {
// int input_tensor_index;
// int output_score_tensor_index;
// int output_label_tensor_index;
// const char* input_tensor_name;
// const char* output_score_tensor_name;
// const char* output_label_tensor_name;
// };

class TfLiteNLClassifierOptions extends Struct {
@Int32()
external int inputTensorIndex;

@Int32()
external int outputScoreTensorIndex;

@Int32()
external int outputLabelTensorIndex;

external Pointer<Utf8> inputTensorName;

external Pointer<Utf8> outputScoreTensorName;

external Pointer<Utf8> outputLabelTensorName;

static Pointer<TfLiteNLClassifierOptions> allocate(
int inputTensorIndex,
int outputScoreTensorIndex,
int outputLabelTensorIndex,
String inputTensorName,
String outputScoreTensorName,
String outputLabelTensorName,
) {
final result = calloc<TfLiteNLClassifierOptions>();
result.ref
..inputTensorIndex = inputTensorIndex
..outputScoreTensorIndex = outputScoreTensorIndex
..outputLabelTensorIndex = outputLabelTensorIndex
..inputTensorName = inputTensorName.toNativeUtf8()
..outputScoreTensorName = outputScoreTensorName.toNativeUtf8()
..outputLabelTensorName = outputLabelTensorName.toNativeUtf8();
return result;
}
}

class TfLiteCategories extends Struct {
@Int32()
external int size;

external Pointer<TfLiteCategory> categories;
}

class TfLiteCategory extends Struct {
external Pointer<Utf8> text;

@Double()
external double score;
}

class TfLiteBertNLClassifier extends Opaque {}

class TfLiteBertNLClassifierOptions extends Struct {
@Int32()
external int maxSeqLen;
}
47 changes: 47 additions & 0 deletions lib/src/task/bindings/text/qa/bert_qa.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import 'dart:ffi';

import 'package:ffi/ffi.dart';
import 'types.dart';

import 'package:tflite_flutter_helper/src/task/bindings/dlib.dart';

// ignore_for_file: non_constant_identifier_names, camel_case_types

// Creates BertQuestionAnswerer from model path, returns nullptr if the file
// doesn't exist or is not a well formatted TFLite model path.
Pointer<TfLiteBertQuestionAnswerer> Function(Pointer<Utf8> modelPath)
BertQuestionAnswererFromFile = tflitelib
.lookup<NativeFunction<_BertQuestionAnswererFromFile_native_t>>(
'BertQuestionAnswererFromFile')
.asFunction();

typedef _BertQuestionAnswererFromFile_native_t = Pointer<TfLiteBertQuestionAnswerer> Function(
Pointer<Utf8> modelPath);

// Invokes the encapsulated TFLite model and answers a question based on
// context.
Pointer<QaAnswers> Function(Pointer<TfLiteBertQuestionAnswerer> questionAnswerer,
Pointer<Utf8> context, Pointer<Utf8> question)
BertQuestionAnswererAnswer = tflitelib
.lookup<NativeFunction<_BertQuestionAnswererAnswer_native_t>>(
'BertQuestionAnswererAnswer')
.asFunction();

typedef _BertQuestionAnswererAnswer_native_t = Pointer<QaAnswers> Function(Pointer<TfLiteBertQuestionAnswerer> questionAnswerer,
Pointer<Utf8> context, Pointer<Utf8> question);

// Deletes BertQuestionAnswerer instance
void Function(Pointer<TfLiteBertQuestionAnswerer>) BertQuestionAnswererDelete = tflitelib
.lookup<NativeFunction<_BertQuestionAnswererDelete_native_t>>(
'BertQuestionAnswererDelete')
.asFunction();

typedef _BertQuestionAnswererDelete_native_t = Void Function(Pointer<TfLiteBertQuestionAnswerer>);

// Deletes BertQuestionAnswererQaAnswers instance
void Function(Pointer<QaAnswers>) BertQuestionAnswererQaAnswersDelete = tflitelib
.lookup<NativeFunction<_BertQuestionAnswererQaAnswersDelete_native_t>>(
'BertQuestionAnswererQaAnswersDelete')
.asFunction();

typedef _BertQuestionAnswererQaAnswersDelete_native_t = Void Function(Pointer<QaAnswers>);
23 changes: 23 additions & 0 deletions lib/src/task/bindings/text/qa/types.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import 'dart:ffi';

import 'package:ffi/ffi.dart';

class TfLiteBertQuestionAnswerer extends Opaque {}

class QaAnswer extends Struct {
@Int32()
external int start;
@Int32()
external int end;
@Double()
external double logit;

external Pointer<Utf8> text;
}

class QaAnswers extends Struct {
@Int32()
external int size;

Pointer<QaAnswer> answers;
}
Empty file.
37 changes: 37 additions & 0 deletions lib/src/task/text/nl_classifier/nl_classifier.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import 'dart:ffi';
import 'package:ffi/ffi.dart';
import 'package:quiver/check.dart';
import 'package:tflite_flutter_helper/src/label/category.dart';
import 'package:tflite_flutter_helper/src/task/bindings/text/nl_classifier/nl_classifer.dart';
import 'package:tflite_flutter_helper/src/task/bindings/text/nl_classifier/types.dart';
import 'package:tflite_flutter_helper/src/task/text/nl_classifier/nl_classifier_options.dart';

class NLClassifier {
final Pointer<TfLiteNLClassifier> _classifier;
bool _deleted = false;
Pointer<TfLiteNLClassifier> get base => _classifier;

NLClassifier._(this._classifier);

factory NLClassifier._create(String modelPath, NLClassifierOptions options) {
final classiferPtr =
NLClassifierFromFileAndOptions(modelPath.toNativeUtf8(), options.base);
return NLClassifier._(classiferPtr);
}

List<Category> classify(String text) {
final ref = NLClassifierClassify(base, text.toNativeUtf8()).ref;
var categoryList = List.generate(
ref.size,
(i) => Category(
ref.categories[i].text.toDartString(), ref.categories[i].score),
);
return categoryList;
}

void delete() {
checkState(!_deleted, message: 'NLCLassifier already deleted.');
NLClassifierDelete(base);
_deleted = true;
}
}
64 changes: 64 additions & 0 deletions lib/src/task/text/nl_classifier/nl_classifier_options.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import 'dart:ffi';

import 'package:ffi/ffi.dart';
import 'package:quiver/check.dart';
import 'package:tflite_flutter_helper/src/task/bindings/text/nl_classifier/types.dart';

class NLClassifierOptions {
final Pointer<TfLiteNLClassifierOptions> _options;
bool _deleted = false;

Pointer<TfLiteNLClassifierOptions> get base => _options;
NLClassifierOptions._(this._options);

/// Creates a new options instance.
factory NLClassifierOptions() {
final optionsPtr = calloc<TfLiteNLClassifierOptions>();
return NLClassifierOptions._(optionsPtr);
}

int get inputTensorIndex => base.ref.inputTensorIndex;

set inputTensorIndex(int value) {
base.ref.inputTensorIndex = value;
}

int get outputScoreTensorIndex => base.ref.outputScoreTensorIndex;

set outputScoreTensorIndex(int value) {
base.ref.outputScoreTensorIndex = value;
}

int get outputLabelTensorIndex => base.ref.outputLabelTensorIndex;

set outputLabelTensorIndex(int value) {
base.ref.outputLabelTensorIndex = value;
}

String get inputTensorName => base.ref.inputTensorName.toDartString();

set inputTensorName(String value) {
base.ref.inputTensorName = value.toNativeUtf8();
}

String get outputScoreTensorName =>
base.ref.outputScoreTensorName.toDartString();

set outputScoreTensorName(String value) {
base.ref.outputScoreTensorName = value.toNativeUtf8();
}

String get outputLabelTensorName =>
base.ref.outputLabelTensorName.toDartString();

set outputLabelTensorName(String value) {
base.ref.outputLabelTensorName = value.toNativeUtf8();
}

/// Destroys the options instance.
void delete() {
checkState(!_deleted, message: 'NLClassifierOptions already deleted.');
calloc.free(_options);
_deleted = true;
}
}
Empty file.
2 changes: 1 addition & 1 deletion pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ packages:
source: hosted
version: "1.2.0"
ffi:
dependency: transitive
dependency: "direct main"
description:
name: ffi
url: "https://pub.dartlang.org"
Expand Down
Loading

0 comments on commit a063b1c

Please sign in to comment.