Skip to content

Commit

Permalink
Merge branch 'main' into llmjudge-changes
Browse files Browse the repository at this point in the history
  • Loading branch information
elronbandel authored Jan 19, 2025
2 parents abe7447 + 9842108 commit 760a0ef
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 23 deletions.
19 changes: 13 additions & 6 deletions src/unitxt/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from abc import abstractmethod
from typing import Dict, Union
from typing import Dict, List, Optional, Union

from .dataclass import NonPositionalField
from .formats import Format
from .fusion import FixedFusion, WeightedFusion
from .fusion import FixedFusion
from .operator import SourceOperator
from .standard import DatasetRecipe
from .stream import MultiStream
Expand All @@ -15,6 +15,10 @@ class BaseBenchmark(SourceOperator):
num_demos: int = NonPositionalField(default=None)
system_prompt: SystemPrompt = NonPositionalField(default=None)
loader_limit: int = NonPositionalField(default=None)
splits: List[str] = NonPositionalField(
default_factory=lambda: ["train", "validation", "test"]
)
subset: Optional[str] = NonPositionalField(default=None)

@abstractmethod
def reset(self):
Expand Down Expand Up @@ -65,14 +69,17 @@ def prepare(self):
def process(
self,
) -> MultiStream:
if self.subset is not None:
subsets = {self.subset: self.subsets[self.subset]}
else:
subsets = self.subsets
if self.max_total_samples is None:
operator = FixedFusion(
subsets=self.subsets,
subsets=subsets,
max_instances_per_subset=self.max_samples_per_subset,
include_splits=self.splits,
)
else:
operator = WeightedFusion(
subsets=self.subsets, max_total_samples=self.max_total_samples
)
raise NotImplementedError()

return operator()
29 changes: 14 additions & 15 deletions src/unitxt/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,19 @@ def prepare_subsets(self):
self.named_subsets = {}
if isinstance(self.subsets, list):
for i in range(len(self.subsets)):
self.named_subsets[i] = self.subsets[i]()
self.named_subsets[i] = self.subsets[i]
else:
for name, origin in self.subsets.items():
try:
self.named_subsets[name] = origin()
self.named_subsets[name] = origin
except Exception as e:
raise RuntimeError(f"Exception in subset: {name}") from e

def splits(self) -> List[str]:
self.prepare_subsets()
splits = []
for _, origin in self.named_subsets.items():
for s in origin.keys():
if s not in splits:
if self.include_splits is None or s in self.include_splits:
splits.append(s)
return splits
if self.include_splits is not None:
return self.include_splits
return ["train", "test", "validation"]

def process(
self,
Expand Down Expand Up @@ -80,11 +76,12 @@ def prepare(self):
# flake8: noqa: C901
def fusion_generator(self, split) -> Generator:
for origin_name, origin in self.named_subsets.items():
if split not in origin:
multi_stream = origin()
if split not in multi_stream:
continue
emitted_from_this_split = 0
try:
for instance in origin[split]:
for instance in multi_stream[split]:
if (
self.max_instances_per_subset is not None
and emitted_from_this_split >= self.max_instances_per_subset
Expand Down Expand Up @@ -138,10 +135,12 @@ def prepare(self):
)

def fusion_generator(self, split) -> Generator:
iterators = {
named_origin: iter(origin[split])
for named_origin, origin in self.named_subsets.items()
}
iterators = {}
for origin_name, origin in self.named_subsets.items():
multi_stream = origin()
if split not in multi_stream:
continue
iterators[origin_name] = iter(multi_stream[split])
total_examples = 0
random_generator = new_random_generator(sub_seed="weighted_fusion_" + split)
while (
Expand Down
5 changes: 3 additions & 2 deletions tests/library/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def compare_stream(self, stream, expected_stream):

def test_nonoverlapping_splits_fusion(self):
operator = FixedFusion(
include_splits=["train", "test"],
subsets={
"origin_train": IterableSource(
{"train": [{"x": "x1"}, {"x": "x2"}, {"x": "x3"}]}
Expand Down Expand Up @@ -217,7 +218,7 @@ def test_over_bounded_weighted_fusion(self):
{"b": "y2", "subset": ["origin2"]},
],
}
for key in res:
for key in ["train", "test"]:
self.compare_stream(targets[key], list(res[key]))

operator = WeightedFusion(
Expand Down Expand Up @@ -290,7 +291,7 @@ def test_over_bounded_weighted_fusion(self):
{"b": "y5", "subset": ["origin2"]},
],
}
for key in res:
for key in ["train", "test"]:
self.compare_stream(targets[key], list(res[key]))

targets = [
Expand Down

0 comments on commit 760a0ef

Please sign in to comment.