From 9842108405b1f919ed9170c4402d7aae57726e94 Mon Sep 17 00:00:00 2001 From: Elron Bandel Date: Sun, 19 Jan 2025 14:35:40 +0200 Subject: [PATCH] Ensure fusion do not call streams before use (#1518) * Ensure fusion do not call streams before use Signed-off-by: elronbandel * Fix fusion_generator to skip unavailable splits in named subsets Signed-off-by: elronbandel * Add splits and subset handling to BaseBenchmark class Signed-off-by: elronbandel * Fix iterator assignment in WeightedFusion to use multi_stream for split access Signed-off-by: elronbandel * Refactor subset initialization in BaseBenchmark to use NonPositionalField Signed-off-by: elronbandel --------- Signed-off-by: elronbandel --- src/unitxt/benchmark.py | 19 +++++++++++++------ src/unitxt/fusion.py | 29 ++++++++++++++--------------- tests/library/test_fusion.py | 5 +++-- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/src/unitxt/benchmark.py b/src/unitxt/benchmark.py index a3f4562d8..b5d8c97c1 100644 --- a/src/unitxt/benchmark.py +++ b/src/unitxt/benchmark.py @@ -1,9 +1,9 @@ from abc import abstractmethod -from typing import Dict, Union +from typing import Dict, List, Optional, Union from .dataclass import NonPositionalField from .formats import Format -from .fusion import FixedFusion, WeightedFusion +from .fusion import FixedFusion from .operator import SourceOperator from .standard import DatasetRecipe from .stream import MultiStream @@ -15,6 +15,10 @@ class BaseBenchmark(SourceOperator): num_demos: int = NonPositionalField(default=None) system_prompt: SystemPrompt = NonPositionalField(default=None) loader_limit: int = NonPositionalField(default=None) + splits: List[str] = NonPositionalField( + default_factory=lambda: ["train", "validation", "test"] + ) + subset: Optional[str] = NonPositionalField(default=None) @abstractmethod def reset(self): @@ -65,14 +69,17 @@ def prepare(self): def process( self, ) -> MultiStream: + if self.subset is not None: + subsets = {self.subset: self.subsets[self.subset]} + else: + subsets = self.subsets if self.max_total_samples is None: operator = FixedFusion( - subsets=self.subsets, + subsets=subsets, max_instances_per_subset=self.max_samples_per_subset, + include_splits=self.splits, ) else: - operator = WeightedFusion( - subsets=self.subsets, max_total_samples=self.max_total_samples - ) + raise NotImplementedError() return operator() diff --git a/src/unitxt/fusion.py b/src/unitxt/fusion.py index cfa00c1fe..c37d3b035 100644 --- a/src/unitxt/fusion.py +++ b/src/unitxt/fusion.py @@ -32,23 +32,19 @@ def prepare_subsets(self): self.named_subsets = {} if isinstance(self.subsets, list): for i in range(len(self.subsets)): - self.named_subsets[i] = self.subsets[i]() + self.named_subsets[i] = self.subsets[i] else: for name, origin in self.subsets.items(): try: - self.named_subsets[name] = origin() + self.named_subsets[name] = origin except Exception as e: raise RuntimeError(f"Exception in subset: {name}") from e def splits(self) -> List[str]: self.prepare_subsets() - splits = [] - for _, origin in self.named_subsets.items(): - for s in origin.keys(): - if s not in splits: - if self.include_splits is None or s in self.include_splits: - splits.append(s) - return splits + if self.include_splits is not None: + return self.include_splits + return ["train", "test", "validation"] def process( self, @@ -80,11 +76,12 @@ def prepare(self): # flake8: noqa: C901 def fusion_generator(self, split) -> Generator: for origin_name, origin in self.named_subsets.items(): - if split not in origin: + multi_stream = origin() + if split not in multi_stream: continue emitted_from_this_split = 0 try: - for instance in origin[split]: + for instance in multi_stream[split]: if ( self.max_instances_per_subset is not None and emitted_from_this_split >= self.max_instances_per_subset @@ -138,10 +135,12 @@ def prepare(self): ) def fusion_generator(self, split) -> Generator: - iterators = { - named_origin: iter(origin[split]) - for named_origin, origin in self.named_subsets.items() - } + iterators = {} + for origin_name, origin in self.named_subsets.items(): + multi_stream = origin() + if split not in multi_stream: + continue + iterators[origin_name] = iter(multi_stream[split]) total_examples = 0 random_generator = new_random_generator(sub_seed="weighted_fusion_" + split) while ( diff --git a/tests/library/test_fusion.py b/tests/library/test_fusion.py index 081554bf1..ed1ed5bf4 100644 --- a/tests/library/test_fusion.py +++ b/tests/library/test_fusion.py @@ -60,6 +60,7 @@ def compare_stream(self, stream, expected_stream): def test_nonoverlapping_splits_fusion(self): operator = FixedFusion( + include_splits=["train", "test"], subsets={ "origin_train": IterableSource( {"train": [{"x": "x1"}, {"x": "x2"}, {"x": "x3"}]} @@ -217,7 +218,7 @@ def test_over_bounded_weighted_fusion(self): {"b": "y2", "subset": ["origin2"]}, ], } - for key in res: + for key in ["train", "test"]: self.compare_stream(targets[key], list(res[key])) operator = WeightedFusion( @@ -290,7 +291,7 @@ def test_over_bounded_weighted_fusion(self): {"b": "y5", "subset": ["origin2"]}, ], } - for key in res: + for key in ["train", "test"]: self.compare_stream(targets[key], list(res[key])) targets = [