From ab599cc314b15afa287e270bdd0831ec50ccf8ec Mon Sep 17 00:00:00 2001 From: Maxwell-Jia Date: Sun, 22 Dec 2024 19:03:56 +0800 Subject: [PATCH 1/2] Fix size attribute error for scalar outputs in precision/recall/f1 metrics --- metrics/f1/f1.py | 2 +- metrics/precision/precision.py | 2 +- metrics/recall/recall.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/metrics/f1/f1.py b/metrics/f1/f1.py index fe7683489..644831b3f 100644 --- a/metrics/f1/f1.py +++ b/metrics/f1/f1.py @@ -127,4 +127,4 @@ def _compute(self, predictions, references, labels=None, pos_label=1, average="b score = f1_score( references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight ) - return {"f1": float(score) if score.size == 1 else score} + return {"f1": score if getattr(score, 'size', 1) > 1 else float(score)} diff --git a/metrics/precision/precision.py b/metrics/precision/precision.py index 4b35aa7e4..1d077d27c 100644 --- a/metrics/precision/precision.py +++ b/metrics/precision/precision.py @@ -142,4 +142,4 @@ def _compute( sample_weight=sample_weight, zero_division=zero_division, ) - return {"precision": float(score) if score.size == 1 else score} + return {"precision": score if getattr(score, 'size', 1) > 1 else float(score)} diff --git a/metrics/recall/recall.py b/metrics/recall/recall.py index 8522cfcf6..afbcd66cf 100644 --- a/metrics/recall/recall.py +++ b/metrics/recall/recall.py @@ -132,4 +132,4 @@ def _compute( sample_weight=sample_weight, zero_division=zero_division, ) - return {"recall": float(score) if score.size == 1 else score} + return {"recall": score if getattr(score, 'size', 1) > 1 else float(score)} From 278657ca0514955cc292450d406ec77b881759fe Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 10 Jan 2025 15:35:02 +0100 Subject: [PATCH 2/2] Run formatting --- metrics/f1/f1.py | 2 +- metrics/precision/precision.py | 2 +- metrics/recall/recall.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/metrics/f1/f1.py b/metrics/f1/f1.py index 644831b3f..05b0baadc 100644 --- a/metrics/f1/f1.py +++ b/metrics/f1/f1.py @@ -127,4 +127,4 @@ def _compute(self, predictions, references, labels=None, pos_label=1, average="b score = f1_score( references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight ) - return {"f1": score if getattr(score, 'size', 1) > 1 else float(score)} + return {"f1": score if getattr(score, "size", 1) > 1 else float(score)} diff --git a/metrics/precision/precision.py b/metrics/precision/precision.py index 1d077d27c..170d0e5dd 100644 --- a/metrics/precision/precision.py +++ b/metrics/precision/precision.py @@ -142,4 +142,4 @@ def _compute( sample_weight=sample_weight, zero_division=zero_division, ) - return {"precision": score if getattr(score, 'size', 1) > 1 else float(score)} + return {"precision": score if getattr(score, "size", 1) > 1 else float(score)} diff --git a/metrics/recall/recall.py b/metrics/recall/recall.py index afbcd66cf..1c20afc46 100644 --- a/metrics/recall/recall.py +++ b/metrics/recall/recall.py @@ -132,4 +132,4 @@ def _compute( sample_weight=sample_weight, zero_division=zero_division, ) - return {"recall": score if getattr(score, 'size', 1) > 1 else float(score)} + return {"recall": score if getattr(score, "size", 1) > 1 else float(score)}