forked from am15h/tflite_flutter_helper
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add documentation for question answerer
- Loading branch information
Showing
7 changed files
with
200 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import 'dart:ffi'; | ||
import 'dart:io'; | ||
|
||
import 'package:quiver/check.dart'; | ||
import 'package:tflite_flutter_helper/src/common/file_util.dart'; | ||
import 'package:tflite_flutter_helper/src/task/bindings/text/qa/bert_qa.dart'; | ||
import 'package:tflite_flutter_helper/src/task/bindings/text/qa/types.dart'; | ||
|
||
import 'package:ffi/ffi.dart'; | ||
import 'package:tflite_flutter_helper/src/task/text/qa/question_answerer.dart'; | ||
|
||
import 'qa_answer.dart'; | ||
|
||
/// Task API for BertQA models. */ | ||
class BertQuestionAnswerer implements QuestionAnswerer { | ||
final Pointer<TfLiteBertQuestionAnswerer> _classifier; | ||
bool _deleted = false; | ||
Pointer<TfLiteBertQuestionAnswerer> get base => _classifier; | ||
|
||
BertQuestionAnswerer._(this._classifier); | ||
|
||
/// Generic API to create the QuestionAnswerer for bert models with metadata populated. The API | ||
/// expects a Bert based TFLite model with metadata containing the following information: | ||
/// | ||
/// <ul> | ||
/// <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be | ||
/// used for a <a | ||
/// href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a> | ||
/// model, Sentencepiece Tokenizer Tokenizer can be used for an <a | ||
/// href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a> | ||
/// model. | ||
/// <li>3 input tensors with names "ids", "mask" and "segment_ids". | ||
/// <li>2 output tensors with names "end_logits" and "start_logits". | ||
/// </ul> | ||
/// | ||
/// Creates [BertQuestionAnswerer] from [modelPath]. | ||
/// | ||
/// [modelPath] is the path of the .tflite model loaded on device. | ||
/// | ||
/// throws [FileSystemException] If model file fails to load. | ||
static BertQuestionAnswerer create(String modelPath) { | ||
final nativePtr = BertQuestionAnswererFromFile(modelPath.toNativeUtf8()); | ||
if (nativePtr == nullptr) { | ||
throw FileSystemException( | ||
"Failed to create BertQuestionAnswerer.", modelPath); | ||
} | ||
return BertQuestionAnswerer._(nativePtr); | ||
} | ||
|
||
/// Generic API to create the QuestionAnswerer for bert models with metadata populated. The API | ||
/// expects a Bert based TFLite model with metadata containing the following information: | ||
/// | ||
/// <ul> | ||
/// <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be | ||
/// used for a <a | ||
/// href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a> | ||
/// model, Sentencepiece Tokenizer Tokenizer can be used for an <a | ||
/// href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a> | ||
/// model. | ||
/// <li>3 input tensors with names "ids", "mask" and "segment_ids". | ||
/// <li>2 output tensors with names "end_logits" and "start_logits". | ||
/// </ul> | ||
/// | ||
/// Create [BertQuestionAnswerer] from [modelFile]. | ||
/// | ||
/// throws [FileSystemException] If model file fails to load. | ||
static BertQuestionAnswerer createFromFile(File modelFile) { | ||
return create(modelFile.path); | ||
} | ||
|
||
/// Create [BertQuestionAnswerer] directly from [assetPath]. | ||
/// | ||
/// [assetPath] must the full path to assets. Eg. 'assets/my_model.tflite'. | ||
/// | ||
/// throws [FileSystemException] If model file fails to load. | ||
static Future<BertQuestionAnswerer> createFromAsset(String assetPath) async { | ||
final modelFile = await FileUtil.loadFileOnDevice(assetPath); | ||
return create(modelFile.path); | ||
} | ||
|
||
@override | ||
List<QaAnswer> answer(String context, String question) { | ||
final ref = BertQuestionAnswererAnswer( | ||
base, context.toNativeUtf8(), question.toNativeUtf8()) | ||
.ref; | ||
final qaList = List.generate( | ||
ref.size, | ||
(i) => QaAnswer( | ||
Pos(ref.answers[i].start, ref.answers[i].end, ref.answers[i].logit), | ||
ref.answers[i].text.toDartString(), | ||
), | ||
); | ||
return qaList; | ||
} | ||
|
||
/// Deletes BertQuestionAnswerer native instance. | ||
void delete() { | ||
checkState(!_deleted, message: 'NLCLassifier already deleted.'); | ||
BertQuestionAnswererDelete(base); | ||
_deleted = true; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
/// Answers to [QuestionAnswerer]. Contains information about the answer and its relative | ||
/// position information to the context. | ||
class QaAnswer { | ||
Pos pos; | ||
String text; | ||
|
||
QaAnswer(this.pos, this.text); | ||
} | ||
|
||
/// Position information of the answer relative to context. It is sortable in descending order | ||
/// based on logit. | ||
class Pos implements Comparable<Pos> { | ||
int start; | ||
int end; | ||
double logit; | ||
|
||
Pos(this.start, this.end, this.logit); | ||
|
||
@override | ||
int compareTo(Pos other) { | ||
return other.logit.compareTo(this.logit); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import 'package:tflite_flutter_helper/src/task/text/qa/qa_answer.dart'; | ||
|
||
/// API to answer questions based on context. */ | ||
abstract class QuestionAnswerer { | ||
/// Answers [question] based on [context], and returns a list of possible [QaAnswer]s. Could be | ||
/// empty if no answer was found from the given context. | ||
List<QaAnswer> answer(String context, String question); | ||
} |