diff --git a/sksurv/svm/minlip.py b/sksurv/svm/minlip.py index 2722a98a..53e177d7 100644 --- a/sksurv/svm/minlip.py +++ b/sksurv/svm/minlip.py @@ -6,6 +6,7 @@ import warnings from ..base import SurvivalAnalysisMixin +from ..exceptions import NoComparablePairException from ..util import check_arrays_survival from ._minlip import create_difference_matrix @@ -135,6 +136,9 @@ def _get_kernel(self, X, Y=None): def _fit(self, x, event, time): D = create_difference_matrix(event.astype(numpy.uint8), time, kind=self.pairs) + if D.shape[0] == 0: + raise NoComparablePairException("Data has no comparable pairs, cannot fit model.") + K = self._get_kernel(x) if self.solver == "cvxpy": diff --git a/sksurv/svm/naive_survival_svm.py b/sksurv/svm/naive_survival_svm.py index 6439fad0..d6a573f7 100644 --- a/sksurv/svm/naive_survival_svm.py +++ b/sksurv/svm/naive_survival_svm.py @@ -18,6 +18,7 @@ from sklearn.utils import check_random_state from ..base import SurvivalAnalysisMixin +from ..exceptions import NoComparablePairException from ..util import check_arrays_survival @@ -154,6 +155,8 @@ def fit(self, X, y, sample_weight=None): random_state = check_random_state(self.random_state) x_pairs, y_pairs = self._get_survival_pairs(X, y, random_state) + if x_pairs.shape[0] == 0: + raise NoComparablePairException("Data has no comparable pairs, cannot fit model.") self.C = self.alpha return super().fit(x_pairs, y_pairs, sample_weight=sample_weight) diff --git a/sksurv/svm/survival_svm.py b/sksurv/svm/survival_svm.py index 083f4d4a..2b87ac94 100644 --- a/sksurv/svm/survival_svm.py +++ b/sksurv/svm/survival_svm.py @@ -11,22 +11,22 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . from abc import ABCMeta, abstractmethod +import warnings + +import numpy +import numexpr +from scipy.optimize import minimize from sklearn.base import BaseEstimator from sklearn.exceptions import ConvergenceWarning from sklearn.metrics.pairwise import pairwise_kernels from sklearn.utils import check_X_y, check_array, check_consistent_length, check_random_state from sklearn.utils.extmath import safe_sparse_dot, squared_norm -from scipy.optimize import minimize - -import numpy -import numexpr -import warnings - -from ._prsvm import survival_constraints_simple, survival_constraints_with_support_vectors from ..base import SurvivalAnalysisMixin from ..bintrees import AVLTree, RBTree +from ..exceptions import NoComparablePairException from ..util import check_arrays_survival +from ._prsvm import survival_constraints_simple, survival_constraints_with_support_vectors class Counter(object, metaclass=ABCMeta): @@ -263,6 +263,9 @@ def __init__(self, x, y, alpha, rank_ratio, timeit=False): self.data_x = x self.constraints = survival_constraints_simple(numpy.asarray(y, dtype=numpy.uint8)) + if self.constraints.shape[0] == 0: + raise NoComparablePairException("Data has no comparable pairs, cannot fit model.") + self.L = numpy.ones(self.constraints.shape[0]) @property @@ -303,6 +306,10 @@ def __init__(self, x, y, alpha, rank_ratio, timeit=False): self.data_y = numpy.asarray(y, dtype=numpy.uint8) self._constraints = lambda w: survival_constraints_with_support_vectors(self.data_y, w) + Aw = self._constraints(numpy.zeros(x.shape[1])) + if Aw.shape[0] == 0: + raise NoComparablePairException("Data has no comparable pairs, cannot fit model.") + @property def n_coefficients(self): return self.data_x.shape[1] @@ -372,8 +379,15 @@ def n_coefficients(self): def _init_coefficients(self): w = super()._init_coefficients() + n = w.shape[0] if self._fit_intercept: w[0] = self._counter.time.mean() + n -= 1 + + l_plus, _, l_minus, _ = self._counter.calculate(numpy.zeros(n)) + if numpy.all(l_plus == 0) and numpy.all(l_minus == 0): + raise NoComparablePairException("Data has no comparable pairs, cannot fit model.") + return w def _split_coefficents(self, w): @@ -498,8 +512,15 @@ def n_coefficients(self): def _init_coefficients(self): w = super()._init_coefficients() + n = w.shape[0] if self._fit_intercept: w[0] = self._counter.time.mean() + n -= 1 + + l_plus, _, l_minus, _ = self._counter.calculate(numpy.zeros(n)) + if numpy.all(l_plus == 0) and numpy.all(l_minus == 0): + raise NoComparablePairException("Data has no comparable pairs, cannot fit model.") + return w def _split_coefficents(self, w): diff --git a/tests/conftest.py b/tests/conftest.py index 7270fe5b..a21e35dd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -67,6 +67,15 @@ def whas500_sparse_data(): return SparseDataSet(x_dense=x_dense, x_sparse=x_sparse, y=y) +@pytest.fixture +def whas500_uncomparable(make_whas500): + whas500 = make_whas500(to_numeric=True) + i = numpy.argmax(whas500.y["lenfol"]) + whas500.y["fstat"][:] = False + whas500.y["fstat"][i] = True + return whas500 + + @pytest.fixture def rossi(): """Load rossi.csv""" diff --git a/tests/test_minlip.py b/tests/test_minlip.py index dd80c0ec..31d11ab8 100644 --- a/tests/test_minlip.py +++ b/tests/test_minlip.py @@ -1,14 +1,16 @@ +from itertools import product + import numpy from numpy.testing import assert_array_almost_equal, assert_array_equal import pytest - from sklearn.exceptions import ConvergenceWarning from sklearn.preprocessing import scale -from sksurv.svm.minlip import MinlipSurvivalAnalysis, HingeLossSurvivalSVM from sksurv.datasets import load_gbsg2 +from sksurv.exceptions import NoComparablePairException from sksurv.column import encode_categorical from sksurv.svm._minlip import create_difference_matrix +from sksurv.svm.minlip import MinlipSurvivalAnalysis, HingeLossSurvivalSVM from sksurv.testing import assert_cindex_almost_equal from sksurv.util import Surv @@ -693,3 +695,13 @@ def test_max_iter(gbsg2): with pytest.warns(ConvergenceWarning, match=r"cvxopt solver did not converge: unknown \(duality gap = [.0-9]+\)"): m.fit(x, y) + + +@pytest.mark.parametrize(["model_cls", "solver", "pairs"], + list(product((MinlipSurvivalAnalysis, HingeLossSurvivalSVM), + ("cvxpy", "cvxopt", "osqp"), + ("all", "nearest", "next")))) +def test_fit_uncomparable(whas500_uncomparable, model_cls, solver, pairs): + ssvm = model_cls(solver=solver, pairs=pairs) + with pytest.raises(NoComparablePairException): + ssvm.fit(whas500_uncomparable.x, whas500_uncomparable.y) diff --git a/tests/test_survival_svm.py b/tests/test_survival_svm.py index 40fbccab..49f02625 100644 --- a/tests/test_survival_svm.py +++ b/tests/test_survival_svm.py @@ -13,6 +13,7 @@ from sksurv.bintrees import AVLTree, RBTree from sksurv.column import encode_categorical from sksurv.datasets import load_whas500, get_x_y +from sksurv.exceptions import NoComparablePairException from sksurv.io import loadarff from sksurv.kernels import ClinicalKernelTransform from sksurv.metrics import concordance_index_censored @@ -177,6 +178,13 @@ def test_ranking_with_fit_intercept(): match="fit_intercept=True is only meaningful if rank_ratio < 1.0"): ssvm.fit(x, y) + @staticmethod + @pytest.mark.parametrize("optimizer", ("simple", "avltree", "direct-count", "PRSVM", "rbtree")) + def test_fit_uncomparable(whas500_uncomparable, optimizer): + ssvm = FastSurvivalSVM(optimizer=optimizer) + with pytest.raises(NoComparablePairException): + ssvm.fit(whas500_uncomparable.x, whas500_uncomparable.y) + @staticmethod def test_survial_constraints_no_ties(): y = numpy.array([True, True, False, True, False, False, False, False]) @@ -629,6 +637,13 @@ def test_predict_precomputed_kernel_invalid_shape(make_whas500): r"Got \(100, 14\) for 500 indexed\."): ssvm.predict(x_new) + @staticmethod + @pytest.mark.parametrize("optimizer", ("avltree", "rbtree")) + def test_fit_uncomparable(whas500_uncomparable, optimizer): + ssvm = FastKernelSurvivalSVM(optimizer=optimizer) + with pytest.raises(NoComparablePairException): + ssvm.fit(whas500_uncomparable.x, whas500_uncomparable.y) + @pytest.fixture(params=[ SurvivalCounter, @@ -766,3 +781,9 @@ def test_fit_with_ties(whas500_with_ties): cindex = nrsvm.score(x, y) assert round(abs(cindex - 0.7760582309811175), 7) == 0 + + @staticmethod + def test_fit_uncomparable(whas500_uncomparable): + ssvm = NaiveSurvivalSVM(loss='squared_hinge', dual=False, tol=1e-8, max_iter=1000, random_state=0) + with pytest.raises(NoComparablePairException): + ssvm.fit(whas500_uncomparable.x, whas500_uncomparable.y)