Skip to content

Commit

Permalink
Throw exception when trying to fit model to data with uncomparable pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
sebp committed Jun 27, 2020
1 parent dafe20a commit ffe79c9
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 9 deletions.
4 changes: 4 additions & 0 deletions sksurv/svm/minlip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings

from ..base import SurvivalAnalysisMixin
from ..exceptions import NoComparablePairException
from ..util import check_arrays_survival
from ._minlip import create_difference_matrix

Expand Down Expand Up @@ -135,6 +136,9 @@ def _get_kernel(self, X, Y=None):

def _fit(self, x, event, time):
D = create_difference_matrix(event.astype(numpy.uint8), time, kind=self.pairs)
if D.shape[0] == 0:
raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")

K = self._get_kernel(x)

if self.solver == "cvxpy":
Expand Down
3 changes: 3 additions & 0 deletions sksurv/svm/naive_survival_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sklearn.utils import check_random_state

from ..base import SurvivalAnalysisMixin
from ..exceptions import NoComparablePairException
from ..util import check_arrays_survival


Expand Down Expand Up @@ -154,6 +155,8 @@ def fit(self, X, y, sample_weight=None):
random_state = check_random_state(self.random_state)

x_pairs, y_pairs = self._get_survival_pairs(X, y, random_state)
if x_pairs.shape[0] == 0:
raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")

self.C = self.alpha
return super().fit(x_pairs, y_pairs, sample_weight=sample_weight)
Expand Down
35 changes: 28 additions & 7 deletions sksurv/svm/survival_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from abc import ABCMeta, abstractmethod
import warnings

import numpy
import numexpr
from scipy.optimize import minimize
from sklearn.base import BaseEstimator
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.utils import check_X_y, check_array, check_consistent_length, check_random_state
from sklearn.utils.extmath import safe_sparse_dot, squared_norm

from scipy.optimize import minimize

import numpy
import numexpr
import warnings

from ._prsvm import survival_constraints_simple, survival_constraints_with_support_vectors
from ..base import SurvivalAnalysisMixin
from ..bintrees import AVLTree, RBTree
from ..exceptions import NoComparablePairException
from ..util import check_arrays_survival
from ._prsvm import survival_constraints_simple, survival_constraints_with_support_vectors


class Counter(object, metaclass=ABCMeta):
Expand Down Expand Up @@ -263,6 +263,9 @@ def __init__(self, x, y, alpha, rank_ratio, timeit=False):
self.data_x = x
self.constraints = survival_constraints_simple(numpy.asarray(y, dtype=numpy.uint8))

if self.constraints.shape[0] == 0:
raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")

self.L = numpy.ones(self.constraints.shape[0])

@property
Expand Down Expand Up @@ -303,6 +306,10 @@ def __init__(self, x, y, alpha, rank_ratio, timeit=False):
self.data_y = numpy.asarray(y, dtype=numpy.uint8)
self._constraints = lambda w: survival_constraints_with_support_vectors(self.data_y, w)

Aw = self._constraints(numpy.zeros(x.shape[1]))
if Aw.shape[0] == 0:
raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")

@property
def n_coefficients(self):
return self.data_x.shape[1]
Expand Down Expand Up @@ -372,8 +379,15 @@ def n_coefficients(self):

def _init_coefficients(self):
w = super()._init_coefficients()
n = w.shape[0]
if self._fit_intercept:
w[0] = self._counter.time.mean()
n -= 1

l_plus, _, l_minus, _ = self._counter.calculate(numpy.zeros(n))
if numpy.all(l_plus == 0) and numpy.all(l_minus == 0):
raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")

return w

def _split_coefficents(self, w):
Expand Down Expand Up @@ -498,8 +512,15 @@ def n_coefficients(self):

def _init_coefficients(self):
w = super()._init_coefficients()
n = w.shape[0]
if self._fit_intercept:
w[0] = self._counter.time.mean()
n -= 1

l_plus, _, l_minus, _ = self._counter.calculate(numpy.zeros(n))
if numpy.all(l_plus == 0) and numpy.all(l_minus == 0):
raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")

return w

def _split_coefficents(self, w):
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ def whas500_sparse_data():
return SparseDataSet(x_dense=x_dense, x_sparse=x_sparse, y=y)


@pytest.fixture
def whas500_uncomparable(make_whas500):
whas500 = make_whas500(to_numeric=True)
i = numpy.argmax(whas500.y["lenfol"])
whas500.y["fstat"][:] = False
whas500.y["fstat"][i] = True
return whas500


@pytest.fixture
def rossi():
"""Load rossi.csv"""
Expand Down
16 changes: 14 additions & 2 deletions tests/test_minlip.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from itertools import product

import numpy
from numpy.testing import assert_array_almost_equal, assert_array_equal
import pytest

from sklearn.exceptions import ConvergenceWarning
from sklearn.preprocessing import scale

from sksurv.svm.minlip import MinlipSurvivalAnalysis, HingeLossSurvivalSVM
from sksurv.datasets import load_gbsg2
from sksurv.exceptions import NoComparablePairException
from sksurv.column import encode_categorical
from sksurv.svm._minlip import create_difference_matrix
from sksurv.svm.minlip import MinlipSurvivalAnalysis, HingeLossSurvivalSVM
from sksurv.testing import assert_cindex_almost_equal
from sksurv.util import Surv

Expand Down Expand Up @@ -693,3 +695,13 @@ def test_max_iter(gbsg2):
with pytest.warns(ConvergenceWarning,
match=r"cvxopt solver did not converge: unknown \(duality gap = [.0-9]+\)"):
m.fit(x, y)


@pytest.mark.parametrize(["model_cls", "solver", "pairs"],
list(product((MinlipSurvivalAnalysis, HingeLossSurvivalSVM),
("cvxpy", "cvxopt", "osqp"),
("all", "nearest", "next"))))
def test_fit_uncomparable(whas500_uncomparable, model_cls, solver, pairs):
ssvm = model_cls(solver=solver, pairs=pairs)
with pytest.raises(NoComparablePairException):
ssvm.fit(whas500_uncomparable.x, whas500_uncomparable.y)
21 changes: 21 additions & 0 deletions tests/test_survival_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sksurv.bintrees import AVLTree, RBTree
from sksurv.column import encode_categorical
from sksurv.datasets import load_whas500, get_x_y
from sksurv.exceptions import NoComparablePairException
from sksurv.io import loadarff
from sksurv.kernels import ClinicalKernelTransform
from sksurv.metrics import concordance_index_censored
Expand Down Expand Up @@ -177,6 +178,13 @@ def test_ranking_with_fit_intercept():
match="fit_intercept=True is only meaningful if rank_ratio < 1.0"):
ssvm.fit(x, y)

@staticmethod
@pytest.mark.parametrize("optimizer", ("simple", "avltree", "direct-count", "PRSVM", "rbtree"))
def test_fit_uncomparable(whas500_uncomparable, optimizer):
ssvm = FastSurvivalSVM(optimizer=optimizer)
with pytest.raises(NoComparablePairException):
ssvm.fit(whas500_uncomparable.x, whas500_uncomparable.y)

@staticmethod
def test_survial_constraints_no_ties():
y = numpy.array([True, True, False, True, False, False, False, False])
Expand Down Expand Up @@ -629,6 +637,13 @@ def test_predict_precomputed_kernel_invalid_shape(make_whas500):
r"Got \(100, 14\) for 500 indexed\."):
ssvm.predict(x_new)

@staticmethod
@pytest.mark.parametrize("optimizer", ("avltree", "rbtree"))
def test_fit_uncomparable(whas500_uncomparable, optimizer):
ssvm = FastKernelSurvivalSVM(optimizer=optimizer)
with pytest.raises(NoComparablePairException):
ssvm.fit(whas500_uncomparable.x, whas500_uncomparable.y)


@pytest.fixture(params=[
SurvivalCounter,
Expand Down Expand Up @@ -766,3 +781,9 @@ def test_fit_with_ties(whas500_with_ties):

cindex = nrsvm.score(x, y)
assert round(abs(cindex - 0.7760582309811175), 7) == 0

@staticmethod
def test_fit_uncomparable(whas500_uncomparable):
ssvm = NaiveSurvivalSVM(loss='squared_hinge', dual=False, tol=1e-8, max_iter=1000, random_state=0)
with pytest.raises(NoComparablePairException):
ssvm.fit(whas500_uncomparable.x, whas500_uncomparable.y)

0 comments on commit ffe79c9

Please sign in to comment.