From 7cd647bd6949f1d8b8dbb6c6fd2865572cfac15b Mon Sep 17 00:00:00 2001 From: elronbandel Date: Sun, 29 Dec 2024 12:36:57 +0200 Subject: [PATCH] Refactor F1 metrics to remove 'per_class' average and update related configurations for consistency Signed-off-by: elronbandel --- prepare/metrics/f1.py | 2 +- src/unitxt/catalog/metrics/f1_micro.json | 3 +- src/unitxt/metrics.py | 29 +++++++++++-------- tests/library/test_fusion.py | 36 ++++++++++++++++++------ utils/.secrets.baseline | 4 +-- 5 files changed, 49 insertions(+), 25 deletions(-) diff --git a/prepare/metrics/f1.py b/prepare/metrics/f1.py index 549427fa8..4df38d2cd 100644 --- a/prepare/metrics/f1.py +++ b/prepare/metrics/f1.py @@ -14,7 +14,7 @@ metric = F1Fast(main_score="f1_macro", averages=["macro", "per_class"]) add_to_catalog(metric, "metrics.f1_macro", overwrite=True) -metric = F1Fast(main_score="f1_micro", averages=["micro", "per_class"]) +metric = F1Fast(main_score="f1_micro", averages=["micro"]) add_to_catalog(metric, "metrics.f1_micro", overwrite=True) metric = F1MacroMultiLabel() diff --git a/src/unitxt/catalog/metrics/f1_micro.json b/src/unitxt/catalog/metrics/f1_micro.json index 1e1852f77..74125da8a 100644 --- a/src/unitxt/catalog/metrics/f1_micro.json +++ b/src/unitxt/catalog/metrics/f1_micro.json @@ -2,7 +2,6 @@ "__type__": "f1_fast", "main_score": "f1_micro", "averages": [ - "micro", - "per_class" + "micro" ] } diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index cbb371d35..bd8805318 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -391,7 +391,13 @@ def __new__( def is_original_key(key): - if key.endswith("_ci_low") or key.endswith("_ci_high") or key == "score": + if ( + key.endswith("_ci_low") + or key.endswith("_ci_high") + or key == "score" + or key == "num_of_instances" + or key == "score_name" + ): return False return True @@ -496,18 +502,19 @@ def map_stream( def process(self, stream: Stream, stream_name: Optional[str] = None): instances_scores, global_scores = self.compute(stream, stream_name) - for instance, instance_scores in zip(stream, instances_scores): + for i, (instance, instance_scores) in enumerate(zip(stream, instances_scores)): previous_score = instance.get("score", {"global": {}, "instance": {}}) - for key in global_scores: - if is_original_key(key) and key in previous_score["global"]: - UnitxtWarning( - message=f"Metric '{key}' that has just been evaluated with value {global_scores[key]}, is already recorded " - f"to have value {previous_score['global'][key]} by a previous metric evaluation on this instance or stream. " - f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , " - f"which will yield, in this case, a score named: 'my_second_{key}')", - additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS, - ) + if i == 0: + for key in global_scores: + if is_original_key(key) and key in previous_score["global"]: + UnitxtWarning( + message=f"Metric '{key}' that has just been evaluated with value {global_scores[key]}, is already recorded " + f"to have value {previous_score['global'][key]} by a previous metric evaluation on this instance or stream. " + f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , " + f"which will yield, in this case, a score named: 'my_second_{key}')", + additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS, + ) global_scores = {**previous_score["global"], **global_scores} instance_scores = {**previous_score["instance"], **instance_scores} diff --git a/tests/library/test_fusion.py b/tests/library/test_fusion.py index 26a49d0fb..573917ea1 100644 --- a/tests/library/test_fusion.py +++ b/tests/library/test_fusion.py @@ -396,11 +396,11 @@ def test_end_to_end(self): "num_of_instances": 13, }, "wnli": { - "f1_micro": 0.5, + "f1_macro": 0.357, "f1_entailment": 0.0, "f1_not entailment": 0.714, - "f1_micro_ci_low": 0.2, - "f1_micro_ci_high": 0.762, + "f1_macro_ci_low": 0.182, + "f1_macro_ci_high": 0.467, "f1_entailment_ci_low": None, "f1_entailment_ci_high": None, "f1_not entailment_ci_low": 0.364, @@ -410,14 +410,20 @@ def test_end_to_end(self): "score_ci_high": 0.762, "score_ci_low": 0.2, "num_of_instances": 12, + "accuracy": 0.417, + "accuracy_ci_low": 0.167, + "accuracy_ci_high": 0.667, + "f1_micro": 0.5, + "f1_micro_ci_low": 0.2, + "f1_micro_ci_high": 0.762, "groups": { "template": { "templates.classification.multi_class.relation.default": { - "f1_micro": 0.5, + "f1_macro": 0.357, "f1_entailment": 0.0, "f1_not entailment": 0.714, - "f1_micro_ci_low": 0.2, - "f1_micro_ci_high": 0.762, + "f1_macro_ci_low": 0.182, + "f1_macro_ci_high": 0.467, "f1_entailment_ci_low": None, "f1_entailment_ci_high": None, "f1_not entailment_ci_low": 0.364, @@ -427,16 +433,22 @@ def test_end_to_end(self): "score_ci_high": 0.762, "score_ci_low": 0.2, "num_of_instances": 12, + "accuracy": 0.417, + "accuracy_ci_low": 0.167, + "accuracy_ci_high": 0.667, + "f1_micro": 0.5, + "f1_micro_ci_low": 0.2, + "f1_micro_ci_high": 0.762, } } }, }, "rte": { - "f1_micro": 0.5, + "f1_macro": 0.333, "f1_entailment": 0.0, "f1_not entailment": 0.667, - "f1_micro_ci_low": 0.0, - "f1_micro_ci_high": 0.889, + "f1_macro_ci_low": 0.0, + "f1_macro_ci_high": 0.5, "f1_entailment_ci_low": None, "f1_entailment_ci_high": None, "f1_not entailment_ci_low": 0.0, @@ -446,6 +458,12 @@ def test_end_to_end(self): "score_ci_high": 0.889, "score_ci_low": 0.0, "num_of_instances": 5, + "accuracy": 0.4, + "accuracy_ci_low": 0.0, + "accuracy_ci_high": 0.8, + "f1_micro": 0.5, + "f1_micro_ci_low": 0.0, + "f1_micro_ci_high": 0.889, }, "score": 0.161, "score_name": "subsets_mean", diff --git a/utils/.secrets.baseline b/utils/.secrets.baseline index a5b640d5f..cf583a57e 100644 --- a/utils/.secrets.baseline +++ b/utils/.secrets.baseline @@ -161,7 +161,7 @@ "filename": "src/unitxt/metrics.py", "hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889", "is_verified": false, - "line_number": 2506, + "line_number": 2513, "is_secret": false } ], @@ -184,5 +184,5 @@ } ] }, - "generated_at": "2024-12-29T09:44:05Z" + "generated_at": "2024-12-29T10:35:32Z" }