Skip to content

Commit

Permalink
Refactor F1 metrics to remove 'per_class' average and update related …
Browse files Browse the repository at this point in the history
…configurations for consistency

Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel committed Dec 29, 2024
1 parent 146cb07 commit 7cd647b
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 25 deletions.
2 changes: 1 addition & 1 deletion prepare/metrics/f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
metric = F1Fast(main_score="f1_macro", averages=["macro", "per_class"])
add_to_catalog(metric, "metrics.f1_macro", overwrite=True)

metric = F1Fast(main_score="f1_micro", averages=["micro", "per_class"])
metric = F1Fast(main_score="f1_micro", averages=["micro"])
add_to_catalog(metric, "metrics.f1_micro", overwrite=True)

metric = F1MacroMultiLabel()
Expand Down
3 changes: 1 addition & 2 deletions src/unitxt/catalog/metrics/f1_micro.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"__type__": "f1_fast",
"main_score": "f1_micro",
"averages": [
"micro",
"per_class"
"micro"
]
}
29 changes: 18 additions & 11 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,13 @@ def __new__(


def is_original_key(key):
if key.endswith("_ci_low") or key.endswith("_ci_high") or key == "score":
if (
key.endswith("_ci_low")
or key.endswith("_ci_high")
or key == "score"
or key == "num_of_instances"
or key == "score_name"
):
return False
return True

Expand Down Expand Up @@ -496,18 +502,19 @@ def map_stream(

def process(self, stream: Stream, stream_name: Optional[str] = None):
instances_scores, global_scores = self.compute(stream, stream_name)
for instance, instance_scores in zip(stream, instances_scores):
for i, (instance, instance_scores) in enumerate(zip(stream, instances_scores)):
previous_score = instance.get("score", {"global": {}, "instance": {}})

for key in global_scores:
if is_original_key(key) and key in previous_score["global"]:
UnitxtWarning(
message=f"Metric '{key}' that has just been evaluated with value {global_scores[key]}, is already recorded "
f"to have value {previous_score['global'][key]} by a previous metric evaluation on this instance or stream. "
f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
f"which will yield, in this case, a score named: 'my_second_{key}')",
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
)
if i == 0:
for key in global_scores:
if is_original_key(key) and key in previous_score["global"]:
UnitxtWarning(
message=f"Metric '{key}' that has just been evaluated with value {global_scores[key]}, is already recorded "
f"to have value {previous_score['global'][key]} by a previous metric evaluation on this instance or stream. "
f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
f"which will yield, in this case, a score named: 'my_second_{key}')",
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
)

global_scores = {**previous_score["global"], **global_scores}
instance_scores = {**previous_score["instance"], **instance_scores}
Expand Down
36 changes: 27 additions & 9 deletions tests/library/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,11 @@ def test_end_to_end(self):
"num_of_instances": 13,
},
"wnli": {
"f1_micro": 0.5,
"f1_macro": 0.357,
"f1_entailment": 0.0,
"f1_not entailment": 0.714,
"f1_micro_ci_low": 0.2,
"f1_micro_ci_high": 0.762,
"f1_macro_ci_low": 0.182,
"f1_macro_ci_high": 0.467,
"f1_entailment_ci_low": None,
"f1_entailment_ci_high": None,
"f1_not entailment_ci_low": 0.364,
Expand All @@ -410,14 +410,20 @@ def test_end_to_end(self):
"score_ci_high": 0.762,
"score_ci_low": 0.2,
"num_of_instances": 12,
"accuracy": 0.417,
"accuracy_ci_low": 0.167,
"accuracy_ci_high": 0.667,
"f1_micro": 0.5,
"f1_micro_ci_low": 0.2,
"f1_micro_ci_high": 0.762,
"groups": {
"template": {
"templates.classification.multi_class.relation.default": {
"f1_micro": 0.5,
"f1_macro": 0.357,
"f1_entailment": 0.0,
"f1_not entailment": 0.714,
"f1_micro_ci_low": 0.2,
"f1_micro_ci_high": 0.762,
"f1_macro_ci_low": 0.182,
"f1_macro_ci_high": 0.467,
"f1_entailment_ci_low": None,
"f1_entailment_ci_high": None,
"f1_not entailment_ci_low": 0.364,
Expand All @@ -427,16 +433,22 @@ def test_end_to_end(self):
"score_ci_high": 0.762,
"score_ci_low": 0.2,
"num_of_instances": 12,
"accuracy": 0.417,
"accuracy_ci_low": 0.167,
"accuracy_ci_high": 0.667,
"f1_micro": 0.5,
"f1_micro_ci_low": 0.2,
"f1_micro_ci_high": 0.762,
}
}
},
},
"rte": {
"f1_micro": 0.5,
"f1_macro": 0.333,
"f1_entailment": 0.0,
"f1_not entailment": 0.667,
"f1_micro_ci_low": 0.0,
"f1_micro_ci_high": 0.889,
"f1_macro_ci_low": 0.0,
"f1_macro_ci_high": 0.5,
"f1_entailment_ci_low": None,
"f1_entailment_ci_high": None,
"f1_not entailment_ci_low": 0.0,
Expand All @@ -446,6 +458,12 @@ def test_end_to_end(self):
"score_ci_high": 0.889,
"score_ci_low": 0.0,
"num_of_instances": 5,
"accuracy": 0.4,
"accuracy_ci_low": 0.0,
"accuracy_ci_high": 0.8,
"f1_micro": 0.5,
"f1_micro_ci_low": 0.0,
"f1_micro_ci_high": 0.889,
},
"score": 0.161,
"score_name": "subsets_mean",
Expand Down
4 changes: 2 additions & 2 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
"filename": "src/unitxt/metrics.py",
"hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889",
"is_verified": false,
"line_number": 2506,
"line_number": 2513,
"is_secret": false
}
],
Expand All @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2024-12-29T09:44:05Z"
"generated_at": "2024-12-29T10:35:32Z"
}

0 comments on commit 7cd647b

Please sign in to comment.