From 36b46e38face0cb0a0152e1774522a65b0782c10 Mon Sep 17 00:00:00 2001 From: dafnapension Date: Wed, 22 Jan 2025 12:33:49 +0200 Subject: [PATCH] set trust remote Signed-off-by: dafnapension --- performance/bluebench_profiler.py | 26 +++++++++++++------------- src/unitxt/test_utils/card.py | 31 +++++++++++++++++-------------- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/performance/bluebench_profiler.py b/performance/bluebench_profiler.py index 0e4634c6e..55adee8fe 100644 --- a/performance/bluebench_profiler.py +++ b/performance/bluebench_profiler.py @@ -59,25 +59,25 @@ def profiler_instantiate_benchmark_recipe( def profiler_generate_benchmark_dataset( self, benchmark_recipe: Benchmark, split: str, **kwargs ) -> List[Dict[str, Any]]: + stream = benchmark_recipe()[split] + + # to charge here for the time of generating all instances of the split + return list(stream) + + def profiler_do_the_profiling(self, dataset_query: str, split: str, **kwargs): with settings.context( disable_hf_datasets_cache=False, allow_unverified_code=True, ): - stream = benchmark_recipe()[split] - - # to charge here for the time of generating all instances - return list(stream) + benchmark_recipe = self.profiler_instantiate_benchmark_recipe( + dataset_query=dataset_query, **kwargs + ) - def profiler_do_the_profiling(self, dataset_query: str, split: str, **kwargs): - benchmark_recipe = self.profiler_instantiate_benchmark_recipe( - dataset_query=dataset_query, **kwargs - ) - - dataset = self.profiler_generate_benchmark_dataset( - benchmark_recipe=benchmark_recipe, split=split, **kwargs - ) + dataset = self.profiler_generate_benchmark_dataset( + benchmark_recipe=benchmark_recipe, split=split, **kwargs + ) - logger.critical(f"length of evaluation_result: {len(dataset)}") + logger.critical(f"length of bluegench generated dataset: {len(dataset)}") dataset_query = "benchmarks.bluebench[loader_limit=30,max_samples_per_subset=30]" diff --git a/src/unitxt/test_utils/card.py b/src/unitxt/test_utils/card.py index 1d4edd56d..fb60bb19c 100644 --- a/src/unitxt/test_utils/card.py +++ b/src/unitxt/test_utils/card.py @@ -291,18 +291,21 @@ def test_card( else: template_card_indices = range(len(card.templates)) - for template_card_index in template_card_indices: - examples = load_examples_from_dataset_recipe( - card, template_card_index=template_card_index, debug=debug, **kwargs - ) - if test_exact_match_score_when_predictions_equal_references: - test_correct_predictions( - examples=examples, strict=strict, exact_match_score=exact_match_score - ) - if test_full_mismatch_score_with_full_mismatch_prediction_values: - test_wrong_predictions( - examples=examples, - strict=strict, - maximum_full_mismatch_score=maximum_full_mismatch_score, - full_mismatch_prediction_values=full_mismatch_prediction_values, + with settings.context(allow_unverified_code=True): + for template_card_index in template_card_indices: + examples = load_examples_from_dataset_recipe( + card, template_card_index=template_card_index, debug=debug, **kwargs ) + if test_exact_match_score_when_predictions_equal_references: + test_correct_predictions( + examples=examples, + strict=strict, + exact_match_score=exact_match_score, + ) + if test_full_mismatch_score_with_full_mismatch_prediction_values: + test_wrong_predictions( + examples=examples, + strict=strict, + maximum_full_mismatch_score=maximum_full_mismatch_score, + full_mismatch_prediction_values=full_mismatch_prediction_values, + )