Skip to content

Commit

Permalink
Format with black
Browse files Browse the repository at this point in the history
  • Loading branch information
mirand863 committed Nov 27, 2024
1 parent dda9e1c commit c401745
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions hiclass/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,12 @@ def _recall_macro(y_true: np.ndarray, y_pred: np.ndarray):
return _compute_macro(y_true, y_pred, _recall_micro)


def f1(y_true: np.ndarray, y_pred: np.ndarray, average: str = "micro", zero_division: str = "warn"):
def f1(
y_true: np.ndarray,
y_pred: np.ndarray,
average: str = "micro",
zero_division: str = "warn",
):
r"""
Compute hierarchical f-score.
Expand Down Expand Up @@ -311,13 +316,19 @@ def _f_score_macro(y_true: np.ndarray, y_pred: np.ndarray, zero_division):
return _compute_macro(y_true, y_pred, _f_score_micro, zero_division)


def _compute_macro(y_true: np.ndarray, y_pred: np.ndarray, _micro_function, zero_division=None):
def _compute_macro(
y_true: np.ndarray, y_pred: np.ndarray, _micro_function, zero_division=None
):
overall_sum = 0
for ground_truth, prediction in zip(y_true, y_pred):
if zero_division:
sample_score = _micro_function(np.array([ground_truth]), np.array([prediction]), zero_division)
sample_score = _micro_function(
np.array([ground_truth]), np.array([prediction]), zero_division
)
else:
sample_score = _micro_function(np.array([ground_truth]), np.array([prediction]))
sample_score = _micro_function(
np.array([ground_truth]), np.array([prediction])
)
overall_sum = overall_sum + sample_score
return overall_sum / len(y_true)

Expand Down

0 comments on commit c401745

Please sign in to comment.