diff --git a/src/unitxt/dataset_utils.py b/src/unitxt/dataset_utils.py index b32fe86fc..c837b37f0 100644 --- a/src/unitxt/dataset_utils.py +++ b/src/unitxt/dataset_utils.py @@ -5,7 +5,7 @@ from .parsing_utils import parse_key_equals_value_string_to_dict from .register import _reset_env_local_catalogs, register_all_artifacts from .settings_utils import get_settings -from .standard import BaseRecipe +from .standard import DatasetRecipe logger = get_logger() settings = get_settings() @@ -24,7 +24,7 @@ def parse(query: str): def get_dataset_artifact(dataset): - if isinstance(dataset, BaseRecipe): + if isinstance(dataset, DatasetRecipe): return dataset assert isinstance( dataset, str diff --git a/src/unitxt/standard.py b/src/unitxt/standard.py index eadd4ee40..a98793719 100644 --- a/src/unitxt/standard.py +++ b/src/unitxt/standard.py @@ -136,7 +136,7 @@ def from_stream_generator( return set_demos_pool(ms) -class BaseRecipe(SourceSequentialOperator): +class DatasetRecipe(SourceSequentialOperator): """This class represents a standard recipe for data processing and preparation. This class can be used to prepare a recipe. @@ -676,11 +676,3 @@ def prepare(self): if isinstance(self.template, TemplatesList): self.template = self.template.items self.reset_pipeline() - - -class StandardRecipeWithIndexes(BaseRecipe): - pass - - -class DatasetRecipe(StandardRecipeWithIndexes): - pass diff --git a/tests/library/test_recipe.py b/tests/library/test_recipe.py index 3c4440daa..3fb39003f 100644 --- a/tests/library/test_recipe.py +++ b/tests/library/test_recipe.py @@ -10,7 +10,7 @@ from unitxt.formats import SystemFormat from unitxt.loaders import LoadFromDictionary from unitxt.serializers import SingleTypeSerializer, TableSerializer -from unitxt.standard import DatasetRecipe, StandardRecipeWithIndexes +from unitxt.standard import DatasetRecipe from unitxt.task import Task from unitxt.templates import InputOutputTemplate, TemplatesList from unitxt.text_utils import print_dict @@ -293,7 +293,7 @@ def test_dataset_recipe_with_demoed_instances(self): processed_input_instance["source"], ) - def test_dataset_recipe_with_indexes_with_catalog(self): + def test_dataset_recipe_with_catalog_wnli(self): recipe = DatasetRecipe( card="cards.wnli", system_prompt="system_prompts.models.llama", @@ -325,7 +325,7 @@ def test_dataset_recipe_with_demos_not_removed_from_data(self): "demos_pool_size" ] - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( card="cards.wnli", template_card_index=0, demos_pool_size=100, @@ -346,7 +346,7 @@ def test_dataset_recipe_with_demos_not_removed_from_data(self): self.assertEqual(n_demos_keep_demos, n_demos_remove_demos) def test_empty_template(self): - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( card="cards.wnli", system_prompt="system_prompts.models.llama", template="templates.empty", @@ -378,7 +378,7 @@ def test_empty_template(self): self.assertDictEqual(target_task_data, result_task_data) def test_key_val_template(self): - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( card="cards.wnli", system_prompt="system_prompts.models.llama", template="templates.key_val", @@ -456,7 +456,7 @@ def test_key_val_template(self): self.assertDictEqual(target_task_data, result_task_data) def test_random_template(self): - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( card="cards.wnli", system_prompt="system_prompts.models.llama", template=[ @@ -497,7 +497,7 @@ def test_random_template_with_templates_list(self): "templates.classification.multi_class.relation.truthfulness.flan_5", ] ) - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( card="cards.wnli", system_prompt="system_prompts.models.llama", template=templates, @@ -529,7 +529,7 @@ def test_random_template_with_templates_list(self): self.assertDictEqual(result, target) def test_random_num_demos(self): - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( card="cards.wnli", system_prompt="system_prompts.models.llama", template="templates.key_val", @@ -548,7 +548,7 @@ def test_random_num_demos(self): self.assertEqual(len(lengths), 4) def test_dataset_recipe_with_balancer(self): - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( card="cards.wnli", system_prompt="system_prompts.models.llama", template="templates.key_val", @@ -566,7 +566,7 @@ def test_dataset_recipe_with_balancer(self): self.assertEqual(counts["entailment"], counts["not entailment"]) def test_dataset_recipe_with_loader_limit(self): - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( card="cards.wnli", system_prompt="system_prompts.models.llama", template="templates.key_val", @@ -584,7 +584,7 @@ def test_dataset_recipe_with_loader_limit(self): def test_dataset_recipe_with_loader_limit_errors(self): with self.assertRaises(ValueError): - StandardRecipeWithIndexes( + DatasetRecipe( card="cards.wnli", template="templates.key_val", max_test_instances=10, @@ -592,14 +592,14 @@ def test_dataset_recipe_with_loader_limit_errors(self): ) with self.assertRaises(ValueError): - StandardRecipeWithIndexes( + DatasetRecipe( card="cards.wnli", template="templates.key_val", max_train_instances=10, loader_limit=9, ) with self.assertRaises(ValueError): - StandardRecipeWithIndexes( + DatasetRecipe( template="templates.key_val", card="cards.wnli", max_validation_instances=10, @@ -607,7 +607,7 @@ def test_dataset_recipe_with_loader_limit_errors(self): ) with self.assertRaises(ValueError): - StandardRecipeWithIndexes( + DatasetRecipe( template="templates.key_val", card="cards.wnli", num_demos=3, @@ -616,7 +616,7 @@ def test_dataset_recipe_with_loader_limit_errors(self): ) def test_dataset_recipe_with_no_demos_to_take(self): - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( template="templates.key_val", card="cards.xwinogrande.pt", num_demos=3, @@ -632,7 +632,7 @@ def test_dataset_recipe_with_no_demos_to_take(self): ) with self.assertRaises(Exception) as cm: - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( template="templates.key_val", card="cards.xwinogrande.pt", num_demos=3, @@ -645,7 +645,7 @@ def test_dataset_recipe_with_no_demos_to_take(self): ) with self.assertRaises(Exception) as cm: - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( template="templates.key_val", card="cards.xwinogrande.pt", num_demos=30, @@ -658,7 +658,7 @@ def test_dataset_recipe_with_no_demos_to_take(self): ) def test_dataset_recipe_with_no_test(self): - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( template="templates.key_val", card="cards.xwinogrande.pt", num_demos=3, @@ -671,7 +671,7 @@ def test_dataset_recipe_with_no_test(self): def test_dataset_recipe_with_template_errors(self): # Check either template or template index was specified , but not both with self.assertRaises(AssertionError) as cm: - StandardRecipeWithIndexes( + DatasetRecipe( card="cards.wnli", template="templates.key_val", template_card_index=100 ) self.assertTrue( @@ -684,7 +684,7 @@ def test_dataset_recipe_with_template_errors(self): # Also check if string index is used with self.assertRaises(AssertionError) as cm: - StandardRecipeWithIndexes( + DatasetRecipe( card="cards.wnli", template="templates.key_val", template_card_index="illegal_template", @@ -699,17 +699,15 @@ def test_dataset_recipe_with_template_errors(self): # Return an error if index is not found in card with self.assertRaises(ValueError) as cm: - StandardRecipeWithIndexes( - card="cards.wnli", template_card_index="illegal_template" - ) + DatasetRecipe(card="cards.wnli", template_card_index="illegal_template") self.assertTrue("not defined in card." in str(cm.exception)) with self.assertRaises(ValueError) as cm: - StandardRecipeWithIndexes(card="cards.wnli", template_card_index=100) + DatasetRecipe(card="cards.wnli", template_card_index=100) self.assertTrue("not defined in card." in str(cm.exception)) def test_dataset_recipe_with_balancer_and_size_limit(self): - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( card="cards.wnli", system_prompt="system_prompts.models.llama", template="templates.key_val", @@ -729,7 +727,7 @@ def test_dataset_recipe_with_balancer_and_size_limit(self): self.assertEqual(counts["entailment"], counts["not entailment"], 10) def test_dataset_recipe_with_augmentor_on_task_input(self): - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( card="cards.sst2", augmentor="augmentors.text.white_space", template_card_index=0, @@ -754,7 +752,7 @@ def test_dataset_recipe_with_augmentor_on_task_input(self): ), f"{normalized_output_source} is not equal to f{normalized_input_source}" def test_dataset_recipe_with_train_size_limit(self): - recipe = StandardRecipeWithIndexes( + recipe = DatasetRecipe( card="cards.wnli", system_prompt="system_prompts.models.llama", template="templates.key_val", @@ -775,7 +773,7 @@ def test_recipe_with_hf_with_twice_the_same_instance_demos(self): d = load_dataset( dataset_file, - "__type__=dataset_recipe_with_indexes,card=cards.wnli,template=templates.classification.multi_class.relation.default,system_prompt=system_prompts.models.llama,demos_pool_size=5,num_demos=1", + "__type__=dataset_recipe,card=cards.wnli,template=templates.classification.multi_class.relation.default,system_prompt=system_prompts.models.llama,demos_pool_size=5,num_demos=1", streaming=True, trust_remote_code=True, ) @@ -800,7 +798,7 @@ def test_dataset_recipe_with_a_missing_sampler(self): task_card, _ = copy.deepcopy(fetch_artifact("cards.sst2")) task_card.sampler = None with self.assertRaises(ValueError) as e: - StandardRecipeWithIndexes( + DatasetRecipe( card=task_card, template_card_index=0, max_train_instances=0,