Skip to content

Commit

Permalink
Refactor metrics to use fast variants and update configurations for a…
Browse files Browse the repository at this point in the history
…ccuracy and F1 metrics

Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel committed Dec 26, 2024
1 parent addf0ce commit 5f2efae
Show file tree
Hide file tree
Showing 11 changed files with 204 additions and 58 deletions.
4 changes: 2 additions & 2 deletions prepare/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from unitxt import add_to_catalog
from unitxt.metrics import Accuracy, BinaryAccuracy, BinaryMaxAccuracy
from unitxt.metrics import AccuracyFast, BinaryAccuracy, BinaryMaxAccuracy
from unitxt.test_utils.metrics import test_metric

metric = Accuracy()
metric = AccuracyFast()

predictions = ["A", "B", "C"]
references = [["B"], ["A"], ["C"]]
Expand Down
7 changes: 3 additions & 4 deletions prepare/metrics/f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@
from unitxt.metrics import (
BinaryMaxF1,
F1Binary,
F1Macro,
F1Fast,
F1MacroMultiLabel,
F1Micro,
F1MicroMultiLabel,
F1Strings,
F1Weighted,
PrecisionBinary,
RecallBinary,
)

metric = F1Macro()
metric = F1Fast(main_score="f1_macro", averages=["macro", "per_class"])
add_to_catalog(metric, "metrics.f1_macro", overwrite=True)

metric = F1Micro()
metric = F1Fast(main_score="f1_micro", averages=["micro", "per_class"])
add_to_catalog(metric, "metrics.f1_micro", overwrite=True)

metric = F1MacroMultiLabel()
Expand Down
44 changes: 38 additions & 6 deletions prepare/metrics/meteor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from unitxt import add_to_catalog
from unitxt.metrics import HuggingfaceMetric, Meteor
from unitxt.metrics import HuggingfaceMetric, MeteorFast
from unitxt.test_utils.metrics import test_metric

metric = Meteor(n_resamples=3)
metric = MeteorFast(
__description__="""METEOR (Metric for Evaluation of Translation with Explicit ORdering) is a machine translation evaluation metric, which is calculated based on the harmonic mean of precision and recall, with recall weighted more than precision.
METEOR is based on a generalized concept of unigram matching between the machine-produced translation and human-produced reference translations. Unigrams can be matched based on their surface forms, stemmed forms, and meanings. Once all generalized unigram matches between the two strings have been found, METEOR computes a score for this matching using a combination of unigram-precision, unigram-recall, and a measure of fragmentation that is designed to directly capture how well-ordered the matched words in the machine translation are in relation to the reference.
"""
)

predictions = [
"It is a guide to action which ensures that the military always obeys the commands of the party",
Expand Down Expand Up @@ -30,11 +35,11 @@

global_target = {
"meteor": 0.58,
"meteor_ci_high": 0.59,
"meteor_ci_low": 0.58,
"meteor_ci_high": 0.67,
"meteor_ci_low": 0.48,
"score": 0.58,
"score_ci_high": 0.59,
"score_ci_low": 0.58,
"score_ci_high": 0.67,
"score_ci_low": 0.48,
"score_name": "meteor",
"num_of_instances": 4,
}
Expand All @@ -49,6 +54,32 @@
global_target=global_target,
)

metric_hf = MeteorFast(
n_resamples=3,
__description__="""Huggingface version with bad confidence interval calculation of METEOR (Metric for Evaluation of Translation with Explicit ORdering) is a machine translation evaluation metric, which is calculated based on the harmonic mean of precision and recall, with recall weighted more than precision.
METEOR is based on a generalized concept of unigram matching between the machine-produced translation and human-produced reference translations. Unigrams can be matched based on their surface forms, stemmed forms, and meanings. Once all generalized unigram matches between the two strings have been found, METEOR computes a score for this matching using a combination of unigram-precision, unigram-recall, and a measure of fragmentation that is designed to directly capture how well-ordered the matched words in the machine translation are in relation to the reference.
""",
)
global_target = {
"meteor": 0.58,
"meteor_ci_high": 0.59,
"meteor_ci_low": 0.58,
"num_of_instances": 4,
"score": 0.58,
"score_ci_high": 0.59,
"score_ci_low": 0.58,
"score_name": "meteor",
}

outputs = test_metric(
metric=metric_hf,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)

# compare results with the HF version of meteor
metric2 = HuggingfaceMetric(
hf_metric_name="meteor", main_score="meteor", prediction_type=str
Expand All @@ -63,3 +94,4 @@
)

add_to_catalog(metric, "metrics.meteor", overwrite=True)
add_to_catalog(metric_hf, "metrics.meteor_hf", overwrite=True)
2 changes: 1 addition & 1 deletion src/unitxt/catalog/metrics/accuracy.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"__type__": "accuracy"
"__type__": "accuracy_fast"
}
7 changes: 6 additions & 1 deletion src/unitxt/catalog/metrics/f1_macro.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
{
"__type__": "f1_macro"
"__type__": "f1_fast",
"main_score": "f1_macro",
"averages": [
"macro",
"per_class"
]
}
7 changes: 6 additions & 1 deletion src/unitxt/catalog/metrics/f1_micro.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
{
"__type__": "f1_micro"
"__type__": "f1_fast",
"main_score": "f1_micro",
"averages": [
"micro",
"per_class"
]
}
4 changes: 2 additions & 2 deletions src/unitxt/catalog/metrics/meteor.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"__type__": "meteor",
"n_resamples": 3
"__type__": "meteor_fast",
"__description__": "METEOR (Metric for Evaluation of Translation with Explicit ORdering) is a machine translation evaluation metric, which is calculated based on the harmonic mean of precision and recall, with recall weighted more than precision.\n\nMETEOR is based on a generalized concept of unigram matching between the machine-produced translation and human-produced reference translations. Unigrams can be matched based on their surface forms, stemmed forms, and meanings. Once all generalized unigram matches between the two strings have been found, METEOR computes a score for this matching using a combination of unigram-precision, unigram-recall, and a measure of fragmentation that is designed to directly capture how well-ordered the matched words in the machine translation are in relation to the reference.\n"
}
5 changes: 5 additions & 0 deletions src/unitxt/catalog/metrics/meteor_hf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"__type__": "meteor_fast",
"n_resamples": 3,
"__description__": "Huggingface version with bad confidence interval calculation of METEOR (Metric for Evaluation of Translation with Explicit ORdering) is a machine translation evaluation metric, which is calculated based on the harmonic mean of precision and recall, with recall weighted more than precision.\n\nMETEOR is based on a generalized concept of unigram matching between the machine-produced translation and human-produced reference translations. Unigrams can be matched based on their surface forms, stemmed forms, and meanings. Once all generalized unigram matches between the two strings have been found, METEOR computes a score for this matching using a combination of unigram-precision, unigram-recall, and a measure of fragmentation that is designed to directly capture how well-ordered the matched words in the machine translation are in relation to the reference.\n"
}
129 changes: 91 additions & 38 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections import Counter, defaultdict, namedtuple
from dataclasses import field
from functools import lru_cache
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union

import numpy
import numpy as np
Expand Down Expand Up @@ -396,6 +396,7 @@ class MapReduceMetric(
ConfidenceIntervalMixin,
Generic[PredictionType, IntermediateType],
):
score_prefix = ""
reference_field: str = NonPositionalField(default="references")
prediction_field: str = NonPositionalField(default="prediction")

Expand All @@ -417,9 +418,12 @@ def reduce(self, intermediates: List[IntermediateType]) -> Dict[str, Any]:
def disable_confidence_interval_calculation(self):
self.n_resamples = None

def annotate_main_score(self, scores):
scores["score_name"] = self.main_score
scores["score"] = scores[self.main_score]
def annotate_scores(self, scores):
scores = {
**{self.score_prefix + key: val for key, val in scores.items()},
"score_name": self.score_prefix + self.main_score,
"score": scores[self.main_score],
}
for level in ["high", "low"]:
if f"{self.main_score}_ci_{level}" in scores:
scores[f"score_ci_{level}"] = scores[f"{self.main_score}_ci_{level}"]
Expand Down Expand Up @@ -497,11 +501,13 @@ def compute(self, stream: Stream, stream_name: Optional[str] = None):
instances_scores = []
for intermediate in intermediates_list:
instance_score = self.reduce_one(intermediate)
instance_score = self.annotate_main_score(instance_score)
instance_score = self.annotate_scores(instance_score)
instances_scores.append(instance_score)

global_scores = self.reduce_and_bootstrap(intermediates_list)
global_scores = self.annotate_main_score(global_scores)
global_scores = self.annotate_scores(global_scores)

global_scores["num_of_instances"] = len(intermediates_list)

return instances_scores, global_scores

Expand Down Expand Up @@ -575,6 +581,12 @@ def map(

class F1Fast(MapReduceMetric[str, Tuple[int, int, List[str]]]):
main_score = "f1"
averages: List[Literal["f1", "macro", "micro", "per_class"]] = [
"f1",
"micro",
"macro",
"per_class",
]

def prepare(self):
super().prepare()
Expand Down Expand Up @@ -603,28 +615,47 @@ def reduce(self, intermediates: List[Tuple[int, int, List[str]]]) -> Dict[str, A
y_pred.append(pred_idx)
y_true.append(ref_idx)

# Compute F1 scores
f1_macro = self._metric(
y_true, y_pred, average="macro", labels=range(num_classes), zero_division=0
)
f1_micro = self._metric(
y_true, y_pred, average="micro", labels=range(num_classes), zero_division=0
)
# For per-class F1, average=None returns an array of F1 for each label
f1_per_class = self._metric(
y_true, y_pred, average=None, labels=range(num_classes), zero_division=0
)
result = {}

# Create a flat dict of all metrics
result = {
"f1": float(f1_macro), # Use macro-F1 as the "main_score"
"f1_macro": float(f1_macro),
"f1_micro": float(f1_micro),
}
if "f1" in self.averages:
result["f1"] = float(
self._metric(
y_true,
y_pred,
average="macro",
labels=range(num_classes),
zero_division=0,
)
)

# Add class-wise F1 using "f1_class_<class_name>" keys
for class_name, score in zip(all_classes, f1_per_class):
result[f"f1_{class_name}"] = float(score)
if "micro" in self.averages:
result["f1_micro"] = float(
self._metric(
y_true,
y_pred,
average="micro",
labels=range(num_classes),
zero_division=0,
)
)

if "macro" in self.averages:
result["f1_macro"] = float(
self._metric(
y_true,
y_pred,
average="macro",
labels=range(num_classes),
zero_division=0,
)
)

if "per_class" in self.averages:
f1_per_class = self._metric(
y_true, y_pred, average=None, labels=range(num_classes), zero_division=0
)
for class_name, score in zip(all_classes, f1_per_class):
result[f"f1_{class_name}"] = float(score)

return result

Expand Down Expand Up @@ -2149,17 +2180,49 @@ def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> di
return score


class MeteorFast(ReductionInstanceMetric[str, Dict[str, float]]):
main_score = "meteor"
reduction = MeanReduction()
_requirements_list: List[str] = ["nltk>=3.6.6"]
alpha: float = 0.9
beta: int = 3
gamma: float = 0.5

def prepare(self):
super().prepare()
import nltk

nltk.download("wordnet", quiet=True)
nltk.download("omw-1.4", quiet=True)
from nltk import word_tokenize
from nltk.translate import meteor_score

self.word_tokenize = word_tokenize
self.meteor_score = meteor_score

def map(
self, prediction: str, references: List[str], task_data: Dict[str, Any]
) -> Dict[str, float]:
score = self.meteor_score.meteor_score(
[self.word_tokenize(ref) for ref in references],
self.word_tokenize(prediction),
alpha=self.alpha,
beta=self.beta,
gamma=self.gamma,
)
return {self.main_score: score}


class Meteor(InstanceMetric):
main_score = "meteor"
ci_scores = ["meteor"]
reduction_map = {"mean": ["meteor"]}
prediction_type = str

_requirements_list: List[str] = ["nltk"]
_requirements_list: List[str] = ["nltk>=3.6.6"]
alpha: float = 0.9
beta: int = 3
gamma: float = 0.5
# unitxt uses nltk version >= 3.8

def prepare(self):
super().prepare()
Expand All @@ -2173,16 +2236,6 @@ def prepare(self):
self.word_tokenize = word_tokenize
self.meteor_score = meteor_score

def verify(self):
import importlib.metadata as importlib_metadata

from datasets.config import version

nltk_version = version.parse(importlib_metadata.version("nltk"))
assert nltk_version >= version.Version(
"3.6.6"
), "nltk version must be at least 3.6.6"

def compute(self, references, prediction, task_data):
score = self.meteor_score.meteor_score(
[self.word_tokenize(ref) for ref in references],
Expand Down
Loading

0 comments on commit 5f2efae

Please sign in to comment.