Skip to content

Commit

Permalink
returned some recipe_tests
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <[email protected]>
  • Loading branch information
dafnapension committed Dec 15, 2024
1 parent d66c2dd commit a53f14e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 48 deletions.
14 changes: 7 additions & 7 deletions src/unitxt/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def set_pipelines(self):

self.inference_instance.steps = [
self.metadata,
self.processing,
# self.processing,
]

self.inference_demos = SourceSequentialOperator()
Expand All @@ -301,7 +301,7 @@ def set_pipelines(self):

self.inference = SequentialOperator()

self.inference.steps = [self.metadata, self.verbalization, self.finalize]
self.inference.steps = [self.verbalization, self.finalize]

def production_preprocess(self, task_instances):
ms = MultiStream.from_iterables({constants.inference_stream: task_instances})
Expand All @@ -320,19 +320,19 @@ def produce(self, task_instances):
self.before_process_multi_stream()

ms = MultiStream.from_iterables({constants.inference_stream: task_instances})
ms = self.inference_instance(ms)
if not self.use_demos:
# go with task_instances all the way, it does not need other streams:
ms = self.inference_instance(ms)
ms = self.processing(ms)
ms = self.verbalization(ms)
ms = self.finalize(ms)
return list(ms[constants.inference_stream])

ms = self.metadata(ms)
# ready to join, as if passed loading and standardization

streams = self.inference_demos()
# stopped before processing
# streams stopped before processing
# ms is ready to join, as if passed loading and standardization
streams[constants.inference_stream] = ms[constants.inference_stream]
multi_stream = MultiStream(streams)
multi_stream = self.processing(streams)
multi_stream = self.inference(multi_stream)
return list(multi_stream[constants.inference_stream])
Expand Down
69 changes: 28 additions & 41 deletions tests/library/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from unitxt.templates import InputOutputTemplate, TemplatesList
from unitxt.text_utils import print_dict
from unitxt.types import Table
from unitxt.utils import recursive_copy

from tests.utils import UnitxtTestCase

Expand Down Expand Up @@ -106,47 +107,33 @@ def test_standard_recipe_production_without_demos(self):

self.assertDictEqual(result, target)

# def test_standard_recipe_production_consistency(self):
# recipe = StandardRecipe(
# card="cards.mmlu.marketing",
# system_prompt="system_prompts.models.llama",
# template="templates.qa.multiple_choice.with_topic.lm_eval_harness",
# format="formats.user_agent",
# demos_pool_size=5,
# num_demos=1,
# )

# instances = [
# {
# "question": "what?",
# "choices": ["yes", "not", "maybe"],
# "answer": "maybe",
# "topic": "testing",
# }
# ]

# self.assertListEqual(
# recipe.production_demos_pool(), recipe.production_demos_pool()
# )

# self.assertDictEqual(
# recipe.produce(instances)[0],
# recipe.produce(instances)[0],
# )

# i1 = recipe.production_preprocess(instances)[0]
# i2 = recipe.production_preprocess(instances)[0]
# for meta_data in ["card", "template", "format", "system_prompt"]:
# if meta_data in i1["recipe_metadata"]:
# i1["recipe_metadata"][meta_data] = i1["recipe_metadata"][
# meta_data
# ]._to_raw_dict()
# if not isinstance(i2["recipe_metadata"][meta_data], dict):
# i2["recipe_metadata"][meta_data] = i2["recipe_metadata"][
# meta_data
# ]._to_raw_dict()

# self.assertDictEqual(i1, i2)
def test_standard_recipe_production_consistency(self):
recipe = StandardRecipe(
card="cards.mmlu.marketing",
system_prompt="system_prompts.models.llama",
template="templates.qa.multiple_choice.with_topic.lm_eval_harness",
format="formats.user_agent",
demos_pool_size=5,
num_demos=1,
)

instances = [
{
"question": "what?",
"choices": ["yes", "not", "maybe"],
"answer": "maybe",
"topic": "testing",
}
]

self.assertDictEqual(
recipe.produce(recursive_copy(instances))[0],
recipe.produce(recursive_copy(instances))[0],
)

i1 = recipe.production_preprocess(recursive_copy(instances))[0]
i2 = recipe.production_preprocess(recursive_copy(instances))[0]
self.assertDictEqual(i1, i2)

def test_standard_recipe_production_with_demos(self):
recipe = StandardRecipe(
Expand Down

0 comments on commit a53f14e

Please sign in to comment.