Skip to content

Commit

Permalink
allow to consume a whole stream
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <[email protected]>
  • Loading branch information
dafnapension committed Dec 24, 2024
1 parent 392b5e3 commit 1d6502e
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 16 deletions.
75 changes: 59 additions & 16 deletions src/unitxt/standard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import json
import sys
from typing import Any, Dict, Generator, List, Optional, Union

from .artifact import fetch_artifact
Expand Down Expand Up @@ -39,22 +40,34 @@
class CreateDemosPool(MultiStreamOperator):
from_stream: str = None
demos_pool_size: int = None
remove_targets_from_source_split: bool = None
demos_removed_from_data: bool = None
to_field: str = constants.demos_pool_field

# flake8: noqa: B007
def process(self, multi_stream: MultiStream) -> MultiStream:
# generate the demos_pool as a selection of demos_pool_size distinct instances
# (distinct by their "input_fields" field). The selection is taken from stream named from_stream.
# The selected instances are later treated as ordinary instances or not, depending on parameter
# remove_targets_from_source_split.
# demos_removed_from_data.
# The selection of instances is done from the first instances of the stream named from_stream.
# instances that are not distinct from previously selected demo instances, are kept aside, to be later
# treated like all the remaining instances of stream from_stream.
if self.from_stream not in multi_stream:
raise ValueError(
f"Input multi-stream is missing a stream named '{self.from_stream}' to take demo instances from for the demos_pool."
)
if (
self.demos_removed_from_data is not None
and self.demos_removed_from_data is True
and (self.demos_pool_size == sys.maxsize)
):
# going to consume the whole of input stream named self.from_stream for demo instances,
# and not let demos instances to behave as regular instances. so self.from_stream
# ends here its life as an input stream that is expected to reach the end of the recipe
if len(multi_stream) == 1:
raise ValueError(
f"The single input stream, '{self.from_stream}' is to be wholly consumed for generating demos, and no instance is left to use these demos."
)
from_stream = multi_stream[self.from_stream]
demos_pool = []
input_fields_of_demos_pool = []
Expand All @@ -80,9 +93,31 @@ def process(self, multi_stream: MultiStream) -> MultiStream:
demos_pool = recursive_copy(demos_pool)

set_demos_pool = Set(fields={self.to_field: demos_pool})
if not self.remove_targets_from_source_split:
if (
self.demos_removed_from_data is not None
and self.demos_removed_from_data is False
):
# all input instances go out. No one is "killed" because selected as demo
return set_demos_pool(multi_stream)

if (
self.demos_removed_from_data is not None
and self.demos_removed_from_data is True
):
if self.demos_pool_size == sys.maxsize:
# consume the whole of input stream self.from_stream, just for demos, and do not
# take any of its instances to behave as a non-demo instance, i.e., a regular instance
# that consume the demos
out_ms = MultiStream(
{
stream_name: multi_stream[stream_name]
for stream_name in multi_stream
if stream_name != self.from_stream
}
)
return set_demos_pool(out_ms)

# self.demos_removed_from_data and not consume the whole of self.from_stream just for demos
def from_stream_generator(
first_layer: list, ms: MultiStream, stream_name: str, start: int
) -> Generator:
Expand Down Expand Up @@ -162,22 +197,21 @@ class DatasetRecipe(SourceSequentialOperator):
max_test_instances (int, optional):
Maximum test instances for the refiner.
demos_pool_size (int, optional):
Size of the demos pool.
Size of the demos pool. -1 for taking the whole of stream 'demos_taken_from'.
demos_pool(List[Dict[str, Any]], optional):
a list of instances to make the demos_pool
num_demos (int, optional):
Number of demos to be used.
demos_pool_name (str, optional):
Name of the demos pool. Default is constants.demos_pool_field
Number of demos to add to each instance, to become part of the source to be generated for this instance.
demos_taken_from (str, optional):
Specifies from where the demos are taken. Default is "train".
Specifies the stream from where the demos are taken. Default is "train".
demos_field (str, optional):
Field name for demos. Default is "demos".
The num_demos demos selected for an instance are stored in this field of that instance.
demos_pool_field_name (str, optional):
field name to maintain the demos_pool, until sampled from, to make the demos.
field name to maintain the demos_pool, until sampled from, in order to make the demos.
Defaults to constants.demos_pool_field.
demos_removed_from_data (bool, optional):
whether to remove the demos from the source data, Default is True
whether to remove the demos taken to demos_pool from the source data, Default is True
sampler (Sampler, optional):
The Sampler used to select the demonstrations when num_demos > 0.
skip_demoed_instances (bool, optional):
Expand Down Expand Up @@ -278,7 +312,12 @@ def verify(self):
raise ValueError(
f"num_demos (got: {self.max_demos_size}) should not exceed demos_pool_size - 1 (got: {self.demos_pool_size}), (-1: to always allow filtering of a demo identical to the processed instance)."
)
if self.loader_limit and self.demos_pool_size > self.loader_limit:
if (
(not self.demos_pool)
and (self.demos_pool_size != sys.maxsize)
and self.loader_limit
and (self.demos_pool_size > self.loader_limit)
):
raise ValueError(
f"demos_pool_size should not exceed loader_limit ({self.loader_limit}), Got demos_pool_size={self.demos_pool_size}"
)
Expand Down Expand Up @@ -405,7 +444,9 @@ def production_preprocess(self, task_instances):

@property
def has_custom_demos_pool(self):
return self.demos_pool_size is not None and self.demos_pool_size > 0
return self.demos_pool_size is not None and (
self.demos_pool_size > 0 or self.demos_pool_size == -1
)

@property
def use_demos(self):
Expand Down Expand Up @@ -512,7 +553,7 @@ def reset_pipeline(self):
demos_pool_size=self.demos_pool_size
if self.demos_pool is None
else None,
remove_targets_from_source_split=self.demos_removed_from_data,
demos_removed_from_data=self.demos_removed_from_data,
to_field=self.demos_pool_field_name,
)
)
Expand Down Expand Up @@ -674,9 +715,11 @@ def prepare(self):
+ "that together direct its production."
)

# now set self.demos_pool_size for the checks of verify
if self.demos_pool:
self.demos_pool_size = len(self.demos_pool)
# now set self.demos_pool_size for the checks done by verify
if self.demos_pool:
self.demos_pool_size = len(self.demos_pool)
if self.demos_pool_size is not None and self.demos_pool_size == -1:
self.demos_pool_size = sys.maxsize

if isinstance(self.template, TemplatesList):
self.template = self.template.items
Expand Down
66 changes: 66 additions & 0 deletions tests/library/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
import copy
import json
import re
import sys
from typing import Any, Dict

from unitxt import dataset_file
from unitxt.artifact import fetch_artifact
from unitxt.card import TaskCard
from unitxt.catalog import get_from_catalog
from unitxt.formats import SystemFormat
from unitxt.loaders import LoadFromDictionary
from unitxt.serializers import SingleTypeSerializer, TableSerializer
from unitxt.splitters import SplitRandomMix
from unitxt.standard import DatasetRecipe
from unitxt.task import Task
from unitxt.templates import InputOutputTemplate, TemplatesList
Expand Down Expand Up @@ -305,6 +308,69 @@ def test_dataset_recipe_with_demoed_instances(self):
processed_input_instance["source"],
)

# flake8: noqa: C416
def test_dataset_recipe_with_whole_stream_to_become_demos(self):
# glue.wnli has: train: 635 instances, validation: 71 and test: 146 (together: 852)
# take the whole of validation tobecome demos_pool
recipe = DatasetRecipe(
card="cards.wnli",
system_prompt="system_prompts.models.llama",
template_card_index=0,
format="formats.user_agent",
demos_taken_from="validation",
demos_pool_size=-1,
num_demos=3,
)
ms = recipe()

# assert 'validation' is wholly consumed for demos, not showing at the end of recipe
self.assertSetEqual({stream_name for stream_name in ms}, {"train", "test"})

tests = list(ms["test"])
task_data = json.loads(tests[0]["task_data"])
# assert maxsize is written as demos_pool_size
self.assertEqual(task_data["metadata"]["demos_pool_size"], sys.maxsize)

# flake8: noqa: C400
def test_dataset_recipe_with_whole_stream_to_become_demos_and_no_stream_left(self):
# tweaking wnli to become a card going through preprocess_steps with just one stream: 'validation'
wnli_card = get_from_catalog("cards.wnli")
wnli_card.preprocess_steps[0] = SplitRandomMix({"validation": "train[5%]"})

# now consume that single stream wholly for demos
with self.assertRaises(ValueError) as ve:
recipe = DatasetRecipe(
card=wnli_card,
system_prompt="system_prompts.models.llama",
template_card_index=0,
format="formats.user_agent",
demos_taken_from="validation",
demos_pool_size=-1,
num_demos=3,
)
ms = recipe()
# error: no instance is left to use the demos_pool made of the wholly consumed single input stream
self.assertEqual(
"The single input stream, 'validation' is to be wholly consumed for generating demos, and no instance is left to use these demos.",
str(ve.exception),
)

# but if recipe.demos_removed_from_data is false, that very stream will use the demos_pool
# and reach the end
recipe = DatasetRecipe(
card=wnli_card,
system_prompt="system_prompts.models.llama",
template_card_index=0,
format="formats.user_agent",
demos_taken_from="validation",
demos_pool_size=-1,
num_demos=3,
demos_removed_from_data=False,
)
ms = recipe()
self.assertListEqual(["validation"], [stream_name for stream_name in ms])
self.assertEqual(10, len(list(ms["validation"])))

def test_dataset_recipe_with_catalog_wnli(self):
recipe = DatasetRecipe(
card="cards.wnli",
Expand Down

0 comments on commit 1d6502e

Please sign in to comment.