From 24342d741b56c934c73134b0e5407f63e2003f96 Mon Sep 17 00:00:00 2001 From: dafnapension Date: Sun, 22 Dec 2024 14:25:07 +0200 Subject: [PATCH] separate AddDemosPool from CreateDemosPool Signed-off-by: dafnapension --- src/unitxt/settings_utils.py | 1 + src/unitxt/standard.py | 156 ++++++++++++++++++----------------- tests/library/test_recipe.py | 22 +++-- 3 files changed, 98 insertions(+), 81 deletions(-) diff --git a/src/unitxt/settings_utils.py b/src/unitxt/settings_utils.py index 8e35afe08..75a3bd641 100644 --- a/src/unitxt/settings_utils.py +++ b/src/unitxt/settings_utils.py @@ -186,6 +186,7 @@ def __getattr__(self, key): constants.inference_stream = "__INFERENCE_STREAM__" constants.instance_stream = "__INSTANCE_STREAM__" constants.image_tag = "unitxt-img" + constants.demos_pool_field = "_demos_pool_" def get_settings() -> Settings: diff --git a/src/unitxt/standard.py b/src/unitxt/standard.py index a98793719..e7d4e3619 100644 --- a/src/unitxt/standard.py +++ b/src/unitxt/standard.py @@ -36,71 +36,50 @@ # Used to give meaningful name to recipe steps class CreateDemosPool(MultiStreamOperator): - given_demos_pool: List[Dict[str, Any]] = None from_stream: str = None demos_pool_size: int = None remove_targets_from_source_split: bool = None - to_field: str = "_demos_pool_" - - def verify(self): - assert ( - self.given_demos_pool is not None - and isoftype(self.given_demos_pool, List[Dict[str, Any]]) - ) != ( - self.from_stream is not None - and self.demos_pool_size is not None - and self.remove_targets_from_source_split is not None - ), ( - "The demos_pool must be specified by exactly one of two ways: explicitly, as parameter " - + "given_demos_pool, or via parameters from_stream, demos_pool_size, and remove_targets_from_source_split, " - + "that together direct its production." - ) + to_field: str = constants.demos_pool_field # flake8: noqa: B007 def process(self, multi_stream: MultiStream) -> MultiStream: - if not self.given_demos_pool: - # generate the demos_pool as a selection of demos_pool_size distinct instances - # (distinct by their "input_fields" field). The selection is taken from stream named from_stream. - # The selected instances are later treated as ordinary instances or not, depending on parameter - # remove_targets_from_source_split. - # The selection of instances is done from the first instances of the stream named from_stream. - # instances that are not distinct from previously selected demo instances, are kept aside, to be later - # treated like all the remaining instances of stream from_stream. - if self.from_stream not in multi_stream: - raise ValueError( - f"Input multi-stream is missing a stream named '{self.from_stream}' to take demo instances from for the demos_pool." - ) - from_stream = multi_stream[self.from_stream] - demos_pool = [] - input_fields_of_demos_pool = [] - not_selected_from_from_stream = [] - for num_scanned, instance in enumerate(from_stream): - if "input_fields" not in instance: - raise ValueError( - f"'input_fields' field is missing from '{instance}'." - ) - input_fields_signature = json.dumps( - instance["input_fields"], sort_keys=True - ) - if input_fields_signature in input_fields_of_demos_pool: - not_selected_from_from_stream.append(instance) - continue - demos_pool.append(instance) - input_fields_of_demos_pool.append(input_fields_signature) - if len(demos_pool) >= self.demos_pool_size: - break - - # for backward compatibility, do not throw exception here if demos pool is smaller than expected. - # Delay that for the event (if occurs) that Sample is not be able to sample num_demos demos. + # generate the demos_pool as a selection of demos_pool_size distinct instances + # (distinct by their "input_fields" field). The selection is taken from stream named from_stream. + # The selected instances are later treated as ordinary instances or not, depending on parameter + # remove_targets_from_source_split. + # The selection of instances is done from the first instances of the stream named from_stream. + # instances that are not distinct from previously selected demo instances, are kept aside, to be later + # treated like all the remaining instances of stream from_stream. + if self.from_stream not in multi_stream: + raise ValueError( + f"Input multi-stream is missing a stream named '{self.from_stream}' to take demo instances from for the demos_pool." + ) + from_stream = multi_stream[self.from_stream] + demos_pool = [] + input_fields_of_demos_pool = [] + not_selected_from_from_stream = [] + for num_scanned, instance in enumerate(from_stream): + if "input_fields" not in instance: + raise ValueError(f"'input_fields' field is missing from '{instance}'.") + input_fields_signature = json.dumps( + instance["input_fields"], sort_keys=True + ) + if input_fields_signature in input_fields_of_demos_pool: + not_selected_from_from_stream.append(instance) + continue + demos_pool.append(instance) + input_fields_of_demos_pool.append(input_fields_signature) + if len(demos_pool) >= self.demos_pool_size: + break - # to avoid endless recursion in case of not demos_removed_from_data - demos_pool = recursive_copy(demos_pool) + # for backward compatibility, do not throw exception here if demos pool is smaller than expected. + # Delay that for the event (if occurs) that Sample is not be able to sample num_demos demos. - else: - demos_pool = self.given_demos_pool + # to avoid endless recursion in case of not demos_removed_from_data + demos_pool = recursive_copy(demos_pool) set_demos_pool = Set(fields={self.to_field: demos_pool}) - if self.given_demos_pool or not self.remove_targets_from_source_split: + if not self.remove_targets_from_source_split: return set_demos_pool(multi_stream) def from_stream_generator( @@ -136,6 +115,15 @@ def from_stream_generator( return set_demos_pool(ms) +class AddDemosPool(MultiStreamOperator): + demos_pool: List[Dict[str, Any]] + demos_pool_field_name: str = constants.demos_pool_field + + def process(self, multi_stream: MultiStream) -> MultiStream: + set_demos_pool = Set(fields={self.demos_pool_field_name: self.demos_pool}) + return set_demos_pool(multi_stream) + + class DatasetRecipe(SourceSequentialOperator): """This class represents a standard recipe for data processing and preparation. @@ -174,19 +162,19 @@ class DatasetRecipe(SourceSequentialOperator): Maximum test instances for the refiner. demos_pool_size (int, optional): Size of the demos pool. - given_demos_pool(List[Dict[str, Any]], optional): + demos_pool(List[Dict[str, Any]], optional): a list of instances to make the demos_pool num_demos (int, optional): Number of demos to be used. demos_pool_name (str, optional): - Name of the demos pool. Default is "demos_pool". + Name of the demos pool. Default is constants.demos_pool_field demos_taken_from (str, optional): Specifies from where the demos are taken. Default is "train". demos_field (str, optional): Field name for demos. Default is "demos". demos_pool_field_name (str, optional): field name to maintain the demos_pool, until sampled from, to make the demos. - Defaults to "_demos_pool_". + Defaults to constants.demos_pool_field. demos_removed_from_data (bool, optional): whether to remove the demos from the source data, Default is True sampler (Sampler, optional): @@ -239,10 +227,10 @@ class DatasetRecipe(SourceSequentialOperator): test_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner) demos_pool_size: int = None - given_demos_pool: List[Dict[str, Any]] = None + demos_pool: List[Dict[str, Any]] = None num_demos: Optional[Union[int, List[int]]] = 0 demos_removed_from_data: bool = True - demos_pool_field_name: str = "_demos_pool_" + demos_pool_field_name: str = constants.demos_pool_field demos_taken_from: str = "train" demos_field: str = "demos" @@ -509,17 +497,24 @@ def reset_pipeline(self): # one example of consume but not used, and indeed skips over a problematic (json-wise) input: # prepare/cards/rag/end_to_end/clapnq.py if self.has_custom_demos_pool: - self.processing.steps.append( - CreateDemosPool( - given_demos_pool=self.given_demos_pool, - from_stream=self.demos_taken_from, - demos_pool_size=self.demos_pool_size - if self.given_demos_pool is None - else None, - remove_targets_from_source_split=self.demos_removed_from_data, - to_field=self.demos_pool_field_name, + if self.demos_pool: + self.processing.steps.append( + AddDemosPool( + demos_pool=self.demos_pool, + demos_pool_field_name=self.demos_pool_field_name, + ) + ) + else: + self.processing.steps.append( + CreateDemosPool( + from_stream=self.demos_taken_from, + demos_pool_size=self.demos_pool_size + if self.demos_pool is None + else None, + remove_targets_from_source_split=self.demos_removed_from_data, + to_field=self.demos_pool_field_name, + ) ) - ) if self.use_demos: if self.sampler is None: @@ -665,13 +660,22 @@ def prepare(self): "No template was specified in the the 'template' or 'template_card_index' recipe arguments, and no default templates are defined the card or task" ) if self.use_demos: - if (self.demos_pool_size is None) == (self.given_demos_pool is None): - raise ValueError( - f"When using demonstrations, exactly one of either demos_pool_size or given_demos_pool must be set. Got both {'' if self.demos_pool_size else 'not '}set." - ) + assert ( + self.demos_pool is not None + and isoftype(self.demos_pool, List[Dict[str, Any]]) + ) != ( + self.demos_taken_from is not None + and self.demos_pool_size is not None + and self.demos_removed_from_data is not None + ), ( + "The demos_pool must be specified by exactly one of two ways: explicitly, as a list of instances coming through parameter " + + "'demos_pool', or via parameters 'demos_taken_from', 'demos_pool_size', and 'demos_removed_from_data', " + + "that together direct its production." + ) + # now set self.demos_pool_size for the checks of verify - if self.given_demos_pool: - self.demos_pool_size = len(self.given_demos_pool) + if self.demos_pool: + self.demos_pool_size = len(self.demos_pool) if isinstance(self.template, TemplatesList): self.template = self.template.items diff --git a/tests/library/test_recipe.py b/tests/library/test_recipe.py index 3fb39003f..349850db7 100644 --- a/tests/library/test_recipe.py +++ b/tests/library/test_recipe.py @@ -181,24 +181,36 @@ def test_dataset_recipe_with_given_demos(self): for_demos = recipe.processing(for_demos) for_demos = recursive_copy(list(for_demos["validation"])) + # for_demos is a list of instances, taken from stream 'validation' of the source of 'cards.wnli'. + # Having passed the steps of preprocessing, each of them complies now with the format of 'cards.wnli.task' + + # we now run a recipe with this for_demos, and see stream 'train' coming out with them as demos + recipe2 = DatasetRecipe( card="cards.wnli", template_card_index=0, - given_demos_pool=for_demos, + demos_pool=for_demos[0:5], + num_demos=3, ) trains = list(recipe2()["train"]) - assert "The entailment class is entailment" not in trains[0]["source"] + source_demos_input = trains[0]["source"] + + # the same result as when creating the demos while processing the recipe: recipe3 = DatasetRecipe( card="cards.wnli", template_card_index=0, - given_demos_pool=for_demos, + demos_taken_from="validation", + demos_pool_size=5, + demos_removed_from_data=True, num_demos=3, ) trains = list(recipe3()["train"]) - assert "The entailment class is entailment" in trains[0]["source"] + source_demos_selected = trains[0]["source"] + + self.assertEqual(source_demos_input, source_demos_selected) def test_dataset_recipe_not_duplicating_demos_pool(self): recipe = DatasetRecipe( @@ -212,7 +224,7 @@ def test_dataset_recipe_not_duplicating_demos_pool(self): recipe3 = DatasetRecipe( card="cards.wnli", template_card_index=0, - given_demos_pool=for_demos, + demos_pool=for_demos, num_demos=3, )