diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index 0aae6fef7..01269d457 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -513,6 +513,66 @@ def get_index_or_default(lst, item, default=-1): return default +class AggregationReduction(Artifact, Generic[IntermediateType]): + def reduce(self, intermidates: List[IntermediateType]) -> Dict[str, Any]: + pass + + +class DictReduction(AggregationReduction[Dict[str, float]]): + def reduce_list(self, lst: List[float]): + pass + + def reduce(self, intermidates: List[Dict[str, float]]): + lists = {} + for intermidate in intermidates: + for key, val in intermidate.items(): + if key not in lists: + lists[key] = [] + lists[key].append(val) + + result = {} + for key, val_list in lists.items(): + result[key] = self.reduce_list(val_list) + return result + + +class MeanReduction(DictReduction): + def reduce_list(self, lst: List[float]): + return float(nan_mean(lst)) + + +class MaxReduction(DictReduction): + def reduce_list(self, lst: List[float]): + return float(nan_max(lst)) + + +class ReductionInstanceMetric( + MapReduceMetric[PredictionType, IntermediateType], + Generic[PredictionType, IntermediateType], +): + reduction: AggregationReduction[IntermediateType] + + def reduce(self, intermediates: List[IntermediateType]) -> Dict[str, Any]: + return self.reduction.reduce(intermediates) + + def reduce_one(self, intermidate: IntermediateType): + return recursive_copy(intermidate) + + +class AccuracyFast(ReductionInstanceMetric[str, Dict[str, float]]): + main_score = "accuracy" + reduction = MeanReduction() + + def map( + self, prediction: str, references: List[str], task_data: Dict[str, Any] + ) -> Dict[str, float]: + return { + self.main_score: float( + str(prediction) in [str(reference) for reference in references] + ) + } + + class F1Fast(MapReduceMetric[str, Tuple[int, int, List[str]]]): main_score = "f1" diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index 244865bfc..275a56578 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -7,6 +7,7 @@ from unitxt.metrics import ( NER, Accuracy, + AccuracyFast, BinaryAccuracy, BinaryMaxAccuracy, BinaryMaxF1, @@ -215,6 +216,39 @@ def test_accuracy(self): for output, target in zip(outputs, instance_targets): self.assertDictEqual(output["score"]["instance"], target) + def test_accuracy_fast(self): + metric = AccuracyFast() + + predictions = ["A", "B", "C"] + references = [["B", "C"], ["A"], ["B", "C"]] + + outputs = apply_metric( + metric=metric, predictions=predictions, references=references + ) + + expected_global_result = { + "accuracy": 1 / 3, + "score": 1 / 3, + "score_name": "accuracy", + } + + global_result = outputs[0]["score"]["global"].copy() + # Only check the keys that are expected, i.e. exist in expected_global_result + global_result = { + key: value + for key, value in global_result.items() + if key in expected_global_result + } + self.assertDictEqual(global_result, expected_global_result) + + instance_targets = [ + {"accuracy": 0.0, "score": 0.0, "score_name": "accuracy"}, + {"accuracy": 0.0, "score": 0.0, "score_name": "accuracy"}, + {"accuracy": 1.0, "score": 1.0, "score_name": "accuracy"}, + ] + for output, target in zip(outputs, instance_targets): + self.assertDictEqual(output["score"]["instance"], target) + def test_accuracy_with_prefix(self): metric = Accuracy(score_prefix="my_") @@ -1335,6 +1369,14 @@ def test_instance_metric_confidence_interval(self): expected_ci_high=0.87, ) + def test_map_reduce_metric_confidence_interval(self): + """Test the calculation of confidence intervals for an instance metric (Accuracy is used as an instance of an InstanceMetric).""" + self._test_confidence_interval( + metric=AccuracyFast(), + expected_ci_low=0.71, + expected_ci_high=0.87, + ) + def test_instance_metric_with_multiple_scores_confidence_interval(self): self._test_confidence_interval( metric=TokenOverlap(), diff --git a/utils/.secrets.baseline b/utils/.secrets.baseline index 03496ede1..9d171af9f 100644 --- a/utils/.secrets.baseline +++ b/utils/.secrets.baseline @@ -161,7 +161,7 @@ "filename": "src/unitxt/metrics.py", "hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889", "is_verified": false, - "line_number": 2370, + "line_number": 2430, "is_secret": false } ], @@ -184,5 +184,5 @@ } ] }, - "generated_at": "2024-12-26T13:52:54Z" + "generated_at": "2024-12-26T14:29:31Z" }