From ccebc838f9066df89705aa611600aa316757d8af Mon Sep 17 00:00:00 2001 From: dafnapension Date: Fri, 20 Dec 2024 19:50:15 +0200 Subject: [PATCH] make demos_pool a variable, rather than a separate stream. Needs to restart train repeatedly. Fix needed in Loaders. Allow a given demos_pool, or input stream-instances already loaded with demos Signed-off-by: dafnapension --- src/unitxt/loaders.py | 148 ++++++++++----------- src/unitxt/schema.py | 3 + src/unitxt/splitters.py | 61 ++++----- src/unitxt/standard.py | 221 ++++++++++++++++++++++++++----- src/unitxt/task.py | 3 + src/unitxt/test_utils/metrics.py | 2 +- tests/library/test_api.py | 18 +-- tests/library/test_benchmark.py | 8 +- tests/library/test_recipe.py | 172 ++++++++++++++++++++---- utils/.secrets.baseline | 4 +- 10 files changed, 454 insertions(+), 186 deletions(-) diff --git a/src/unitxt/loaders.py b/src/unitxt/loaders.py index 89ede9c38c..1ca35ffd4a 100644 --- a/src/unitxt/loaders.py +++ b/src/unitxt/loaders.py @@ -41,6 +41,7 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union import pandas as pd +from datasets import IterableDatasetDict from datasets import load_dataset as hf_load_dataset from huggingface_hub import HfApi from tqdm import tqdm @@ -51,7 +52,7 @@ from .operator import SourceOperator from .operators import Set from .settings_utils import get_settings -from .stream import DynamicStream, MultiStream +from .stream import MultiStream from .type_utils import isoftype from .utils import LRUCache @@ -122,7 +123,7 @@ def add_data_classification(self, multi_stream: MultiStream) -> MultiStream: ) return operator(multi_stream) - def sef_default_data_classification( + def set_default_data_classification( self, default_data_classification_policy, additional_info ): if self.data_classification_policy is None: @@ -161,23 +162,24 @@ class LoadHF(Loader): and it can filter datasets upon loading. Args: - path: The path or identifier of the dataset on the HuggingFace Hub. - - name: An optional dataset name. - - data_dir: Optional directory to store downloaded data. - - split: Optional specification of which split to load. - - data_files: Optional specification of particular data files to load. - - revision: Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version. - - streaming (bool): indicating if streaming should be used. - - filtering_lambda: A lambda function for filtering the data after loading. - - num_proc (int): Optional integer to specify the number of processes to use for parallel dataset loading. + path: + The path or identifier of the dataset on the HuggingFace Hub. + name: + An optional dataset name. + data_dir: + Optional directory to store downloaded data. + split: + Optional specification of which split to load. + data_files: + Optional specification of particular data files to load. + revision: + Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version. + streaming (bool): + indicating if streaming should be used. + filtering_lambda (str, optional): + A lambda function for filtering the data after loading. + num_proc (int, optional): + Specifies the number of processes to use for parallel dataset loading. Example: Loading glue's mrpc dataset @@ -277,39 +279,21 @@ def load_dataset(self): for split in dataset.keys(): dataset[split] = dataset[split].to_iterable_dataset() else: - dataset = {self.split: dataset} - - if self.filtering_lambda is not None: - dataset = self.filter_load(dataset) + dataset = {self.split: dataset.to_iterable_dataset()} return dataset - def split_limited_load(self, dataset, split_name): - yield from itertools.islice(dataset[split_name], self.get_limit()) - - def limited_load(self, dataset): - self.log_limited_loading() - return MultiStream( - { - name: DynamicStream( - generator=self.split_limited_load, - gen_kwargs={"dataset": dataset, "split_name": name}, - ) - for name in dataset.keys() - } - ) - def _maybe_set_classification_policy(self): if os.path.exists(self.path): - self.sef_default_data_classification( + self.set_default_data_classification( ["proprietary"], "when loading from local files" ) else: - self.sef_default_data_classification( + self.set_default_data_classification( ["public"], "when loading from Huggingface hub" ) - def load_iterables(self): + def load_iterables(self) -> IterableDatasetDict: try: dataset = self.stream_dataset() except ( @@ -317,8 +301,15 @@ def load_iterables(self): ): # streaming is not supported for zipped files so we load without streaming dataset = self.load_dataset() + if self.filtering_lambda is not None: + dataset = self.filter_load(dataset) + if self.get_limit() is not None: - return self.limited_load(dataset=dataset) + self.log_limited_loading() + return { + split_name: dataset[split_name].take(self.get_limit()) + for split_name in dataset + } return dataset @@ -350,7 +341,7 @@ class LoadCSV(Loader): sep: str = "," def _maybe_set_classification_policy(self): - self.sef_default_data_classification( + self.set_default_data_classification( ["proprietary"], "when loading from local files" ) @@ -363,9 +354,7 @@ def load_iterables(self): file_path, nrows=self.get_limit(), sep=self.sep ).to_dict("records") else: - iterables[split_name] = pd.read_csv(file_path, sep=self.sep).to_dict( - "records" - ) + iterables[split_name] = pd.read_csv(file_path).to_dict("records") return iterables @@ -473,14 +462,22 @@ class LoadFromIBMCloud(Loader): 3. Mapping: split -> file_names, e.g. {"test" : ["test1.json", "test2.json"], "train": ["train.json"]} Args: - endpoint_url_env: Environment variable name for the IBM Cloud endpoint URL. - aws_access_key_id_env: Environment variable name for the AWS access key ID. - aws_secret_access_key_env: Environment variable name for the AWS secret access key. - bucket_name: Name of the S3 bucket from which to load data. - data_dir: Optional directory path within the bucket. - data_files: Union type allowing either a list of file names or a mapping of splits to file names. - data_field: The dataset key for nested JSON file, i.e. when multiple datasets are nested in the same file - caching: Bool indicating if caching is enabled to avoid re-downloading data. + endpoint_url_env: + Environment variable name for the IBM Cloud endpoint URL. + aws_access_key_id_env: + Environment variable name for the AWS access key ID. + aws_secret_access_key_env: + Environment variable name for the AWS secret access key. + bucket_name: + Name of the S3 bucket from which to load data. + data_dir: + Optional directory path within the bucket. + data_files: + Union type allowing either a list of file names or a mapping of splits to file names. + data_field: + The dataset key for nested JSON file, i.e. when multiple datasets are nested in the same file + caching (bool): + indicating if caching is enabled to avoid re-downloading data. Example: Loading from IBM Cloud @@ -576,7 +573,7 @@ def lazy_verify(self): raise NotImplementedError("LoadFromKaggle cannot load with streaming.") def _maybe_set_classification_policy(self): - self.sef_default_data_classification( + self.set_default_data_classification( ["proprietary"], "when loading from IBM COS" ) @@ -727,7 +724,7 @@ def verify(self): ) def _maybe_set_classification_policy(self): - self.sef_default_data_classification( + self.set_default_data_classification( ["proprietary"], "when loading from python dictionary" ) @@ -742,25 +739,24 @@ class LoadFromHFSpace(LoadHF): from the given space and then reads them as a HuggingFace Dataset. Args: - space_name (str): Name of the HuggingFace Space to be accessed. - - data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]): Relative - paths to files within a given repository. If given as a mapping, paths should - be values, while keys should represent the type of respective files - (training, testing etc.). - - path (str, optional): Absolute path to a directory where data should be downloaded. - - revision (str, optional): ID of a Git branch or commit to be used. By default, it is - set to None, thus data is downloaded from the main branch of the accessed - repository. - - use_token (bool, optional): Whether a token is used for authentication when accessing - the HuggingFace Space. If necessary, the token is read from the HuggingFace - config folder. - - token_env (str, optional): Key of an env variable which value will be used for - authentication when accessing the HuggingFace Space - if necessary. + space_name (str): + Name of the HuggingFace Space to be accessed. + data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]): + Relative paths to files within a given repository. If given as a mapping, + paths should be values, while keys should represent the type of respective files + (training, testing etc.). + path (str, optional): + Absolute path to a directory where data should be downloaded. + revision (str, optional): + ID of a Git branch or commit to be used. By default, it is set to None, + thus data is downloaded from the main branch of the accessed repository. + use_token (bool, optional): + Whether a token is used for authentication when accessing + the HuggingFace Space. If necessary, the token is read from the HuggingFace + config folder. + token_env (str, optional): + Key of an env variable which value will be used for + authentication when accessing the HuggingFace Space - if necessary. Example: Loading from a HuggingFace Space @@ -908,7 +904,7 @@ def _map_wildcard_path_to_full_paths(self): ) def _maybe_set_classification_policy(self): - self.sef_default_data_classification( + self.set_default_data_classification( ["public"], "when loading from Huggingface spaces" ) diff --git a/src/unitxt/schema.py b/src/unitxt/schema.py index cde289124d..5b913d11a7 100644 --- a/src/unitxt/schema.py +++ b/src/unitxt/schema.py @@ -130,6 +130,9 @@ def process( ) task_data["metadata"]["num_demos"] = instance["recipe_metadata"]["num_demos"] + task_data["metadata"]["demos_pool_size"] = instance["recipe_metadata"][ + "demos_pool_size" + ] task_data["metadata"]["template"] = self.artifact_to_jsonable( instance["recipe_metadata"]["template"] ) diff --git a/src/unitxt/splitters.py b/src/unitxt/splitters.py index dd6ee45b1e..521872f882 100644 --- a/src/unitxt/splitters.py +++ b/src/unitxt/splitters.py @@ -1,11 +1,11 @@ import itertools from abc import abstractmethod from difflib import get_close_matches -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from .artifact import Artifact from .dict_utils import dict_get -from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator +from .operator import InstanceOperator, MultiStreamOperator from .random_utils import new_random_generator from .split_utils import ( parse_random_mix_string, @@ -14,7 +14,7 @@ rename_split, slice_streams, ) -from .stream import EmptyStreamError, FaultyStreamError, MultiStream +from .stream import MultiStream from .type_utils import isoftype from .utils import recursive_copy @@ -118,14 +118,14 @@ class Sampler(Artifact): def sample( self, sample_size: int, - instances_pool: List[Dict[str, object]], - instance: Dict[str, object], - ) -> List[Dict[str, object]]: + instances_pool: List[Dict[str, Any]], + instance: Dict[str, Any], + ) -> List[Dict[str, Any]]: pass def filter_source_by_instance( - self, instances_pool: List[Dict[str, object]], instance: Dict[str, object] - ) -> List[Dict[str, object]]: + self, instances_pool: List[Dict[str, Any]], instance: Dict[str, Any] + ) -> List[Dict[str, Any]]: if "input_fields" not in instance: raise ValueError(f"'input_fields' field is missing from '{instance}'.") try: @@ -336,10 +336,11 @@ def sample( return result -class Sample(InstanceOperatorWithMultiStreamAccess): - from_stream: str +class Sample(InstanceOperator): + from_field: str to_field: str sampler: Sampler + skip_demoed_instances: bool = False def prepare(self): self.local_cache = None @@ -350,30 +351,26 @@ def get_sample_size(self, instance) -> int: pass def process( - self, instance: Dict[str, object], multi_stream: MultiStream - ) -> Dict[str, object]: - sample_size = self.get_sample_size(instance) - try: - if self.local_cache is None: - self.local_cache = recursive_copy(list(multi_stream[self.from_stream])) + self, instance: Dict[str, Any], multi_stream: MultiStream + ) -> Dict[str, Any]: + if self.skip_demoed_instances and self.to_field in instance: + if self.from_field in instance: + instance.pop(self.from_field) + return instance - source_stream = self.local_cache - source_stream = self.sampler.filter_source_by_instance( - source_stream, instance - ) - if len(source_stream) < sample_size: - raise ValueError( - f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {self.sampler.sample_size}." - ) - sampled_instances = self.sampler.sample( - sample_size=sample_size, instances_pool=source_stream, instance=instance + demos_pool = instance[self.from_field] + sample_size = self.get_sample_size(instance) + source_stream = self.sampler.filter_source_by_instance(demos_pool, instance) + if len(source_stream) < sample_size: + raise ValueError( + f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {sample_size}. Please consider increasing increasing the demos pool, for which you may need to increase loader_limit or employ a less strict stream filtering." ) - instance[self.to_field] = sampled_instances - return instance - except FaultyStreamError as e: - raise EmptyStreamError( - f"Unable to fetch instances from '{self.from_stream}' to '{self.to_field}', due to {e.__class__.__name__}: {e}" - ) from e + sampled_instances = self.sampler.sample( + sample_size=sample_size, instances_pool=source_stream, instance=instance + ) + instance[self.to_field] = recursive_copy(sampled_instances) + instance.pop(self.from_field) # pop the field pointing to the demos_pool + return instance class ConstantSizeSample(Sample): diff --git a/src/unitxt/standard.py b/src/unitxt/standard.py index d18eaab904..92a4defdaa 100644 --- a/src/unitxt/standard.py +++ b/src/unitxt/standard.py @@ -1,4 +1,6 @@ -from typing import List, Optional, Union +import itertools +import json +from typing import Any, Dict, Generator, List, Optional, Union from .artifact import fetch_artifact from .augmentors import Augmentor, NullAugmentor @@ -7,20 +9,26 @@ from .dataclass import Field, InternalField, NonPositionalField, OptionalField from .error_utils import UnitxtError from .formats import Format, SystemFormat +from .generator_utils import ReusableGenerator from .logging_utils import get_logger -from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator +from .operator import ( + MultiStreamOperator, + SequentialOperator, + SourceSequentialOperator, + StreamingOperator, +) from .operators import Set, StreamRefiner from .recipe import Recipe from .schema import FinalizeDataset from .serializers import SingleTypeSerializer from .settings_utils import get_constants, get_settings -from .splitters import ConstantSizeSample, RandomSizeSample, Sampler, SeparateSplit +from .splitters import ConstantSizeSample, RandomSizeSample, Sampler from .stream import MultiStream from .system_prompts import EmptySystemPrompt, SystemPrompt from .task import Task from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template, TemplatesList from .type_utils import isoftype -from .utils import LRUCache +from .utils import LRUCache, recursive_copy constants = get_constants() settings = get_settings() @@ -28,8 +36,105 @@ # Used to give meaningful name to recipe steps -class CreateDemosPool(SeparateSplit): - pass +class CreateDemosPool(MultiStreamOperator): + given_demos_pool: List[Dict[str, Any]] = None + from_stream: str = None + demos_pool_size: int = None + remove_targets_from_source_split: bool = None + to_field: str = "_demos_pool_" + + def verify(self): + assert ( + self.given_demos_pool is not None + and isoftype(self.given_demos_pool, List[Dict[str, Any]]) + ) != ( + self.from_stream is not None + and self.demos_pool_size is not None + and self.remove_targets_from_source_split is not None + ), ( + "The demos_pool must be specified by exactly one of two ways: explicitly, as parameter " + + "given_demos_pool, or via parameters from_stream, demos_pool_size, and remove_targets_from_source_split, " + + "that together direct its production." + ) + + # flake8: noqa: B007 + def process(self, multi_stream: MultiStream) -> MultiStream: + if not self.given_demos_pool: + # generate the demos_pool as a selection of demos_pool_size distinct instances + # (distinct by their "input_fields" field). The selection is taken from stream named from_stream. + # The selected instances are later treated as ordinary instances or not, depending on parameter + # remove_targets_from_source_split. + # The selection of instances is done from the first instances of the stream named from_stream. + # instances that are not distinct from previously selected demo instances, are kept aside, to be later + # treated like all the remaining instances of stream from_stream. + if self.from_stream not in multi_stream: + raise ValueError( + f"Input multi-stream is missing a stream named '{self.from_stream}' to take demo instances from for the demos_pool." + ) + from_stream = multi_stream[self.from_stream] + demos_pool = [] + input_fields_of_demos_pool = [] + not_selected_from_from_stream = [] + for num_scanned, instance in enumerate(from_stream): + if "input_fields" not in instance: + raise ValueError( + f"'input_fields' field is missing from '{instance}'." + ) + input_fields_signature = json.dumps( + instance["input_fields"], sort_keys=True + ) + if input_fields_signature in input_fields_of_demos_pool: + not_selected_from_from_stream.append(instance) + continue + demos_pool.append(instance) + input_fields_of_demos_pool.append(input_fields_signature) + if len(demos_pool) >= self.demos_pool_size: + break + + # for backward compatibility, do not throw exception here if demos pool is smaller than expected. + # Delay that for the event (if occurs) that Sample is not be able to sample num_demos demos. + + # to avoid endless recursion in case of not demos_removed_from_data + demos_pool = recursive_copy(demos_pool) + + else: + demos_pool = self.given_demos_pool + + set_demos_pool = Set(fields={self.to_field: demos_pool}) + if self.given_demos_pool or not self.remove_targets_from_source_split: + return set_demos_pool(multi_stream) + + def from_stream_generator( + first_layer: list, ms: MultiStream, stream_name: str, start: int + ) -> Generator: + yield from first_layer + yield from itertools.islice(ms[stream_name], start, None) + + new_streams = {} + for stream_name in multi_stream: + if stream_name == self.from_stream: + new_streams[stream_name] = ReusableGenerator( + generator=from_stream_generator, + gen_kwargs={ + "first_layer": not_selected_from_from_stream, + "ms": multi_stream, + "stream_name": self.from_stream, + "start": num_scanned + 1, + }, + ) + else: + new_streams[stream_name] = ReusableGenerator( + generator=from_stream_generator, + gen_kwargs={ + "first_layer": [], + "ms": multi_stream, + "stream_name": stream_name, + "start": 0, + }, + ) + + ms = MultiStream.from_generators(new_streams) + return set_demos_pool(ms) class BaseRecipe(Recipe, SourceSequentialOperator): @@ -59,14 +164,18 @@ class BaseRecipe(Recipe, SourceSequentialOperator): test_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner) demos_pool_size: int = None + given_demos_pool: List[Dict[str, Any]] = None num_demos: Optional[Union[int, List[int]]] = 0 demos_removed_from_data: bool = True + demos_pool_field_name: str = "_demos_pool_" - demos_pool_name: str = "demos_pool" demos_taken_from: str = "train" demos_field: str = "demos" sampler: Sampler = None + # do not push demos to instances whose "demos" field is already populated + skip_demoed_instances: bool = False + augmentor: Union[Augmentor, List[Augmentor]] = OptionalField(default=None) steps: List[StreamingOperator] = InternalField(default_factory=list) @@ -101,9 +210,9 @@ def verify(self): raise ValueError( "When using demonstrations both num_demos and demos_pool_size should be assigned with positive integers." ) - if self.demos_pool_size < self.max_demos_size: + if self.demos_pool_size < self.max_demos_size + 1: raise ValueError( - f"num_demos (got: {self.max_demos_size}) should not exceed demos_pool_size (got: {self.demos_pool_size})" + f"num_demos (got: {self.max_demos_size}) should not exceed demos_pool_size - 1 (got: {self.demos_pool_size}), (-1: to always allow filtering of a demo identical to the processed instance)." ) if self.loader_limit and self.demos_pool_size > self.loader_limit: raise ValueError( @@ -220,25 +329,15 @@ def set_pipelines(self): self.loading, self.metadata, self.standardization, - self.processing, ] self.inference = SequentialOperator() - self.inference.steps = [self.metadata, self.verbalization, self.finalize] + self.inference.steps = [self.processing, self.verbalization, self.finalize] def production_preprocess(self, task_instances): ms = MultiStream.from_iterables({constants.inference_stream: task_instances}) - return list(self.inference_instance(ms)[constants.inference_stream]) - - def production_demos_pool(self): - if self.use_demos: - demos_pool = self.__class__._demos_pool_cache.get(str(self), None) - if demos_pool is None: - demos_pool = list(self.inference_demos()[self.demos_pool_name]) - self.__class__._demos_pool_cache[str(self)] = demos_pool - return demos_pool - return [] + return list(self.metadata(ms)[constants.inference_stream]) @property def has_custom_demos_pool(self): @@ -251,13 +350,22 @@ def use_demos(self): def produce(self, task_instances): """Use the recipe in production to produce model ready query from standard task instance.""" self.before_process_multi_stream() - streams = { - constants.inference_stream: self.production_preprocess(task_instances), - } - if self.use_demos: - streams[self.demos_pool_name] = self.production_demos_pool() - multi_stream = MultiStream.from_iterables(streams) - multi_stream = self.inference(multi_stream) + + ms = MultiStream.from_iterables({constants.inference_stream: task_instances}) + # does not hurt to set metadata + # task_instances are assumed to be as if passed through self.standardization + ms = self.metadata(ms) + if not self.use_demos: + # go with task_instances all the way, it does not need other streams: + ms = self.inference(ms) + return list(ms[constants.inference_stream]) + + streams = self.inference_demos() + # streams stopped before processing + # ms is ready to join, it will get the demos from streams + streams[constants.inference_stream] = ms[constants.inference_stream] + # multi_stream = MultiStream(streams) + multi_stream = self.inference(streams) return list(multi_stream[constants.inference_stream]) def reset(self): @@ -321,13 +429,20 @@ def reset_pipeline(self): augmentor.set_fields(self.card.task.augmentable_inputs) self.processing.steps.append(augmentor) + # for backward compatibility, consume the demos instances even if not pushed into demos field of the ordinary instances, + # in order to use the very same ordinary instances as in back releases. + # one example of consume but not used, and indeed skips over a problematic (json-wise) input: + # prepare/cards/rag/end_to_end/clapnq.py if self.has_custom_demos_pool: self.processing.steps.append( CreateDemosPool( - from_split=self.demos_taken_from, - to_split_names=[self.demos_pool_name, self.demos_taken_from], - to_split_sizes=[int(self.demos_pool_size)], + given_demos_pool=self.given_demos_pool, + from_stream=self.demos_taken_from, + demos_pool_size=self.demos_pool_size + if self.given_demos_pool is None + else None, remove_targets_from_source_split=self.demos_removed_from_data, + to_field=self.demos_pool_field_name, ) ) @@ -346,28 +461,41 @@ def reset_pipeline(self): if isinstance(self.num_demos, int): self.verbalization.steps.append( ConstantSizeSample( - from_stream=self.demos_pool_name, + from_field=self.demos_pool_field_name, to_field=self.demos_field, sampler=self.sampler, sample_size=self.num_demos, + skip_demoed_instances=self.skip_demoed_instances, ) ) self.verbalization.steps.append( - Set(fields={"recipe_metadata/num_demos": self.num_demos}) + Set( + fields={ + "recipe_metadata/num_demos": self.num_demos, + "recipe_metadata/demos_pool_size": self.demos_pool_size, + } + ) ) elif isinstance(self.num_demos, list): self.verbalization.steps.append( RandomSizeSample( - from_stream=self.demos_pool_name, + from_field=self.demos_pool_field_name, to_field=self.demos_field, sampler=self.sampler, sample_sizes=self.num_demos, + skip_demoed_instances=self.skip_demoed_instances, ) ) self.verbalization.steps.append( GetLength(field="demos", to_field="recipe_metadata/num_demos") ) + self.verbalization.steps.append( + Set( + fields={"recipe_metadata/demos_pool_size": self.demos_pool_size} + ) + ) + else: raise ValueError("num_demos must be int or List[int]") @@ -383,9 +511,15 @@ def reset_pipeline(self): template=self.template, demos_field=self.demos_field ) ) + else: self.verbalization.steps.append( - Set(fields={"recipe_metadata/num_demos": 0}) + Set( + fields={ + "recipe_metadata/num_demos": 0, + "recipe_metadata/demos_pool_size": 0, + } + ) ) if isinstance(self.template, list): self.verbalization.steps.append( @@ -410,6 +544,15 @@ def reset_pipeline(self): self.finalize.steps.append(FinalizeDataset(group_by=self.group_by)) def prepare(self): + if self.use_demos: + if (self.demos_pool_size is None) == (self.given_demos_pool is None): + raise ValueError( + f"When using demonstrations, exactly one of either demos_pool_size or given_demos_pool must be set. Got both {'' if self.demos_pool_size else 'not '}set." + ) + # now set self.demos_pool_size for the checks of verify + if self.given_demos_pool: + self.demos_pool_size = len(self.given_demos_pool) + if isinstance(self.template, TemplatesList): self.template = self.template.items self.reset_pipeline() @@ -506,6 +649,8 @@ class StandardRecipe(StandardRecipeWithIndexes): Maximum test instances for the refiner. demos_pool_size (int, optional): Size of the demos pool. + given_demos_pool(List[Dict[str, Any]], optional): + a list of instances to make the demos_pool num_demos (int, optional): Number of demos to be used. demos_pool_name (str, optional): @@ -514,10 +659,16 @@ class StandardRecipe(StandardRecipeWithIndexes): Specifies from where the demos are taken. Default is "train". demos_field (str, optional): Field name for demos. Default is "demos". + demos_pool_field_name (str, optional): + field name to maintain the demos_pool, until sampled from, to make the demos. + Defaults to "_demos_pool_". demos_removed_from_data (bool, optional): whether to remove the demos from the source data, Default is True sampler (Sampler, optional): The Sampler used to select the demonstrations when num_demos > 0. + skip_demoed_instances (bool, optional): + whether to skip pushing demos to an instance whose demos_field is + already populated. Defaults to False. steps (List[StreamingOperator], optional): List of StreamingOperator objects to be used in the recipe. augmentor (Augmentor) : diff --git a/src/unitxt/task.py b/src/unitxt/task.py index 5add3af729..cc36dd84a4 100644 --- a/src/unitxt/task.py +++ b/src/unitxt/task.py @@ -290,6 +290,9 @@ def process( "media": instance.get("media", {}), "recipe_metadata": instance.get("recipe_metadata", {}), } + if "demos" in instance: + # for the case of recipe.skip_demoed_instances + result["demos"] = instance["demos"] if stream_name == constants.inference_stream: return result diff --git a/src/unitxt/test_utils/metrics.py b/src/unitxt/test_utils/metrics.py index b3723447b6..ee0d356150 100644 --- a/src/unitxt/test_utils/metrics.py +++ b/src/unitxt/test_utils/metrics.py @@ -74,7 +74,7 @@ def apply_metric( ti = [] for instance in test_iterable: ti.append(deepcopy(instance)) - multi_stream = MultiStream.from_iterables({"test": ti}) + multi_stream = MultiStream.from_iterables({"test": ti}, copying=True) output_multi_stream = metric(multi_stream) output_stream = output_multi_stream["test"] diff --git a/tests/library/test_api.py b/tests/library/test_api.py index 1dbefeb3a1..47bdaacdb0 100644 --- a/tests/library/test_api.py +++ b/tests/library/test_api.py @@ -34,7 +34,7 @@ def test_load_dataset(self): "target": "5.0", "references": ["5.0"], "source": "Given this sentence: 'A plane is taking off.', on a scale of 1.0 to 5.0, what is the similarity to this text 'An air plane is taking off.'?\n", - "task_data": '{"text1": "A plane is taking off.", "text2": "An air plane is taking off.", "attribute_name": "similarity", "min_value": 1.0, "max_value": 5.0, "attribute_value": 5.0, "metadata": {"data_classification_policy": ["public"], "template": "templates.regression.two_texts.simple", "num_demos": 0}}', + "task_data": '{"text1": "A plane is taking off.", "text2": "An air plane is taking off.", "attribute_name": "similarity", "min_value": 1.0, "max_value": 5.0, "attribute_value": 5.0, "metadata": {"data_classification_policy": ["public"], "template": "templates.regression.two_texts.simple", "demos_pool_size": 0, "num_demos": 0}}', "groups": [], "media": {"audios": [], "images": []}, "subset": [], @@ -65,7 +65,7 @@ def test_load_dataset_with_multi_num_demos(self): "processors.take_first_non_empty_line", "processors.cast_to_float_return_zero_if_failed", ], - "task_data": '{"text1": "A man is spreading shreded cheese on a pizza.", "text2": "A man is spreading shredded cheese on an uncooked pizza.", "attribute_name": "similarity", "min_value": 1.0, "max_value": 5.0, "metadata": {"data_classification_policy": ["public"], "num_demos": 0, "template": "templates.regression.two_texts.simple"}, "attribute_value": 3.799999952316284, "demos": []}', + "task_data": '{"text1": "A man is spreading shreded cheese on a pizza.", "text2": "A man is spreading shredded cheese on an uncooked pizza.", "attribute_name": "similarity", "min_value": 1.0, "max_value": 5.0, "metadata": {"data_classification_policy": ["public"], "demos_pool_size": 2, "num_demos": 0, "template": "templates.regression.two_texts.simple"}, "attribute_value": 3.799999952316284, "demos": []}', "data_classification_policy": ["public"], } self.assertEqual(len(dataset["train"]), 5) @@ -86,7 +86,7 @@ def test_load_dataset_with_multi_templates(self): "target": "5.0", "references": ["5.0"], "source": "text1: A plane is taking off., text2: An air plane is taking off., attribute_name: similarity, min_value: 1.0, max_value: 5.0\n", - "task_data": '{"text1": "A plane is taking off.", "text2": "An air plane is taking off.", "attribute_name": "similarity", "min_value": 1.0, "max_value": 5.0, "attribute_value": 5.0, "metadata": {"data_classification_policy": ["public"], "template": "templates.key_val", "num_demos": 0}}', + "task_data": '{"text1": "A plane is taking off.", "text2": "An air plane is taking off.", "attribute_name": "similarity", "min_value": 1.0, "max_value": 5.0, "attribute_value": 5.0, "metadata": {"data_classification_policy": ["public"], "template": "templates.key_val", "demos_pool_size": 0, "num_demos": 0}}', "groups": [], "media": {"audios": [], "images": []}, "subset": [], @@ -117,7 +117,7 @@ def test_load_dataset_with_benchmark(self): "processors.lower_case_till_punc", ], "source": "Classify the grammatical acceptability of the following text to one of these options: unacceptable, acceptable.\ntext: The sailors rode the breeze clear of the rocks.\nThe grammatical acceptability is ", - "task_data": '{"text": "The sailors rode the breeze clear of the rocks.", "text_type": "text", "classes": ["unacceptable", "acceptable"], "type_of_class": "grammatical acceptability", "label": "acceptable", "metadata": {"data_classification_policy": ["public"], "template": "templates.classification.multi_class.instruction", "num_demos": 0}}', + "task_data": '{"text": "The sailors rode the breeze clear of the rocks.", "text_type": "text", "classes": ["unacceptable", "acceptable"], "type_of_class": "grammatical acceptability", "label": "acceptable", "metadata": {"data_classification_policy": ["public"], "template": "templates.classification.multi_class.instruction", "demos_pool_size": 0, "num_demos": 0}}', "groups": [], "media": {"audios": [], "images": []}, "subset": ["cola"], @@ -132,7 +132,7 @@ def test_load_dataset_with_benchmark(self): "processors.lower_case_till_punc", ], "source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.\npremise: The drain is clogged with hair. It has to be cleaned.\nhypothesis: The hair has to be cleaned.\nThe entailment class is ", - "task_data": '{"text_a": "The drain is clogged with hair. It has to be cleaned.", "text_a_type": "premise", "text_b": "The hair has to be cleaned.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "label": "entailment", "metadata": {"data_classification_policy": ["public"], "template": "templates.classification.multi_class.relation.default", "num_demos": 0}}', + "task_data": '{"text_a": "The drain is clogged with hair. It has to be cleaned.", "text_a_type": "premise", "text_b": "The hair has to be cleaned.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "label": "entailment", "metadata": {"data_classification_policy": ["public"], "template": "templates.classification.multi_class.relation.default", "demos_pool_size": 0, "num_demos": 0}}', "groups": [], "media": {"audios": [], "images": []}, "subset": ["wnli"], @@ -201,6 +201,7 @@ def test_evaluate(self): "metadata": { "data_classification_policy": ["public"], "template": "templates.regression.two_texts.simple", + "demos_pool_size": 0, "num_demos": 0, }, "source": "Given this sentence: 'A plane is taking off.', on a scale of 1.0 to 5.0, what is the similarity to this text 'An air plane is taking off.'?\n", @@ -255,6 +256,7 @@ def test_evaluate_with_groups(self): "metadata": { "data_classification_policy": ["public"], "template": "templates.regression.two_texts.simple", + "demos_pool_size": 0, "num_demos": 0, }, "source": "Given this sentence: 'A plane is taking off.', on a scale of 1.0 to 5.0, what is the similarity to this text 'An air plane is taking off.'?\n", @@ -335,7 +337,7 @@ def test_produce_with_recipe(self): "processors.lower_case_till_punc", ], "source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.\npremise: When Tatyana reached the cabin, her mother was sleeping. She was careful not to disturb her, undressing and climbing back into her berth.\nhypothesis: mother was careful not to disturb her, undressing and climbing back into her berth.\nThe entailment class is entailment\n\npremise: Steve follows Fred's example in everything. He influences him hugely.\nhypothesis: Steve influences him hugely.\nThe entailment class is entailment\n\npremise: It works perfectly\nhypothesis: It works!\nThe entailment class is ", - "task_data": '{"text_a": "It works perfectly", "text_a_type": "premise", "text_b": "It works!", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": [], "num_demos": 2, "template": "templates.classification.multi_class.relation.default"}, "demos": [{"text_a": "When Tatyana reached the cabin, her mother was sleeping. She was careful not to disturb her, undressing and climbing back into her berth.", "text_a_type": "premise", "text_b": "mother was careful not to disturb her, undressing and climbing back into her berth.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"]}, "label": "entailment"}, {"text_a": "Steve follows Fred\'s example in everything. He influences him hugely.", "text_a_type": "premise", "text_b": "Steve influences him hugely.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"]}, "label": "entailment"}]}', + "task_data": '{"text_a": "It works perfectly", "text_a_type": "premise", "text_b": "It works!", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": [], "num_demos": 2, "demos_pool_size": 5, "template": "templates.classification.multi_class.relation.default"}, "demos": [{"text_a": "When Tatyana reached the cabin, her mother was sleeping. She was careful not to disturb her, undressing and climbing back into her berth.", "text_a_type": "premise", "text_b": "mother was careful not to disturb her, undressing and climbing back into her berth.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"]}, "label": "entailment"}, {"text_a": "Steve follows Fred\'s example in everything. He influences him hugely.", "text_a_type": "premise", "text_b": "Steve influences him hugely.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"]}, "label": "entailment"}]}', "groups": [], "subset": [], "media": {"images": [], "audios": []}, @@ -364,7 +366,7 @@ def test_produce_with_task(self): "processors.lower_case_till_punc", ], "source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.\npremise: It works perfectly\nhypothesis: It works!\nThe entailment class is ", - "task_data": '{"text_a": "It works perfectly", "text_a_type": "premise", "text_b": "It works!", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": [], "num_demos": 0, "template": "templates.classification.multi_class.relation.default"}}', + "task_data": '{"text_a": "It works perfectly", "text_a_type": "premise", "text_b": "It works!", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": [], "num_demos": 0, "demos_pool_size": 0, "template": "templates.classification.multi_class.relation.default"}}', "groups": [], "subset": [], "media": {"images": [], "audios": []}, @@ -395,7 +397,7 @@ def test_produce_with_recipe_with_list_of_instances(self): "processors.lower_case_till_punc", ], "source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.\npremise: When Tatyana reached the cabin, her mother was sleeping. She was careful not to disturb her, undressing and climbing back into her berth.\nhypothesis: mother was careful not to disturb her, undressing and climbing back into her berth.\nThe entailment class is entailment\n\npremise: Steve follows Fred's example in everything. He influences him hugely.\nhypothesis: Steve influences him hugely.\nThe entailment class is entailment\n\npremise: It works perfectly\nhypothesis: It works!\nThe entailment class is ", - "task_data": '{"text_a": "It works perfectly", "text_a_type": "premise", "text_b": "It works!", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": [], "num_demos": 2, "template": "templates.classification.multi_class.relation.default"}, "demos": [{"text_a": "When Tatyana reached the cabin, her mother was sleeping. She was careful not to disturb her, undressing and climbing back into her berth.", "text_a_type": "premise", "text_b": "mother was careful not to disturb her, undressing and climbing back into her berth.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"]}, "label": "entailment"}, {"text_a": "Steve follows Fred\'s example in everything. He influences him hugely.", "text_a_type": "premise", "text_b": "Steve influences him hugely.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"]}, "label": "entailment"}]}', + "task_data": '{"text_a": "It works perfectly", "text_a_type": "premise", "text_b": "It works!", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": [], "num_demos": 2, "demos_pool_size": 5, "template": "templates.classification.multi_class.relation.default"}, "demos": [{"text_a": "When Tatyana reached the cabin, her mother was sleeping. She was careful not to disturb her, undressing and climbing back into her berth.", "text_a_type": "premise", "text_b": "mother was careful not to disturb her, undressing and climbing back into her berth.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"]}, "label": "entailment"}, {"text_a": "Steve follows Fred\'s example in everything. He influences him hugely.", "text_a_type": "premise", "text_b": "Steve influences him hugely.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"]}, "label": "entailment"}]}', "groups": [], "subset": [], "media": {"images": [], "audios": []}, diff --git a/tests/library/test_benchmark.py b/tests/library/test_benchmark.py index 21579d97eb..9a6d0056ad 100644 --- a/tests/library/test_benchmark.py +++ b/tests/library/test_benchmark.py @@ -38,7 +38,7 @@ def test_benchmark(self): "target": "acceptable", "references": ["acceptable"], "source": "Classify the grammatical acceptability of the following text to one of these options: unacceptable, acceptable.\n\nUser:text: The sailors rode the breeze clear of the rocks.\nAgent:The grammatical acceptability is ", - "task_data": '{"text": "The sailors rode the breeze clear of the rocks.", "text_type": "text", "classes": ["unacceptable", "acceptable"], "type_of_class": "grammatical acceptability", "metadata": {"data_classification_policy": ["public"], "num_demos": 0, "template": "templates.classification.multi_class.instruction"}, "label": "acceptable"}', + "task_data": '{"text": "The sailors rode the breeze clear of the rocks.", "text_type": "text", "classes": ["unacceptable", "acceptable"], "type_of_class": "grammatical acceptability", "metadata": {"data_classification_policy": ["public"], "num_demos": 0, "demos_pool_size": 0, "template": "templates.classification.multi_class.instruction"}, "label": "acceptable"}', "groups": [], "subset": ["cola"], }, @@ -53,7 +53,7 @@ def test_benchmark(self): "target": "acceptable", "references": ["acceptable"], "source": "Classify the grammatical acceptability of the following text to one of these options: unacceptable, acceptable.\n\nUser:text: The weights made the rope stretch over the pulley.\nAgent:The grammatical acceptability is ", - "task_data": '{"text": "The weights made the rope stretch over the pulley.", "text_type": "text", "classes": ["unacceptable", "acceptable"], "type_of_class": "grammatical acceptability", "metadata": {"data_classification_policy": ["public"], "num_demos": 0, "template": "templates.classification.multi_class.instruction"}, "label": "acceptable"}', + "task_data": '{"text": "The weights made the rope stretch over the pulley.", "text_type": "text", "classes": ["unacceptable", "acceptable"], "type_of_class": "grammatical acceptability", "metadata": {"data_classification_policy": ["public"], "num_demos": 0, "demos_pool_size": 0, "template": "templates.classification.multi_class.instruction"}, "label": "acceptable"}', "groups": [], "subset": ["cola"], }, @@ -72,7 +72,7 @@ def test_benchmark(self): "target": "entailment", "references": ["entailment"], "source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.\n\nUser:premise: The drain is clogged with hair. It has to be cleaned.\nhypothesis: The hair has to be cleaned.\nAgent:The entailment class is ", - "task_data": '{"text_a": "The drain is clogged with hair. It has to be cleaned.", "text_a_type": "premise", "text_b": "The hair has to be cleaned.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"], "num_demos": 0, "template": "templates.classification.multi_class.relation.default"}, "label": "entailment"}', + "task_data": '{"text_a": "The drain is clogged with hair. It has to be cleaned.", "text_a_type": "premise", "text_b": "The hair has to be cleaned.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"], "num_demos": 0, "demos_pool_size": 0, "template": "templates.classification.multi_class.relation.default"}, "label": "entailment"}', "groups": [], "subset": ["wnli"], }, @@ -91,7 +91,7 @@ def test_benchmark(self): "target": "not entailment", "references": ["not entailment"], "source": "Given a premise and hypothesis classify the entailment of the hypothesis to one of entailment, not entailment.\n\nUser:premise: Jane knocked on Susan's door but she did not answer.\nhypothesis: Susan did not answer.\nAgent:The entailment class is ", - "task_data": '{"text_a": "Jane knocked on Susan\'s door but she did not answer.", "text_a_type": "premise", "text_b": "Susan did not answer.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"], "num_demos": 0, "template": "templates.classification.multi_class.relation.default"}, "label": "not entailment"}', + "task_data": '{"text_a": "Jane knocked on Susan\'s door but she did not answer.", "text_a_type": "premise", "text_b": "Susan did not answer.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"], "num_demos": 0, "demos_pool_size": 0, "template": "templates.classification.multi_class.relation.default"}, "label": "not entailment"}', "groups": [], "subset": ["wnli"], }, diff --git a/tests/library/test_recipe.py b/tests/library/test_recipe.py index 38a130572b..a3225fae3b 100644 --- a/tests/library/test_recipe.py +++ b/tests/library/test_recipe.py @@ -15,6 +15,7 @@ from unitxt.templates import InputOutputTemplate, TemplatesList from unitxt.text_utils import print_dict from unitxt.types import Table +from unitxt.utils import recursive_copy from tests.utils import UnitxtTestCase @@ -125,27 +126,13 @@ def test_standard_recipe_production_consistency(self): } ] - self.assertListEqual( - recipe.production_demos_pool(), recipe.production_demos_pool() - ) - self.assertDictEqual( - recipe.produce(instances)[0], - recipe.produce(instances)[0], - ) - - i1 = recipe.production_preprocess(instances)[0] - i2 = recipe.production_preprocess(instances)[0] - for meta_data in ["card", "template", "format", "system_prompt"]: - if meta_data in i1["recipe_metadata"]: - i1["recipe_metadata"][meta_data] = i1["recipe_metadata"][ - meta_data - ]._to_raw_dict() - if not isinstance(i2["recipe_metadata"][meta_data], dict): - i2["recipe_metadata"][meta_data] = i2["recipe_metadata"][ - meta_data - ]._to_raw_dict() + recipe.produce(recursive_copy(instances))[0], + recipe.produce(recursive_copy(instances))[0], + ) + i1 = recipe.production_preprocess(recursive_copy(instances))[0] + i2 = recipe.production_preprocess(recursive_copy(instances))[0] self.assertDictEqual(i1, i2) def test_standard_recipe_production_with_demos(self): @@ -173,7 +160,7 @@ def test_standard_recipe_production_with_demos(self): "data_classification_policy": [], "postprocessors": ["processors.first_character"], "source": "<>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<>\n\n\n\n\nUser: The following are multiple choice questions (with answers) about marketing.\n\nAlthough the content and quality can be as controlled as direct mail, response rates of this medium are lower because of the lack of a personal address mechanism. This media format is known as:\nA. Care lines.\nB. Direct mail.\nC. Inserts.\nD. Door to door.\nAnswer:\nAgent: D\n\nUser: The following are multiple choice questions (with answers) about marketing.\n\n _____________ is a natural outcome when combining demographic and geographic variables.\nA. Geodemographics\nB. Product differentiation.\nC. ANSOFF matrix.\nD. Brand management.\nAnswer:\nAgent: A\n\nUser: The following are multiple choice questions (with answers) about marketing.\n\nIn an organization, the group of people tasked with buying decisions is referred to as the _______________.\nA. Outsourcing unit.\nB. Procurement centre.\nC. Chief executive unit.\nD. Decision-making unit.\nAnswer:\nAgent: D\n\n\nUser:The following are multiple choice questions (with answers) about testing.\n\nwhat?\nA. yes\nB. not\nC. maybe\nAnswer:\nAgent:", - "task_data": '{"topic": "testing", "question": "what?", "choices": ["yes", "not", "maybe"], "options": [" A", " B", " C"], "metadata": {"data_classification_policy": [], "num_demos": 3, "template": "templates.qa.multiple_choice.with_topic.lm_eval_harness"}, "demos": [{"topic": "marketing", "question": "Although the content and quality can be as controlled as direct mail, response rates of this medium are lower because of the lack of a personal address mechanism. This media format is known as:", "choices": ["Care lines.", "Direct mail.", "Inserts.", "Door to door."], "options": [" A", " B", " C", " D"], "metadata": {"data_classification_policy": ["public"]}, "answer": 3}, {"topic": "marketing", "question": " _____________ is a natural outcome when combining demographic and geographic variables.", "choices": ["Geodemographics", "Product differentiation.", "ANSOFF matrix.", "Brand management."], "options": [" A", " B", " C", " D"], "metadata": {"data_classification_policy": ["public"]}, "answer": 0}, {"topic": "marketing", "question": "In an organization, the group of people tasked with buying decisions is referred to as the _______________.", "choices": ["Outsourcing unit.", "Procurement centre.", "Chief executive unit.", "Decision-making unit."], "options": [" A", " B", " C", " D"], "metadata": {"data_classification_policy": ["public"]}, "answer": 3}]}', + "task_data": '{"topic": "testing", "question": "what?", "choices": ["yes", "not", "maybe"], "options": [" A", " B", " C"], "metadata": {"data_classification_policy": [], "demos_pool_size": 5, "num_demos": 3, "template": "templates.qa.multiple_choice.with_topic.lm_eval_harness"}, "demos": [{"topic": "marketing", "question": "Although the content and quality can be as controlled as direct mail, response rates of this medium are lower because of the lack of a personal address mechanism. This media format is known as:", "choices": ["Care lines.", "Direct mail.", "Inserts.", "Door to door."], "options": [" A", " B", " C", " D"], "metadata": {"data_classification_policy": ["public"]}, "answer": 3}, {"topic": "marketing", "question": " _____________ is a natural outcome when combining demographic and geographic variables.", "choices": ["Geodemographics", "Product differentiation.", "ANSOFF matrix.", "Brand management."], "options": [" A", " B", " C", " D"], "metadata": {"data_classification_policy": ["public"]}, "answer": 0}, {"topic": "marketing", "question": "In an organization, the group of people tasked with buying decisions is referred to as the _______________.", "choices": ["Outsourcing unit.", "Procurement centre.", "Chief executive unit.", "Decision-making unit."], "options": [" A", " B", " C", " D"], "metadata": {"data_classification_policy": ["public"]}, "answer": 3}]}', "groups": [], "subset": [], "media": {"images": [], "audios": []}, @@ -185,6 +172,127 @@ def test_standard_recipe_production_with_demos(self): self.assertDictEqual(result, target) self.assertDictEqual(target_task_data, result_task_data) + def test_standard_recipe_with_given_demos(self): + recipe = StandardRecipe( + card="cards.wnli", + template_card_index=0, + ) + for_demos = recipe.inference_demos() + for_demos = recipe.processing(for_demos) + for_demos = recursive_copy(list(for_demos["validation"])) + + recipe2 = StandardRecipe( + card="cards.wnli", + template_card_index=0, + given_demos_pool=for_demos, + ) + + trains = list(recipe2()["train"]) + assert "The entailment class is entailment" not in trains[0]["source"] + + recipe3 = StandardRecipe( + card="cards.wnli", + template_card_index=0, + given_demos_pool=for_demos, + num_demos=3, + ) + + trains = list(recipe3()["train"]) + assert "The entailment class is entailment" in trains[0]["source"] + + def test_standard_recipe_not_duplicating_demos_pool(self): + recipe = StandardRecipe( + card="cards.wnli", + template_card_index=0, + ) + for_demos = recipe.inference_demos() + for_demos = recipe.processing(for_demos) + for_demos = recursive_copy(list(for_demos["validation"])) + + recipe3 = StandardRecipe( + card="cards.wnli", + template_card_index=0, + given_demos_pool=for_demos, + num_demos=3, + ) + + ms = recipe3.inference_demos() + ms = recipe3.processing(ms) + # here the ms stopped before verbalizing, after processing so it still has "_demos_pool_" field + trains = list(ms["train"]) + assert "_demos_pool_" in trains[0] + first_demo_of_first_instance = trains[0]["_demos_pool_"][0] + first_demo_of_second_instance = trains[1]["_demos_pool_"][0] + self.assertDictEqual( + first_demo_of_first_instance, first_demo_of_second_instance + ) + self.assertEqual( + first_demo_of_first_instance["input_fields"]["text_a_type"], "premise" + ) + + # change just the demos in the first instance + first_demo_of_first_instance["input_fields"]["text_a_type"] = "hallelujah" + # verify that the demos in the second instance change as well + self.assertEqual( + first_demo_of_second_instance["input_fields"]["text_a_type"], "hallelujah" + ) + + def test_standard_recipe_with_demoed_instances(self): + recipe = StandardRecipe( + card="cards.wnli", + template_card_index=0, + ) + ms = recipe.loading() + ms = recipe.metadata(ms) + ms = recipe.standardization(ms) + a_standardized_input_instance = next(iter(ms["test"])) + self.assertNotIn("demos", a_standardized_input_instance) + + ms = recipe.loading() + ms = recipe.metadata(ms) + ms = recipe.standardization(ms) + ms = recipe.task(ms) + a_tasked_input_instance = next(iter(ms["validation"])) + self.assertIn( + "I took the water bottle out of the backpack ", + a_tasked_input_instance["input_fields"]["text_a"], + ) + + a_standardized_input_instance["demos"] = [a_tasked_input_instance] + demoed_standardized_input_instance = recursive_copy( + a_standardized_input_instance + ) + + recipe2 = StandardRecipe( + card="cards.wnli", + template_card_index=0, + demos_pool_size=3, + num_demos=1, + skip_demoed_instances=True, + ) + + processed_input_instance = recipe2.produce([a_standardized_input_instance])[0] + self.assertIn( + "premise: I took the water bottle out of the backpack ", + processed_input_instance["source"], + ) + + recipe3 = StandardRecipe( + card="cards.wnli", + template_card_index=0, + demos_pool_size=3, + num_demos=1, + skip_demoed_instances=False, + ) + + processed_input_instance = recipe3.produce( + [demoed_standardized_input_instance] + )[0] + self.assertNotIn( + "premise: I took the water bottle out of the backpack ", + processed_input_instance["source"], + ) + def test_standard_recipe_with_indexes_with_catalog(self): recipe = StandardRecipe( card="cards.wnli", @@ -211,8 +319,11 @@ def test_standard_recipe_with_demos_not_removed_from_data(self): ) stream = recipe() - n_trains_remove_demos = len(list(stream["train"])) - n_demos_remove_demos = len(list(stream["demos_pool"])) + trains = list(stream["train"]) + n_trains_remove_demos = len(trains) + n_demos_remove_demos = json.loads(trains[0]["task_data"])["metadata"][ + "demos_pool_size" + ] recipe = StandardRecipeWithIndexes( card="cards.wnli", @@ -223,8 +334,11 @@ def test_standard_recipe_with_demos_not_removed_from_data(self): ) stream = recipe() - n_trains_keep_demos = len(list(stream["train"])) - n_demos_keep_demos = len(list(stream["demos_pool"])) + trains = list(stream["train"]) + n_trains_keep_demos = len(trains) + n_demos_keep_demos = json.loads(trains[0]["task_data"])["metadata"][ + "demos_pool_size" + ] self.assertEqual( n_trains_keep_demos, n_trains_remove_demos + n_demos_remove_demos @@ -249,7 +363,7 @@ def test_empty_template(self): "target": "not entailment", "references": ["not entailment"], "source": "<>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<>\n\n\n\n\nUser: Emma did not pass the ball to Janie although she was open., premise, She saw that Janie was open., hypothesis, entailment, not entailment, entailment\nAgent: not entailment\n\nUser: The foxes are getting in at night and attacking the chickens. I shall have to kill them., premise, I shall have to kill The foxes., hypothesis, entailment, not entailment, entailment\nAgent: not entailment\n\nUser: Fred is the only man alive who still remembers my father as an infant. When Fred first saw my father, he was twelve years old., premise, When Fred first saw my father, My father was twelve years old., hypothesis, entailment, not entailment, entailment\nAgent: entailment\n\n\nUser:Grace was happy to trade me her sweater for my jacket. She thinks it looks dowdy on her., premise, The sweater looks dowdy on her., hypothesis, entailment, not entailment, entailment\nAgent:", - "task_data": '{"text_a": "Grace was happy to trade me her sweater for my jacket. She thinks it looks dowdy on her.", "text_a_type": "premise", "text_b": "The sweater looks dowdy on her.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"], "num_demos": 3, "template": "templates.empty"}, "label": "not entailment", "demos": [{"text_a": "Emma did not pass the ball to Janie although she was open.", "text_a_type": "premise", "text_b": "She saw that Janie was open.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"]}, "label": "not entailment"}, {"text_a": "The foxes are getting in at night and attacking the chickens. I shall have to kill them.", "text_a_type": "premise", "text_b": "I shall have to kill The foxes.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"]}, "label": "not entailment"}, {"text_a": "Fred is the only man alive who still remembers my father as an infant. When Fred first saw my father, he was twelve years old.", "text_a_type": "premise", "text_b": "When Fred first saw my father, My father was twelve years old.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"]}, "label": "entailment"}]}', + "task_data": '{"text_a": "Grace was happy to trade me her sweater for my jacket. She thinks it looks dowdy on her.", "text_a_type": "premise", "text_b": "The sweater looks dowdy on her.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"], "num_demos": 3, "demos_pool_size": 100, "template": "templates.empty"}, "label": "not entailment", "demos": [{"text_a": "Emma did not pass the ball to Janie although she was open.", "text_a_type": "premise", "text_b": "She saw that Janie was open.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"]}, "label": "not entailment"}, {"text_a": "The foxes are getting in at night and attacking the chickens. I shall have to kill them.", "text_a_type": "premise", "text_b": "I shall have to kill The foxes.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"]}, "label": "not entailment"}, {"text_a": "Fred is the only man alive who still remembers my father as an infant. When Fred first saw my father, he was twelve years old.", "text_a_type": "premise", "text_b": "When Fred first saw my father, My father was twelve years old.", "text_b_type": "hypothesis", "classes": ["entailment", "not entailment"], "type_of_relation": "entailment", "metadata": {"data_classification_policy": ["public"]}, "label": "entailment"}]}', "groups": [], "subset": [], } @@ -294,6 +408,7 @@ def test_key_val_template(self): "type_of_relation": "entailment", "metadata": { "data_classification_policy": ["public"], + "demos_pool_size": 100, "num_demos": 3, "template": "templates.key_val", }, @@ -512,7 +627,7 @@ def test_standard_recipe_with_no_demos_to_take(self): self.assertTrue( str(cm.exception).startswith( - "Unable to fetch instances from 'demos_pool' to 'demos'" + "Input multi-stream is missing a stream named 'train' to take demo instances from for the demos_pool." ) ) @@ -539,7 +654,7 @@ def test_standard_recipe_with_no_demos_to_take(self): self.assertEqual( str(cm.exception), - "num_demos (got: 30) should not exceed demos_pool_size (got: 10)", + "num_demos (got: 30) should not exceed demos_pool_size - 1 (got: 10), (-1: to always allow filtering of a demo identical to the processed instance).", ) def test_standard_recipe_with_no_test(self): @@ -607,7 +722,8 @@ def test_standard_recipe_with_balancer_and_size_limit(self): stream = recipe() counts = collections.Counter() - for instance in stream["train"]: + trains = list(stream["train"]) + for instance in trains: counts[instance["target"]] += 1 self.assertEqual(counts["entailment"], counts["not entailment"], 10) diff --git a/utils/.secrets.baseline b/utils/.secrets.baseline index ec9d704a7a..ccb0e21169 100644 --- a/utils/.secrets.baseline +++ b/utils/.secrets.baseline @@ -151,7 +151,7 @@ "filename": "src/unitxt/loaders.py", "hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742", "is_verified": false, - "line_number": 493, + "line_number": 490, "is_secret": false } ], @@ -184,5 +184,5 @@ } ] }, - "generated_at": "2024-12-09T15:45:50Z" + "generated_at": "2024-12-20T17:49:50Z" }