From a53f14e05f4083e9ae011d1755e687d3ad15ad34 Mon Sep 17 00:00:00 2001 From: dafnapension Date: Sun, 15 Dec 2024 08:45:17 +0200 Subject: [PATCH] returned some recipe_tests Signed-off-by: dafnapension --- src/unitxt/standard.py | 14 ++++---- tests/library/test_recipe.py | 69 +++++++++++++++--------------------- 2 files changed, 35 insertions(+), 48 deletions(-) diff --git a/src/unitxt/standard.py b/src/unitxt/standard.py index c0c959e39..f8deb0880 100644 --- a/src/unitxt/standard.py +++ b/src/unitxt/standard.py @@ -287,7 +287,7 @@ def set_pipelines(self): self.inference_instance.steps = [ self.metadata, - self.processing, + # self.processing, ] self.inference_demos = SourceSequentialOperator() @@ -301,7 +301,7 @@ def set_pipelines(self): self.inference = SequentialOperator() - self.inference.steps = [self.metadata, self.verbalization, self.finalize] + self.inference.steps = [self.verbalization, self.finalize] def production_preprocess(self, task_instances): ms = MultiStream.from_iterables({constants.inference_stream: task_instances}) @@ -320,19 +320,19 @@ def produce(self, task_instances): self.before_process_multi_stream() ms = MultiStream.from_iterables({constants.inference_stream: task_instances}) + ms = self.inference_instance(ms) if not self.use_demos: # go with task_instances all the way, it does not need other streams: - ms = self.inference_instance(ms) + ms = self.processing(ms) ms = self.verbalization(ms) ms = self.finalize(ms) return list(ms[constants.inference_stream]) - ms = self.metadata(ms) - # ready to join, as if passed loading and standardization - streams = self.inference_demos() - # stopped before processing + # streams stopped before processing + # ms is ready to join, as if passed loading and standardization streams[constants.inference_stream] = ms[constants.inference_stream] + multi_stream = MultiStream(streams) multi_stream = self.processing(streams) multi_stream = self.inference(multi_stream) return list(multi_stream[constants.inference_stream]) diff --git a/tests/library/test_recipe.py b/tests/library/test_recipe.py index 757a5378c..8282b7295 100644 --- a/tests/library/test_recipe.py +++ b/tests/library/test_recipe.py @@ -15,6 +15,7 @@ from unitxt.templates import InputOutputTemplate, TemplatesList from unitxt.text_utils import print_dict from unitxt.types import Table +from unitxt.utils import recursive_copy from tests.utils import UnitxtTestCase @@ -106,47 +107,33 @@ def test_standard_recipe_production_without_demos(self): self.assertDictEqual(result, target) - # def test_standard_recipe_production_consistency(self): - # recipe = StandardRecipe( - # card="cards.mmlu.marketing", - # system_prompt="system_prompts.models.llama", - # template="templates.qa.multiple_choice.with_topic.lm_eval_harness", - # format="formats.user_agent", - # demos_pool_size=5, - # num_demos=1, - # ) - - # instances = [ - # { - # "question": "what?", - # "choices": ["yes", "not", "maybe"], - # "answer": "maybe", - # "topic": "testing", - # } - # ] - - # self.assertListEqual( - # recipe.production_demos_pool(), recipe.production_demos_pool() - # ) - - # self.assertDictEqual( - # recipe.produce(instances)[0], - # recipe.produce(instances)[0], - # ) - - # i1 = recipe.production_preprocess(instances)[0] - # i2 = recipe.production_preprocess(instances)[0] - # for meta_data in ["card", "template", "format", "system_prompt"]: - # if meta_data in i1["recipe_metadata"]: - # i1["recipe_metadata"][meta_data] = i1["recipe_metadata"][ - # meta_data - # ]._to_raw_dict() - # if not isinstance(i2["recipe_metadata"][meta_data], dict): - # i2["recipe_metadata"][meta_data] = i2["recipe_metadata"][ - # meta_data - # ]._to_raw_dict() - - # self.assertDictEqual(i1, i2) + def test_standard_recipe_production_consistency(self): + recipe = StandardRecipe( + card="cards.mmlu.marketing", + system_prompt="system_prompts.models.llama", + template="templates.qa.multiple_choice.with_topic.lm_eval_harness", + format="formats.user_agent", + demos_pool_size=5, + num_demos=1, + ) + + instances = [ + { + "question": "what?", + "choices": ["yes", "not", "maybe"], + "answer": "maybe", + "topic": "testing", + } + ] + + self.assertDictEqual( + recipe.produce(recursive_copy(instances))[0], + recipe.produce(recursive_copy(instances))[0], + ) + + i1 = recipe.production_preprocess(recursive_copy(instances))[0] + i2 = recipe.production_preprocess(recursive_copy(instances))[0] + self.assertDictEqual(i1, i2) def test_standard_recipe_production_with_demos(self): recipe = StandardRecipe(