Skip to content

Commit

Permalink
added new metrics for regression tasks (#364)
Browse files Browse the repository at this point in the history
* added new metrics for regression tasks

* refreactored regression metrics to new dir

---------

Co-authored-by: Ido Amos [email protected] <[email protected]>
Co-authored-by: Ido Amos [email protected] <[email protected]>
  • Loading branch information
3 people authored Sep 10, 2024
1 parent a108f32 commit ea54b8c
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 17 deletions.
2 changes: 1 addition & 1 deletion fuse/eval/examples/examples_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

from fuse.eval.metrics.stat.metrics_stat_common import MetricPearsonCorrelation
from fuse.eval.metrics.regression.metrics import MetricPearsonCorrelation
import numpy as np
import pandas as pd
from collections import OrderedDict
Expand Down
56 changes: 55 additions & 1 deletion fuse/eval/metrics/libs/stat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from typing import Sequence, Union
from scipy.stats import pearsonr
from scipy.stats import pearsonr, spearmanr


class Stat:
Expand Down Expand Up @@ -55,3 +55,57 @@ def pearson_correlation(
results["statistic"] = statistic
results["p_value"] = p_value
return results

@staticmethod
def spearman_correlation(
pred: Union[np.ndarray, Sequence],
target: Union[np.ndarray, Sequence],
mask: Union[np.ndarray, Sequence, None] = None,
) -> dict:
"""
Spearman correlation coefficient measuring the monotonic relationship between two datasets/vectors.
:param pred: prediction values
:param target: target values
:param mask: optional boolean mask. if it is provided, the metric will be applied only to the masked samples
"""
if 0 == len(pred):
return dict(statistic=float("nan"), p_value=float("nan"))

if isinstance(pred, Sequence):
if np.isscalar(pred[0]):
pred = np.array(pred)
else:
pred = np.concatenate(pred)
if isinstance(target, Sequence):
if np.isscalar(target[0]):
target = np.array(target)
else:
target = np.concatenate(target)
if isinstance(mask, Sequence):
if np.isscalar(mask[0]):
mask = np.array(mask).astype("bool")
else:
mask = np.concatenate(mask).astype("bool")
if mask is not None:
pred = pred[mask]
target = target[mask]

pred = pred.squeeze()
target = target.squeeze()
if len(pred.shape) > 1 or len(target.shape) > 1:
raise ValueError(
f"expected 1D vectors. got pred shape: {pred.shape}, target shape: {target.shape}"
)

assert len(pred) == len(
target
), f"Spearman corr expected to get pred and target with same length but got pred={len(pred)} - target={len(target)}"

statistic, p_value = spearmanr(
pred, target, nan_policy="propagate"
) # nans will result in nan outputs

results = {}
results["statistic"] = statistic
results["p_value"] = p_value
return results
Empty file.
131 changes: 131 additions & 0 deletions fuse/eval/metrics/regression/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import List, Optional, Union
from fuse.eval.metrics.libs.stat import Stat
from fuse.eval.metrics.metrics_common import MetricDefault
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error


class MetricPearsonCorrelation(MetricDefault):
def __init__(
self, pred: str, target: str, mask: Optional[str] = None, **kwargs: dict
) -> None:
super().__init__(
pred=pred,
target=target,
mask=mask,
metric_func=Stat.pearson_correlation,
**kwargs,
)


class MetricSpearmanCorrelation(MetricDefault):
def __init__(
self, pred: str, target: str, mask: Optional[str] = None, **kwargs: dict
) -> None:
super().__init__(
pred=pred,
target=target,
mask=mask,
metric_func=Stat.spearman_correlation,
**kwargs,
)


class MetricMAE(MetricDefault):
def __init__(
self,
pred: str,
target: str,
**kwargs: dict,
) -> None:
"""
See MetricDefault for the missing params
:param pred: scalar predictions
:param target: ground truth scalar labels
:param threshold: threshold to apply to both pred and target
:param balanced: optionally to use balanced accuracy (from sklearn) instead of regular accuracy.
"""
super().__init__(
pred=pred,
target=target,
metric_func=self.mae,
**kwargs,
)

def mae(
self,
pred: Union[List, np.ndarray],
target: Union[List, np.ndarray],
**kwargs: dict,
) -> float:
return mean_absolute_error(y_true=target, y_pred=pred)


class MetricMSE(MetricDefault):
def __init__(
self,
pred: str,
target: str,
**kwargs: dict,
) -> None:
"""
Our implementation of standard MSE, current version of scikit dones't support it as a metric.
See MetricDefault for the missing params
:param pred: scalar predictions
:param target: ground truth scalar labels
:param threshold: threshold to apply to both pred and target
:param balanced: optionally to use balanced accuracy (from sklearn) instead of regular accuracy.
"""
super().__init__(
pred=pred,
target=target,
metric_func=self.mse,
**kwargs,
)

def mse(
self,
pred: Union[List, np.ndarray],
target: Union[List, np.ndarray],
**kwargs: dict,
) -> float:
return mean_squared_error(y_true=target, y_pred=pred)


class MetricRMSE(MetricDefault):
def __init__(
self,
pred: str,
target: str,
**kwargs: dict,
) -> None:
"""
See MetricDefault for the missing params
:param pred: scalar predictions
:param target: ground truth scalar labels
:param threshold: threshold to apply to both pred and target
:param balanced: optionally to use balanced accuracy (from sklearn) instead of regular accuracy.
"""
super().__init__(
pred=pred,
target=target,
metric_func=self.mse,
**kwargs,
)

def mse(
self,
pred: Union[List, np.ndarray],
target: Union[List, np.ndarray],
**kwargs: dict,
) -> float:

pred = np.array(pred).flatten()
target = np.array(target).flatten()

assert len(pred) == len(
target
), f"Expected pred and target to have the dimensions but found: {len(pred)} elements in pred and {len(target)} in target"

squared_diff = (pred - target) ** 2
return squared_diff.mean()
16 changes: 1 addition & 15 deletions fuse/eval/metrics/stat/metrics_stat_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, Hashable, Optional, Sequence
from collections import Counter
from fuse.eval.metrics.metrics_common import MetricDefault, MetricWithCollectorBase
from fuse.eval.metrics.libs.stat import Stat
from fuse.eval.metrics.metrics_common import MetricWithCollectorBase


class MetricUniqueValues(MetricWithCollectorBase):
Expand All @@ -20,16 +19,3 @@ def eval(
counter = Counter(values)

return list(counter.items())


class MetricPearsonCorrelation(MetricDefault):
def __init__(
self, pred: str, target: str, mask: Optional[str] = None, **kwargs: dict
) -> None:
super().__init__(
pred=pred,
target=target,
mask=mask,
metric_func=Stat.pearson_correlation,
**kwargs
)

0 comments on commit ea54b8c

Please sign in to comment.