diff --git a/src/unitxt/standard.py b/src/unitxt/standard.py index 12efd99a7..c0c959e39 100644 --- a/src/unitxt/standard.py +++ b/src/unitxt/standard.py @@ -92,8 +92,18 @@ def from_stream_generator( }, ) else: - new_streams[stream_name] = multi_stream[stream_name] + new_streams[stream_name] = ReusableGenerator( + generator=from_stream_generator, + gen_kwargs={ + "first_layer": [], + "ms": multi_stream, + "stream_name": stream_name, + "start": 0, + }, + ) + if self.sample is None: + return MultiStream.from_generators(new_streams) self.sample.demos_pool = demos_pool return self.sample(new_streams) @@ -297,15 +307,6 @@ def production_preprocess(self, task_instances): ms = MultiStream.from_iterables({constants.inference_stream: task_instances}) return list(self.inference_instance(ms)[constants.inference_stream]) - # def production_demos_pool(self): - # if self.use_demos: - # demos_pool = self.__class__._demos_pool_cache.get(str(self), None) - # if demos_pool is None: - # demos_pool = list(self.inference_demos()[self.demos_pool_name]) - # self.__class__._demos_pool_cache[str(self)] = demos_pool - # return demos_pool - # return [] - @property def has_custom_demos_pool(self): return self.demos_pool_size is not None and self.demos_pool_size > 0 @@ -398,9 +399,14 @@ def reset_pipeline(self): self.processing.steps.append(augmentor) # self.prepare_refiners() - # for backward compatibility move to after demos are taken from train - - if self.use_demos: + # for backward compatibility move refiners to after demos are taken from train + + # for backward compatibility, consume the demos instances even if not pushed into demos field of the ordinary instances, + # to use the very same ordinary instances as in back releases. + # one example of consume but not used, and indeed skips over a problematic (json-wise) input: + # prepare/cards/rag/end_to_end/clapnq.py + # if self.use_demos: + if self.has_custom_demos_pool: if self.sampler is None: if self.card.sampler is None: raise ValueError( @@ -414,53 +420,79 @@ def reset_pipeline(self): # filtering is done to not set a demo instance with same input_fields as the processed ordinary instance, # we prepare a demo_pool of at least num_demos +1 instances demos_pool_size = self.demos_pool_size - if demos_pool_size < self.max_demos_size + 1: - demos_pool_size = self.max_demos_size + 1 - - if isinstance(self.num_demos, int): - self.processing.steps.append( - CreateDemosPoolAndSpreadDemos( - from_stream=self.demos_taken_from, - demos_pool_size=demos_pool_size, - remove_targets_from_source_split=self.demos_removed_from_data, - sample=ConstantSizeSample( - to_field=self.demos_field, - sampler=self.sampler, - sample_size=self.num_demos, - ), + if self.use_demos: + if demos_pool_size < self.max_demos_size + 1: + demos_pool_size = self.max_demos_size + 1 + + if self.use_demos: + if isinstance(self.num_demos, int): + sample = ConstantSizeSample( + to_field=self.demos_field, + sampler=self.sampler, + sample_size=self.num_demos, ) - ) + self.verbalization.steps.append( + Set( + fields={ + "recipe_metadata/num_demos": self.num_demos, + "recipe_metadata/demos_pool_size": demos_pool_size, + } + ) + ) + + elif isinstance(self.num_demos, list): + sample = RandomSizeSample( + to_field=self.demos_field, + sampler=self.sampler, + sample_sizes=self.num_demos, + ) + + self.verbalization.steps.append( + GetLength(field="demos", to_field="recipe_metadata/num_demos") + ) + self.verbalization.steps.append( + Set(fields={"recipe_metadata/demos_pool_size": demos_pool_size}) + ) + + else: + raise ValueError("num_demos must be int or List[int]") + + else: + # consume demos pool, but does not want to push anything to demos field of the ordinary distances + sample = None self.verbalization.steps.append( Set( fields={ - "recipe_metadata/num_demos": self.num_demos, + "recipe_metadata/num_demos": 0, "recipe_metadata/demos_pool_size": demos_pool_size, } ) ) - elif isinstance(self.num_demos, list): - self.processing.steps.append( - CreateDemosPoolAndSpreadDemos( - from_stream=self.demos_taken_from, - demos_pool_size=demos_pool_size, - remove_targets_from_source_split=self.demos_removed_from_data, - sample=RandomSizeSample( - to_field=self.demos_field, - sampler=self.sampler, - sample_sizes=self.num_demos, - ), - ) - ) - self.verbalization.steps.append( - GetLength(field="demos", to_field="recipe_metadata/num_demos") + # here we have the sample suiting self.num_demos + self.processing.steps.append( + CreateDemosPoolAndSpreadDemos( + from_stream=self.demos_taken_from, + demos_pool_size=demos_pool_size, + remove_targets_from_source_split=self.demos_removed_from_data, + sample=sample, ) - self.verbalization.steps.append( - Set(fields={"recipe_metadata/demos_pool_size": demos_pool_size}) + ) + + else: + self.verbalization.steps.append( + Set( + fields={ + "recipe_metadata/num_demos": 0, + "recipe_metadata/demos_pool_size": 0, + } ) - else: - raise ValueError("num_demos must be int or List[int]") + ) + + # for backward compatibility, move the refiners here, after demos are taken from train + self.prepare_refiners() + if self.has_custom_demos_pool and self.use_demos: if isinstance(self.template, list): self.verbalization.steps.append( ApplyRandomTemplate( @@ -473,16 +505,8 @@ def reset_pipeline(self): template=self.template, demos_field=self.demos_field ) ) - else: - self.verbalization.steps.append( - Set( - fields={ - "recipe_metadata/num_demos": 0, - "recipe_metadata/demos_pool_size": 0, - } - ) - ) + else: if isinstance(self.template, list): self.verbalization.steps.append( ApplyRandomTemplate(templates=self.template) @@ -492,9 +516,6 @@ def reset_pipeline(self): ApplySingleTemplate(template=self.template) ) - # for backward compatibility, move the refiners here, after demos are taken from train - self.prepare_refiners() - self.verbalization.steps.append(self.system_prompt) self.verbalization.steps.append(self.format)