Skip to content

Commit

Permalink
add documentation for question answerer
Browse files Browse the repository at this point in the history
  • Loading branch information
am15h committed Aug 19, 2021
1 parent 3a4481a commit 0052afa
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 62 deletions.
65 changes: 59 additions & 6 deletions lib/src/task/text/nl_classifier/bert_nl_classifier.dart
Original file line number Diff line number Diff line change
@@ -1,38 +1,91 @@
import 'dart:ffi';
import 'dart:io';
import 'package:ffi/ffi.dart';
import 'package:quiver/check.dart';
import 'package:tflite_flutter_helper/src/common/file_util.dart';
import 'package:tflite_flutter_helper/src/label/category.dart';
import 'package:tflite_flutter_helper/src/task/bindings/text/nl_classifier/bert_nl_classifier.dart';
import 'package:tflite_flutter_helper/src/task/bindings/text/nl_classifier/types.dart';
import 'package:tflite_flutter_helper/src/task/text/nl_classifier/bert_nl_classifier_options.dart';

/// Classifier API for NLClassification tasks with Bert models, categorizes string into different
/// classes. The API expects a Bert based TFLite model with metadata populated.
///
/// <p>The metadata should contain the following information:
///
/// <ul>
/// <li>1 input_process_unit for Wordpiece/Sentencepiece Tokenizer.
/// <li>3 input tensors with names "ids", "mask" and "segment_ids".
/// <li>1 output tensor of type float32[1, 2], with a optionally attached label file. If a label
/// file is attached, the file should be a plain text file with one label per line, the number
/// of labels should match the number of categories the model outputs.
/// </ul>
class BertNLClassifier {
final Pointer<TfLiteBertNLClassifier> _classifier;
bool _deleted = false;
Pointer<TfLiteBertNLClassifier> get base => _classifier;

BertNLClassifier._(this._classifier);

factory BertNLClassifier.create(String modelPath,
/// Create [BertNLClassifier] from [modelPath] and optional [options].
///
/// [modelPath] is the path of the .tflite model loaded on device.
///
/// throws [FileSystemException] If model file fails to load.
static BertNLClassifier create(String modelPath,
{BertNLClassifierOptions? options}) {
if(options == null) {
if (options == null) {
options = BertNLClassifierOptions();
}
final classiferPtr =
BertNLClassifierFromFileAndOptions(modelPath.toNativeUtf8(), options.base);
return BertNLClassifier._(classiferPtr);
final nativePtr = BertNLClassifierFromFileAndOptions(
modelPath.toNativeUtf8(), options.base);
if (nativePtr == nullptr) {
throw FileSystemException(
"Failed to create BertNLClassifier.", modelPath);
}
return BertNLClassifier._(nativePtr);
}

/// Create [BertNLClassifier] from [modelFile].
///
/// throws [FileSystemException] If model file fails to load.
static BertNLClassifier createFromFile(File modelFile) {
return create(modelFile.path);
}

/// Create [BertNLClassifier] from [modelFile] and [options].
///
/// throws [FileSystemException] If model file fails to load.
static BertNLClassifier createFromFileAndOptions(
File modelFile, BertNLClassifierOptions options) {
return create(modelFile.path, options: options);
}

/// Create [BertNLClassifier] directly from [assetPath] and optional [options].
///
/// [assetPath] must the full path to assets. Eg. 'assets/my_model.tflite'.
///
/// throws [FileSystemException] If model file fails to load.
static Future<BertNLClassifier> createFromAsset(String assetPath,
{BertNLClassifierOptions? options}) async {
final modelFile = await FileUtil.loadFileOnDevice(assetPath);
return create(modelFile.path, options: options);
}

/// Perform classification on a string input [text],
///
/// Returns classified [Category]s as List.
List<Category> classify(String text) {
final ref = BertNLClassifierClassify(base, text.toNativeUtf8()).ref;
final categoryList = List.generate(
ref.size,
(i) => Category(
(i) => Category(
ref.categories[i].text.toDartString(), ref.categories[i].score),
);
return categoryList;
}

/// Deletes BertNLClassifier Instance.
void delete() {
checkState(!_deleted, message: 'BertNLClassifier already deleted.');
BertNLClassifierDelete(base);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@ import 'package:ffi/ffi.dart';
import 'package:quiver/check.dart';
import 'package:tflite_flutter_helper/src/task/bindings/text/nl_classifier/types.dart';

/// Options to configure BertNLClassifier.
class BertNLClassifierOptions {
final Pointer<TfLiteBertNLClassifierOptions> _options;
bool _deleted = false;

Pointer<TfLiteBertNLClassifierOptions> get base => _options;

BertNLClassifierOptions._(this._options);

static const int DEFAULT_MAX_SEQ_LEN = 128;

/// Creates a new options instance.
factory BertNLClassifierOptions() {
final optionsPtr = TfLiteBertNLClassifierOptions.allocate(0);
final optionsPtr = TfLiteBertNLClassifierOptions.allocate(DEFAULT_MAX_SEQ_LEN);
return BertNLClassifierOptions._(optionsPtr);
}

Expand Down
6 changes: 3 additions & 3 deletions lib/src/task/text/nl_classifier/nl_classifier.dart
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ class NLClassifier {
if (options == null) {
options = NLClassifierOptions();
}
final classiferPtr =
final nativePtr =
NLClassifierFromFileAndOptions(modelPath.toNativeUtf8(), options.base);
if (classiferPtr == nullptr) {
if (nativePtr == nullptr) {
throw FileSystemException("Failed to create NLClassifier.", modelPath);
}
return NLClassifier._(classiferPtr);
return NLClassifier._(nativePtr);
}

/// Create [NLClassifier] from [modelFile].
Expand Down
51 changes: 0 additions & 51 deletions lib/src/task/text/qa/bert_qa.dart

This file was deleted.

102 changes: 102 additions & 0 deletions lib/src/task/text/qa/bert_question_answerer.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import 'dart:ffi';
import 'dart:io';

import 'package:quiver/check.dart';
import 'package:tflite_flutter_helper/src/common/file_util.dart';
import 'package:tflite_flutter_helper/src/task/bindings/text/qa/bert_qa.dart';
import 'package:tflite_flutter_helper/src/task/bindings/text/qa/types.dart';

import 'package:ffi/ffi.dart';
import 'package:tflite_flutter_helper/src/task/text/qa/question_answerer.dart';

import 'qa_answer.dart';

/// Task API for BertQA models. */
class BertQuestionAnswerer implements QuestionAnswerer {
final Pointer<TfLiteBertQuestionAnswerer> _classifier;
bool _deleted = false;
Pointer<TfLiteBertQuestionAnswerer> get base => _classifier;

BertQuestionAnswerer._(this._classifier);

/// Generic API to create the QuestionAnswerer for bert models with metadata populated. The API
/// expects a Bert based TFLite model with metadata containing the following information:
///
/// <ul>
/// <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be
/// used for a <a
/// href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a>
/// model, Sentencepiece Tokenizer Tokenizer can be used for an <a
/// href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a>
/// model.
/// <li>3 input tensors with names "ids", "mask" and "segment_ids".
/// <li>2 output tensors with names "end_logits" and "start_logits".
/// </ul>
///
/// Creates [BertQuestionAnswerer] from [modelPath].
///
/// [modelPath] is the path of the .tflite model loaded on device.
///
/// throws [FileSystemException] If model file fails to load.
static BertQuestionAnswerer create(String modelPath) {
final nativePtr = BertQuestionAnswererFromFile(modelPath.toNativeUtf8());
if (nativePtr == nullptr) {
throw FileSystemException(
"Failed to create BertQuestionAnswerer.", modelPath);
}
return BertQuestionAnswerer._(nativePtr);
}

/// Generic API to create the QuestionAnswerer for bert models with metadata populated. The API
/// expects a Bert based TFLite model with metadata containing the following information:
///
/// <ul>
/// <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be
/// used for a <a
/// href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a>
/// model, Sentencepiece Tokenizer Tokenizer can be used for an <a
/// href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a>
/// model.
/// <li>3 input tensors with names "ids", "mask" and "segment_ids".
/// <li>2 output tensors with names "end_logits" and "start_logits".
/// </ul>
///
/// Create [BertQuestionAnswerer] from [modelFile].
///
/// throws [FileSystemException] If model file fails to load.
static BertQuestionAnswerer createFromFile(File modelFile) {
return create(modelFile.path);
}

/// Create [BertQuestionAnswerer] directly from [assetPath].
///
/// [assetPath] must the full path to assets. Eg. 'assets/my_model.tflite'.
///
/// throws [FileSystemException] If model file fails to load.
static Future<BertQuestionAnswerer> createFromAsset(String assetPath) async {
final modelFile = await FileUtil.loadFileOnDevice(assetPath);
return create(modelFile.path);
}

@override
List<QaAnswer> answer(String context, String question) {
final ref = BertQuestionAnswererAnswer(
base, context.toNativeUtf8(), question.toNativeUtf8())
.ref;
final qaList = List.generate(
ref.size,
(i) => QaAnswer(
Pos(ref.answers[i].start, ref.answers[i].end, ref.answers[i].logit),
ref.answers[i].text.toDartString(),
),
);
return qaList;
}

/// Deletes BertQuestionAnswerer native instance.
void delete() {
checkState(!_deleted, message: 'NLCLassifier already deleted.');
BertQuestionAnswererDelete(base);
_deleted = true;
}
}
23 changes: 23 additions & 0 deletions lib/src/task/text/qa/qa_answer.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/// Answers to [QuestionAnswerer]. Contains information about the answer and its relative
/// position information to the context.
class QaAnswer {
Pos pos;
String text;

QaAnswer(this.pos, this.text);
}

/// Position information of the answer relative to context. It is sortable in descending order
/// based on logit.
class Pos implements Comparable<Pos> {
int start;
int end;
double logit;

Pos(this.start, this.end, this.logit);

@override
int compareTo(Pos other) {
return other.logit.compareTo(this.logit);
}
}
8 changes: 8 additions & 0 deletions lib/src/task/text/qa/question_answerer.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import 'package:tflite_flutter_helper/src/task/text/qa/qa_answer.dart';

/// API to answer questions based on context. */
abstract class QuestionAnswerer {
/// Answers [question] based on [context], and returns a list of possible [QaAnswer]s. Could be
/// empty if no answer was found from the given context.
List<QaAnswer> answer(String context, String question);
}

0 comments on commit 0052afa

Please sign in to comment.