diff --git a/optimum_benchmark/generators/task_generator.py b/optimum_benchmark/generators/task_generator.py index 683d8963..f0b8c984 100644 --- a/optimum_benchmark/generators/task_generator.py +++ b/optimum_benchmark/generators/task_generator.py @@ -165,6 +165,25 @@ def __call__(self): return dummy +class Text2TextGenerationGenerator(TextGenerator): + def __call__(self): + dummy = {} + dummy["input_ids"] = self.input_ids() + dummy["decoder_input_ids"] = self.input_ids() + dummy["attention_mask"] = self.attention_mask() + + if self.requires_token_type_ids(): + dummy["token_type_ids"] = self.token_type_ids() + + if self.requires_position_ids(): + dummy["position_ids"] = self.position_ids() + + if self.with_labels: + dummy["labels"] = self.input_ids() + + return dummy + + class QuestionAnsweringGenerator(TextGenerator): def start_positions(self): return self.generate_random_integers( @@ -369,7 +388,7 @@ def __call__(self): "text-classification": TextClassificationGenerator, "token-classification": TokenClassificationGenerator, "text-generation": TextGenerationGenerator, - "text2text-generation": TextGenerationGenerator, + "text2text-generation": Text2TextGenerationGenerator, "question-answering": QuestionAnsweringGenerator, "fill-mask": MaskedLanguageModelingGenerator, "multiple-choice": MultipleChoiceGenerator,