Skip to content

Commit

Permalink
all recipes are DatasetRecipe
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <[email protected]>
  • Loading branch information
dafnapension committed Dec 20, 2024
1 parent 5284d1e commit 934c9a4
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 40 deletions.
4 changes: 2 additions & 2 deletions src/unitxt/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .parsing_utils import parse_key_equals_value_string_to_dict
from .register import _reset_env_local_catalogs, register_all_artifacts
from .settings_utils import get_settings
from .standard import BaseRecipe
from .standard import DatasetRecipe

logger = get_logger()
settings = get_settings()
Expand All @@ -24,7 +24,7 @@ def parse(query: str):


def get_dataset_artifact(dataset):
if isinstance(dataset, BaseRecipe):
if isinstance(dataset, DatasetRecipe):
return dataset
assert isinstance(
dataset, str
Expand Down
10 changes: 1 addition & 9 deletions src/unitxt/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def from_stream_generator(
return set_demos_pool(ms)


class BaseRecipe(SourceSequentialOperator):
class DatasetRecipe(SourceSequentialOperator):
"""This class represents a standard recipe for data processing and preparation.
This class can be used to prepare a recipe.
Expand Down Expand Up @@ -676,11 +676,3 @@ def prepare(self):
if isinstance(self.template, TemplatesList):
self.template = self.template.items
self.reset_pipeline()


class StandardRecipeWithIndexes(BaseRecipe):
pass


class DatasetRecipe(StandardRecipeWithIndexes):
pass
56 changes: 27 additions & 29 deletions tests/library/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from unitxt.formats import SystemFormat
from unitxt.loaders import LoadFromDictionary
from unitxt.serializers import SingleTypeSerializer, TableSerializer
from unitxt.standard import DatasetRecipe, StandardRecipeWithIndexes
from unitxt.standard import DatasetRecipe
from unitxt.task import Task
from unitxt.templates import InputOutputTemplate, TemplatesList
from unitxt.text_utils import print_dict
Expand Down Expand Up @@ -293,7 +293,7 @@ def test_dataset_recipe_with_demoed_instances(self):
processed_input_instance["source"],
)

def test_dataset_recipe_with_indexes_with_catalog(self):
def test_dataset_recipe_with_catalog_wnli(self):
recipe = DatasetRecipe(
card="cards.wnli",
system_prompt="system_prompts.models.llama",
Expand Down Expand Up @@ -325,7 +325,7 @@ def test_dataset_recipe_with_demos_not_removed_from_data(self):
"demos_pool_size"
]

recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
card="cards.wnli",
template_card_index=0,
demos_pool_size=100,
Expand All @@ -346,7 +346,7 @@ def test_dataset_recipe_with_demos_not_removed_from_data(self):
self.assertEqual(n_demos_keep_demos, n_demos_remove_demos)

def test_empty_template(self):
recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
card="cards.wnli",
system_prompt="system_prompts.models.llama",
template="templates.empty",
Expand Down Expand Up @@ -378,7 +378,7 @@ def test_empty_template(self):
self.assertDictEqual(target_task_data, result_task_data)

def test_key_val_template(self):
recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
card="cards.wnli",
system_prompt="system_prompts.models.llama",
template="templates.key_val",
Expand Down Expand Up @@ -456,7 +456,7 @@ def test_key_val_template(self):
self.assertDictEqual(target_task_data, result_task_data)

def test_random_template(self):
recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
card="cards.wnli",
system_prompt="system_prompts.models.llama",
template=[
Expand Down Expand Up @@ -497,7 +497,7 @@ def test_random_template_with_templates_list(self):
"templates.classification.multi_class.relation.truthfulness.flan_5",
]
)
recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
card="cards.wnli",
system_prompt="system_prompts.models.llama",
template=templates,
Expand Down Expand Up @@ -529,7 +529,7 @@ def test_random_template_with_templates_list(self):
self.assertDictEqual(result, target)

def test_random_num_demos(self):
recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
card="cards.wnli",
system_prompt="system_prompts.models.llama",
template="templates.key_val",
Expand All @@ -548,7 +548,7 @@ def test_random_num_demos(self):
self.assertEqual(len(lengths), 4)

def test_dataset_recipe_with_balancer(self):
recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
card="cards.wnli",
system_prompt="system_prompts.models.llama",
template="templates.key_val",
Expand All @@ -566,7 +566,7 @@ def test_dataset_recipe_with_balancer(self):
self.assertEqual(counts["entailment"], counts["not entailment"])

def test_dataset_recipe_with_loader_limit(self):
recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
card="cards.wnli",
system_prompt="system_prompts.models.llama",
template="templates.key_val",
Expand All @@ -584,30 +584,30 @@ def test_dataset_recipe_with_loader_limit(self):

def test_dataset_recipe_with_loader_limit_errors(self):
with self.assertRaises(ValueError):
StandardRecipeWithIndexes(
DatasetRecipe(
card="cards.wnli",
template="templates.key_val",
max_test_instances=10,
loader_limit=9,
)

with self.assertRaises(ValueError):
StandardRecipeWithIndexes(
DatasetRecipe(
card="cards.wnli",
template="templates.key_val",
max_train_instances=10,
loader_limit=9,
)
with self.assertRaises(ValueError):
StandardRecipeWithIndexes(
DatasetRecipe(
template="templates.key_val",
card="cards.wnli",
max_validation_instances=10,
loader_limit=9,
)

with self.assertRaises(ValueError):
StandardRecipeWithIndexes(
DatasetRecipe(
template="templates.key_val",
card="cards.wnli",
num_demos=3,
Expand All @@ -616,7 +616,7 @@ def test_dataset_recipe_with_loader_limit_errors(self):
)

def test_dataset_recipe_with_no_demos_to_take(self):
recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
template="templates.key_val",
card="cards.xwinogrande.pt",
num_demos=3,
Expand All @@ -632,7 +632,7 @@ def test_dataset_recipe_with_no_demos_to_take(self):
)

with self.assertRaises(Exception) as cm:
recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
template="templates.key_val",
card="cards.xwinogrande.pt",
num_demos=3,
Expand All @@ -645,7 +645,7 @@ def test_dataset_recipe_with_no_demos_to_take(self):
)

with self.assertRaises(Exception) as cm:
recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
template="templates.key_val",
card="cards.xwinogrande.pt",
num_demos=30,
Expand All @@ -658,7 +658,7 @@ def test_dataset_recipe_with_no_demos_to_take(self):
)

def test_dataset_recipe_with_no_test(self):
recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
template="templates.key_val",
card="cards.xwinogrande.pt",
num_demos=3,
Expand All @@ -671,7 +671,7 @@ def test_dataset_recipe_with_no_test(self):
def test_dataset_recipe_with_template_errors(self):
# Check either template or template index was specified , but not both
with self.assertRaises(AssertionError) as cm:
StandardRecipeWithIndexes(
DatasetRecipe(
card="cards.wnli", template="templates.key_val", template_card_index=100
)
self.assertTrue(
Expand All @@ -684,7 +684,7 @@ def test_dataset_recipe_with_template_errors(self):

# Also check if string index is used
with self.assertRaises(AssertionError) as cm:
StandardRecipeWithIndexes(
DatasetRecipe(
card="cards.wnli",
template="templates.key_val",
template_card_index="illegal_template",
Expand All @@ -699,17 +699,15 @@ def test_dataset_recipe_with_template_errors(self):

# Return an error if index is not found in card
with self.assertRaises(ValueError) as cm:
StandardRecipeWithIndexes(
card="cards.wnli", template_card_index="illegal_template"
)
DatasetRecipe(card="cards.wnli", template_card_index="illegal_template")
self.assertTrue("not defined in card." in str(cm.exception))

with self.assertRaises(ValueError) as cm:
StandardRecipeWithIndexes(card="cards.wnli", template_card_index=100)
DatasetRecipe(card="cards.wnli", template_card_index=100)
self.assertTrue("not defined in card." in str(cm.exception))

def test_dataset_recipe_with_balancer_and_size_limit(self):
recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
card="cards.wnli",
system_prompt="system_prompts.models.llama",
template="templates.key_val",
Expand All @@ -729,7 +727,7 @@ def test_dataset_recipe_with_balancer_and_size_limit(self):
self.assertEqual(counts["entailment"], counts["not entailment"], 10)

def test_dataset_recipe_with_augmentor_on_task_input(self):
recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
card="cards.sst2",
augmentor="augmentors.text.white_space",
template_card_index=0,
Expand All @@ -754,7 +752,7 @@ def test_dataset_recipe_with_augmentor_on_task_input(self):
), f"{normalized_output_source} is not equal to f{normalized_input_source}"

def test_dataset_recipe_with_train_size_limit(self):
recipe = StandardRecipeWithIndexes(
recipe = DatasetRecipe(
card="cards.wnli",
system_prompt="system_prompts.models.llama",
template="templates.key_val",
Expand All @@ -775,7 +773,7 @@ def test_recipe_with_hf_with_twice_the_same_instance_demos(self):

d = load_dataset(
dataset_file,
"__type__=dataset_recipe_with_indexes,card=cards.wnli,template=templates.classification.multi_class.relation.default,system_prompt=system_prompts.models.llama,demos_pool_size=5,num_demos=1",
"__type__=dataset_recipe,card=cards.wnli,template=templates.classification.multi_class.relation.default,system_prompt=system_prompts.models.llama,demos_pool_size=5,num_demos=1",
streaming=True,
trust_remote_code=True,
)
Expand All @@ -800,7 +798,7 @@ def test_dataset_recipe_with_a_missing_sampler(self):
task_card, _ = copy.deepcopy(fetch_artifact("cards.sst2"))
task_card.sampler = None
with self.assertRaises(ValueError) as e:
StandardRecipeWithIndexes(
DatasetRecipe(
card=task_card,
template_card_index=0,
max_train_instances=0,
Expand Down

0 comments on commit 934c9a4

Please sign in to comment.