diff --git a/flair/training_utils.py b/flair/training_utils.py index ea5b576f0c..ce15bdb6e5 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -10,7 +10,6 @@ from numpy import ndarray from scipy.stats import pearsonr, spearmanr -from scipy.stats._stats_py import PearsonRResult, SignificanceResult from sklearn.metrics import mean_absolute_error, mean_squared_error from torch.optim import Optimizer from torch.utils.data import Dataset @@ -61,10 +60,10 @@ def mean_squared_error(self) -> Union[float, ndarray]: def mean_absolute_error(self): return mean_absolute_error(self.true, self.pred) - def pearsonr(self) -> PearsonRResult: + def pearsonr(self): return pearsonr(self.true, self.pred)[0] - def spearmanr(self) -> SignificanceResult: + def spearmanr(self): return spearmanr(self.true, self.pred)[0] def to_tsv(self) -> str: