Skip to content

Commit

Permalink
separate AddDemosPool from CreateDemosPool
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <[email protected]>
  • Loading branch information
dafnapension committed Dec 22, 2024
1 parent 028a910 commit 24342d7
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 81 deletions.
1 change: 1 addition & 0 deletions src/unitxt/settings_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __getattr__(self, key):
constants.inference_stream = "__INFERENCE_STREAM__"
constants.instance_stream = "__INSTANCE_STREAM__"
constants.image_tag = "unitxt-img"
constants.demos_pool_field = "_demos_pool_"


def get_settings() -> Settings:
Expand Down
156 changes: 80 additions & 76 deletions src/unitxt/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,71 +36,50 @@

# Used to give meaningful name to recipe steps
class CreateDemosPool(MultiStreamOperator):
given_demos_pool: List[Dict[str, Any]] = None
from_stream: str = None
demos_pool_size: int = None
remove_targets_from_source_split: bool = None
to_field: str = "_demos_pool_"

def verify(self):
assert (
self.given_demos_pool is not None
and isoftype(self.given_demos_pool, List[Dict[str, Any]])
) != (
self.from_stream is not None
and self.demos_pool_size is not None
and self.remove_targets_from_source_split is not None
), (
"The demos_pool must be specified by exactly one of two ways: explicitly, as parameter "
+ "given_demos_pool, or via parameters from_stream, demos_pool_size, and remove_targets_from_source_split, "
+ "that together direct its production."
)
to_field: str = constants.demos_pool_field

# flake8: noqa: B007
def process(self, multi_stream: MultiStream) -> MultiStream:
if not self.given_demos_pool:
# generate the demos_pool as a selection of demos_pool_size distinct instances
# (distinct by their "input_fields" field). The selection is taken from stream named from_stream.
# The selected instances are later treated as ordinary instances or not, depending on parameter
# remove_targets_from_source_split.
# The selection of instances is done from the first instances of the stream named from_stream.
# instances that are not distinct from previously selected demo instances, are kept aside, to be later
# treated like all the remaining instances of stream from_stream.
if self.from_stream not in multi_stream:
raise ValueError(
f"Input multi-stream is missing a stream named '{self.from_stream}' to take demo instances from for the demos_pool."
)
from_stream = multi_stream[self.from_stream]
demos_pool = []
input_fields_of_demos_pool = []
not_selected_from_from_stream = []
for num_scanned, instance in enumerate(from_stream):
if "input_fields" not in instance:
raise ValueError(
f"'input_fields' field is missing from '{instance}'."
)
input_fields_signature = json.dumps(
instance["input_fields"], sort_keys=True
)
if input_fields_signature in input_fields_of_demos_pool:
not_selected_from_from_stream.append(instance)
continue
demos_pool.append(instance)
input_fields_of_demos_pool.append(input_fields_signature)
if len(demos_pool) >= self.demos_pool_size:
break

# for backward compatibility, do not throw exception here if demos pool is smaller than expected.
# Delay that for the event (if occurs) that Sample is not be able to sample num_demos demos.
# generate the demos_pool as a selection of demos_pool_size distinct instances
# (distinct by their "input_fields" field). The selection is taken from stream named from_stream.
# The selected instances are later treated as ordinary instances or not, depending on parameter
# remove_targets_from_source_split.
# The selection of instances is done from the first instances of the stream named from_stream.
# instances that are not distinct from previously selected demo instances, are kept aside, to be later
# treated like all the remaining instances of stream from_stream.
if self.from_stream not in multi_stream:
raise ValueError(
f"Input multi-stream is missing a stream named '{self.from_stream}' to take demo instances from for the demos_pool."
)
from_stream = multi_stream[self.from_stream]
demos_pool = []
input_fields_of_demos_pool = []
not_selected_from_from_stream = []
for num_scanned, instance in enumerate(from_stream):
if "input_fields" not in instance:
raise ValueError(f"'input_fields' field is missing from '{instance}'.")
input_fields_signature = json.dumps(
instance["input_fields"], sort_keys=True
)
if input_fields_signature in input_fields_of_demos_pool:
not_selected_from_from_stream.append(instance)
continue
demos_pool.append(instance)
input_fields_of_demos_pool.append(input_fields_signature)
if len(demos_pool) >= self.demos_pool_size:
break

# to avoid endless recursion in case of not demos_removed_from_data
demos_pool = recursive_copy(demos_pool)
# for backward compatibility, do not throw exception here if demos pool is smaller than expected.
# Delay that for the event (if occurs) that Sample is not be able to sample num_demos demos.

else:
demos_pool = self.given_demos_pool
# to avoid endless recursion in case of not demos_removed_from_data
demos_pool = recursive_copy(demos_pool)

set_demos_pool = Set(fields={self.to_field: demos_pool})
if self.given_demos_pool or not self.remove_targets_from_source_split:
if not self.remove_targets_from_source_split:
return set_demos_pool(multi_stream)

def from_stream_generator(
Expand Down Expand Up @@ -136,6 +115,15 @@ def from_stream_generator(
return set_demos_pool(ms)


class AddDemosPool(MultiStreamOperator):
demos_pool: List[Dict[str, Any]]
demos_pool_field_name: str = constants.demos_pool_field

def process(self, multi_stream: MultiStream) -> MultiStream:
set_demos_pool = Set(fields={self.demos_pool_field_name: self.demos_pool})
return set_demos_pool(multi_stream)


class DatasetRecipe(SourceSequentialOperator):
"""This class represents a standard recipe for data processing and preparation.
Expand Down Expand Up @@ -174,19 +162,19 @@ class DatasetRecipe(SourceSequentialOperator):
Maximum test instances for the refiner.
demos_pool_size (int, optional):
Size of the demos pool.
given_demos_pool(List[Dict[str, Any]], optional):
demos_pool(List[Dict[str, Any]], optional):
a list of instances to make the demos_pool
num_demos (int, optional):
Number of demos to be used.
demos_pool_name (str, optional):
Name of the demos pool. Default is "demos_pool".
Name of the demos pool. Default is constants.demos_pool_field
demos_taken_from (str, optional):
Specifies from where the demos are taken. Default is "train".
demos_field (str, optional):
Field name for demos. Default is "demos".
demos_pool_field_name (str, optional):
field name to maintain the demos_pool, until sampled from, to make the demos.
Defaults to "_demos_pool_".
Defaults to constants.demos_pool_field.
demos_removed_from_data (bool, optional):
whether to remove the demos from the source data, Default is True
sampler (Sampler, optional):
Expand Down Expand Up @@ -239,10 +227,10 @@ class DatasetRecipe(SourceSequentialOperator):
test_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)

demos_pool_size: int = None
given_demos_pool: List[Dict[str, Any]] = None
demos_pool: List[Dict[str, Any]] = None
num_demos: Optional[Union[int, List[int]]] = 0
demos_removed_from_data: bool = True
demos_pool_field_name: str = "_demos_pool_"
demos_pool_field_name: str = constants.demos_pool_field

demos_taken_from: str = "train"
demos_field: str = "demos"
Expand Down Expand Up @@ -509,17 +497,24 @@ def reset_pipeline(self):
# one example of consume but not used, and indeed skips over a problematic (json-wise) input:
# prepare/cards/rag/end_to_end/clapnq.py
if self.has_custom_demos_pool:
self.processing.steps.append(
CreateDemosPool(
given_demos_pool=self.given_demos_pool,
from_stream=self.demos_taken_from,
demos_pool_size=self.demos_pool_size
if self.given_demos_pool is None
else None,
remove_targets_from_source_split=self.demos_removed_from_data,
to_field=self.demos_pool_field_name,
if self.demos_pool:
self.processing.steps.append(
AddDemosPool(
demos_pool=self.demos_pool,
demos_pool_field_name=self.demos_pool_field_name,
)
)
else:
self.processing.steps.append(
CreateDemosPool(
from_stream=self.demos_taken_from,
demos_pool_size=self.demos_pool_size
if self.demos_pool is None
else None,
remove_targets_from_source_split=self.demos_removed_from_data,
to_field=self.demos_pool_field_name,
)
)
)

if self.use_demos:
if self.sampler is None:
Expand Down Expand Up @@ -665,13 +660,22 @@ def prepare(self):
"No template was specified in the the 'template' or 'template_card_index' recipe arguments, and no default templates are defined the card or task"
)
if self.use_demos:
if (self.demos_pool_size is None) == (self.given_demos_pool is None):
raise ValueError(
f"When using demonstrations, exactly one of either demos_pool_size or given_demos_pool must be set. Got both {'' if self.demos_pool_size else 'not '}set."
)
assert (
self.demos_pool is not None
and isoftype(self.demos_pool, List[Dict[str, Any]])
) != (
self.demos_taken_from is not None
and self.demos_pool_size is not None
and self.demos_removed_from_data is not None
), (
"The demos_pool must be specified by exactly one of two ways: explicitly, as a list of instances coming through parameter "
+ "'demos_pool', or via parameters 'demos_taken_from', 'demos_pool_size', and 'demos_removed_from_data', "
+ "that together direct its production."
)

# now set self.demos_pool_size for the checks of verify
if self.given_demos_pool:
self.demos_pool_size = len(self.given_demos_pool)
if self.demos_pool:
self.demos_pool_size = len(self.demos_pool)

if isinstance(self.template, TemplatesList):
self.template = self.template.items
Expand Down
22 changes: 17 additions & 5 deletions tests/library/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,24 +181,36 @@ def test_dataset_recipe_with_given_demos(self):
for_demos = recipe.processing(for_demos)
for_demos = recursive_copy(list(for_demos["validation"]))

# for_demos is a list of instances, taken from stream 'validation' of the source of 'cards.wnli'.
# Having passed the steps of preprocessing, each of them complies now with the format of 'cards.wnli.task'

# we now run a recipe with this for_demos, and see stream 'train' coming out with them as demos

recipe2 = DatasetRecipe(
card="cards.wnli",
template_card_index=0,
given_demos_pool=for_demos,
demos_pool=for_demos[0:5],
num_demos=3,
)

trains = list(recipe2()["train"])
assert "The entailment class is entailment" not in trains[0]["source"]
source_demos_input = trains[0]["source"]

# the same result as when creating the demos while processing the recipe:

recipe3 = DatasetRecipe(
card="cards.wnli",
template_card_index=0,
given_demos_pool=for_demos,
demos_taken_from="validation",
demos_pool_size=5,
demos_removed_from_data=True,
num_demos=3,
)

trains = list(recipe3()["train"])
assert "The entailment class is entailment" in trains[0]["source"]
source_demos_selected = trains[0]["source"]

self.assertEqual(source_demos_input, source_demos_selected)

def test_dataset_recipe_not_duplicating_demos_pool(self):
recipe = DatasetRecipe(
Expand All @@ -212,7 +224,7 @@ def test_dataset_recipe_not_duplicating_demos_pool(self):
recipe3 = DatasetRecipe(
card="cards.wnli",
template_card_index=0,
given_demos_pool=for_demos,
demos_pool=for_demos,
num_demos=3,
)

Expand Down

0 comments on commit 24342d7

Please sign in to comment.