Skip to content

Commit

Permalink
Add AccuracyFast metric and associated reduction classes for improved…
Browse files Browse the repository at this point in the history
… accuracy calculations

Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel committed Dec 26, 2024
1 parent eae7334 commit addf0ce
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 2 deletions.
60 changes: 60 additions & 0 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,66 @@ def get_index_or_default(lst, item, default=-1):
return default


class AggregationReduction(Artifact, Generic[IntermediateType]):
def reduce(self, intermidates: List[IntermediateType]) -> Dict[str, Any]:
pass


class DictReduction(AggregationReduction[Dict[str, float]]):
def reduce_list(self, lst: List[float]):
pass

def reduce(self, intermidates: List[Dict[str, float]]):
lists = {}
for intermidate in intermidates:
for key, val in intermidate.items():
if key not in lists:
lists[key] = []
lists[key].append(val)

result = {}
for key, val_list in lists.items():
result[key] = self.reduce_list(val_list)
return result


class MeanReduction(DictReduction):
def reduce_list(self, lst: List[float]):
return float(nan_mean(lst))


class MaxReduction(DictReduction):
def reduce_list(self, lst: List[float]):
return float(nan_max(lst))


class ReductionInstanceMetric(
MapReduceMetric[PredictionType, IntermediateType],
Generic[PredictionType, IntermediateType],
):
reduction: AggregationReduction[IntermediateType]

def reduce(self, intermediates: List[IntermediateType]) -> Dict[str, Any]:
return self.reduction.reduce(intermediates)

def reduce_one(self, intermidate: IntermediateType):
return recursive_copy(intermidate)


class AccuracyFast(ReductionInstanceMetric[str, Dict[str, float]]):
main_score = "accuracy"
reduction = MeanReduction()

def map(
self, prediction: str, references: List[str], task_data: Dict[str, Any]
) -> Dict[str, float]:
return {
self.main_score: float(
str(prediction) in [str(reference) for reference in references]
)
}


class F1Fast(MapReduceMetric[str, Tuple[int, int, List[str]]]):
main_score = "f1"

Expand Down
42 changes: 42 additions & 0 deletions tests/library/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from unitxt.metrics import (
NER,
Accuracy,
AccuracyFast,
BinaryAccuracy,
BinaryMaxAccuracy,
BinaryMaxF1,
Expand Down Expand Up @@ -215,6 +216,39 @@ def test_accuracy(self):
for output, target in zip(outputs, instance_targets):
self.assertDictEqual(output["score"]["instance"], target)

def test_accuracy_fast(self):
metric = AccuracyFast()

predictions = ["A", "B", "C"]
references = [["B", "C"], ["A"], ["B", "C"]]

outputs = apply_metric(
metric=metric, predictions=predictions, references=references
)

expected_global_result = {
"accuracy": 1 / 3,
"score": 1 / 3,
"score_name": "accuracy",
}

global_result = outputs[0]["score"]["global"].copy()
# Only check the keys that are expected, i.e. exist in expected_global_result
global_result = {
key: value
for key, value in global_result.items()
if key in expected_global_result
}
self.assertDictEqual(global_result, expected_global_result)

instance_targets = [
{"accuracy": 0.0, "score": 0.0, "score_name": "accuracy"},
{"accuracy": 0.0, "score": 0.0, "score_name": "accuracy"},
{"accuracy": 1.0, "score": 1.0, "score_name": "accuracy"},
]
for output, target in zip(outputs, instance_targets):
self.assertDictEqual(output["score"]["instance"], target)

def test_accuracy_with_prefix(self):
metric = Accuracy(score_prefix="my_")

Expand Down Expand Up @@ -1335,6 +1369,14 @@ def test_instance_metric_confidence_interval(self):
expected_ci_high=0.87,
)

def test_map_reduce_metric_confidence_interval(self):
"""Test the calculation of confidence intervals for an instance metric (Accuracy is used as an instance of an InstanceMetric)."""
self._test_confidence_interval(
metric=AccuracyFast(),
expected_ci_low=0.71,
expected_ci_high=0.87,
)

def test_instance_metric_with_multiple_scores_confidence_interval(self):
self._test_confidence_interval(
metric=TokenOverlap(),
Expand Down
4 changes: 2 additions & 2 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
"filename": "src/unitxt/metrics.py",
"hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889",
"is_verified": false,
"line_number": 2370,
"line_number": 2430,
"is_secret": false
}
],
Expand All @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2024-12-26T13:52:54Z"
"generated_at": "2024-12-26T14:29:31Z"
}

0 comments on commit addf0ce

Please sign in to comment.