From ffe79c9044059fed487eb05d3e3c51d4058a3383 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sebastian=20P=C3=B6lsterl?= <sebp@k-d-w.org>
Date: Sat, 27 Jun 2020 17:44:20 +0200
Subject: [PATCH] Throw exception when trying to fit model to data with
 uncomparable pairs

---
 sksurv/svm/minlip.py             |  4 ++++
 sksurv/svm/naive_survival_svm.py |  3 +++
 sksurv/svm/survival_svm.py       | 35 +++++++++++++++++++++++++-------
 tests/conftest.py                |  9 ++++++++
 tests/test_minlip.py             | 16 +++++++++++++--
 tests/test_survival_svm.py       | 21 +++++++++++++++++++
 6 files changed, 79 insertions(+), 9 deletions(-)

diff --git a/sksurv/svm/minlip.py b/sksurv/svm/minlip.py
index 2722a98a..53e177d7 100644
--- a/sksurv/svm/minlip.py
+++ b/sksurv/svm/minlip.py
@@ -6,6 +6,7 @@
 import warnings
 
 from ..base import SurvivalAnalysisMixin
+from ..exceptions import NoComparablePairException
 from ..util import check_arrays_survival
 from ._minlip import create_difference_matrix
 
@@ -135,6 +136,9 @@ def _get_kernel(self, X, Y=None):
 
     def _fit(self, x, event, time):
         D = create_difference_matrix(event.astype(numpy.uint8), time, kind=self.pairs)
+        if D.shape[0] == 0:
+            raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")
+
         K = self._get_kernel(x)
 
         if self.solver == "cvxpy":
diff --git a/sksurv/svm/naive_survival_svm.py b/sksurv/svm/naive_survival_svm.py
index 6439fad0..d6a573f7 100644
--- a/sksurv/svm/naive_survival_svm.py
+++ b/sksurv/svm/naive_survival_svm.py
@@ -18,6 +18,7 @@
 from sklearn.utils import check_random_state
 
 from ..base import SurvivalAnalysisMixin
+from ..exceptions import NoComparablePairException
 from ..util import check_arrays_survival
 
 
@@ -154,6 +155,8 @@ def fit(self, X, y, sample_weight=None):
         random_state = check_random_state(self.random_state)
 
         x_pairs, y_pairs = self._get_survival_pairs(X, y, random_state)
+        if x_pairs.shape[0] == 0:
+            raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")
 
         self.C = self.alpha
         return super().fit(x_pairs, y_pairs, sample_weight=sample_weight)
diff --git a/sksurv/svm/survival_svm.py b/sksurv/svm/survival_svm.py
index 083f4d4a..2b87ac94 100644
--- a/sksurv/svm/survival_svm.py
+++ b/sksurv/svm/survival_svm.py
@@ -11,22 +11,22 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 from abc import ABCMeta, abstractmethod
+import warnings
+
+import numpy
+import numexpr
+from scipy.optimize import minimize
 from sklearn.base import BaseEstimator
 from sklearn.exceptions import ConvergenceWarning
 from sklearn.metrics.pairwise import pairwise_kernels
 from sklearn.utils import check_X_y, check_array, check_consistent_length, check_random_state
 from sklearn.utils.extmath import safe_sparse_dot, squared_norm
 
-from scipy.optimize import minimize
-
-import numpy
-import numexpr
-import warnings
-
-from ._prsvm import survival_constraints_simple, survival_constraints_with_support_vectors
 from ..base import SurvivalAnalysisMixin
 from ..bintrees import AVLTree, RBTree
+from ..exceptions import NoComparablePairException
 from ..util import check_arrays_survival
+from ._prsvm import survival_constraints_simple, survival_constraints_with_support_vectors
 
 
 class Counter(object, metaclass=ABCMeta):
@@ -263,6 +263,9 @@ def __init__(self, x, y, alpha, rank_ratio, timeit=False):
         self.data_x = x
         self.constraints = survival_constraints_simple(numpy.asarray(y, dtype=numpy.uint8))
 
+        if self.constraints.shape[0] == 0:
+            raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")
+
         self.L = numpy.ones(self.constraints.shape[0])
 
     @property
@@ -303,6 +306,10 @@ def __init__(self, x, y, alpha, rank_ratio, timeit=False):
         self.data_y = numpy.asarray(y, dtype=numpy.uint8)
         self._constraints = lambda w: survival_constraints_with_support_vectors(self.data_y, w)
 
+        Aw = self._constraints(numpy.zeros(x.shape[1]))
+        if Aw.shape[0] == 0:
+            raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")
+
     @property
     def n_coefficients(self):
         return self.data_x.shape[1]
@@ -372,8 +379,15 @@ def n_coefficients(self):
 
     def _init_coefficients(self):
         w = super()._init_coefficients()
+        n = w.shape[0]
         if self._fit_intercept:
             w[0] = self._counter.time.mean()
+            n -= 1
+
+        l_plus, _, l_minus, _ = self._counter.calculate(numpy.zeros(n))
+        if numpy.all(l_plus == 0) and numpy.all(l_minus == 0):
+            raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")
+
         return w
 
     def _split_coefficents(self, w):
@@ -498,8 +512,15 @@ def n_coefficients(self):
 
     def _init_coefficients(self):
         w = super()._init_coefficients()
+        n = w.shape[0]
         if self._fit_intercept:
             w[0] = self._counter.time.mean()
+            n -= 1
+
+        l_plus, _, l_minus, _ = self._counter.calculate(numpy.zeros(n))
+        if numpy.all(l_plus == 0) and numpy.all(l_minus == 0):
+            raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")
+
         return w
 
     def _split_coefficents(self, w):
diff --git a/tests/conftest.py b/tests/conftest.py
index 7270fe5b..a21e35dd 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -67,6 +67,15 @@ def whas500_sparse_data():
     return SparseDataSet(x_dense=x_dense, x_sparse=x_sparse, y=y)
 
 
+@pytest.fixture
+def whas500_uncomparable(make_whas500):
+    whas500 = make_whas500(to_numeric=True)
+    i = numpy.argmax(whas500.y["lenfol"])
+    whas500.y["fstat"][:] = False
+    whas500.y["fstat"][i] = True
+    return whas500
+
+
 @pytest.fixture
 def rossi():
     """Load rossi.csv"""
diff --git a/tests/test_minlip.py b/tests/test_minlip.py
index dd80c0ec..31d11ab8 100644
--- a/tests/test_minlip.py
+++ b/tests/test_minlip.py
@@ -1,14 +1,16 @@
+from itertools import product
+
 import numpy
 from numpy.testing import assert_array_almost_equal, assert_array_equal
 import pytest
-
 from sklearn.exceptions import ConvergenceWarning
 from sklearn.preprocessing import scale
 
-from sksurv.svm.minlip import MinlipSurvivalAnalysis, HingeLossSurvivalSVM
 from sksurv.datasets import load_gbsg2
+from sksurv.exceptions import NoComparablePairException
 from sksurv.column import encode_categorical
 from sksurv.svm._minlip import create_difference_matrix
+from sksurv.svm.minlip import MinlipSurvivalAnalysis, HingeLossSurvivalSVM
 from sksurv.testing import assert_cindex_almost_equal
 from sksurv.util import Surv
 
@@ -693,3 +695,13 @@ def test_max_iter(gbsg2):
         with pytest.warns(ConvergenceWarning,
                           match=r"cvxopt solver did not converge: unknown \(duality gap = [.0-9]+\)"):
             m.fit(x, y)
+
+
+@pytest.mark.parametrize(["model_cls", "solver", "pairs"],
+                         list(product((MinlipSurvivalAnalysis, HingeLossSurvivalSVM),
+                                      ("cvxpy", "cvxopt", "osqp"),
+                                      ("all", "nearest", "next"))))
+def test_fit_uncomparable(whas500_uncomparable, model_cls, solver, pairs):
+    ssvm = model_cls(solver=solver, pairs=pairs)
+    with pytest.raises(NoComparablePairException):
+        ssvm.fit(whas500_uncomparable.x, whas500_uncomparable.y)
diff --git a/tests/test_survival_svm.py b/tests/test_survival_svm.py
index 40fbccab..49f02625 100644
--- a/tests/test_survival_svm.py
+++ b/tests/test_survival_svm.py
@@ -13,6 +13,7 @@
 from sksurv.bintrees import AVLTree, RBTree
 from sksurv.column import encode_categorical
 from sksurv.datasets import load_whas500, get_x_y
+from sksurv.exceptions import NoComparablePairException
 from sksurv.io import loadarff
 from sksurv.kernels import ClinicalKernelTransform
 from sksurv.metrics import concordance_index_censored
@@ -177,6 +178,13 @@ def test_ranking_with_fit_intercept():
                            match="fit_intercept=True is only meaningful if rank_ratio < 1.0"):
             ssvm.fit(x, y)
 
+    @staticmethod
+    @pytest.mark.parametrize("optimizer", ("simple", "avltree", "direct-count", "PRSVM", "rbtree"))
+    def test_fit_uncomparable(whas500_uncomparable, optimizer):
+        ssvm = FastSurvivalSVM(optimizer=optimizer)
+        with pytest.raises(NoComparablePairException):
+            ssvm.fit(whas500_uncomparable.x, whas500_uncomparable.y)
+
     @staticmethod
     def test_survial_constraints_no_ties():
         y = numpy.array([True, True, False, True, False, False, False, False])
@@ -629,6 +637,13 @@ def test_predict_precomputed_kernel_invalid_shape(make_whas500):
                                  r"Got \(100, 14\) for 500 indexed\."):
             ssvm.predict(x_new)
 
+    @staticmethod
+    @pytest.mark.parametrize("optimizer", ("avltree", "rbtree"))
+    def test_fit_uncomparable(whas500_uncomparable, optimizer):
+        ssvm = FastKernelSurvivalSVM(optimizer=optimizer)
+        with pytest.raises(NoComparablePairException):
+            ssvm.fit(whas500_uncomparable.x, whas500_uncomparable.y)
+
 
 @pytest.fixture(params=[
     SurvivalCounter,
@@ -766,3 +781,9 @@ def test_fit_with_ties(whas500_with_ties):
 
         cindex = nrsvm.score(x, y)
         assert round(abs(cindex - 0.7760582309811175), 7) == 0
+
+    @staticmethod
+    def test_fit_uncomparable(whas500_uncomparable):
+        ssvm = NaiveSurvivalSVM(loss='squared_hinge', dual=False, tol=1e-8, max_iter=1000, random_state=0)
+        with pytest.raises(NoComparablePairException):
+            ssvm.fit(whas500_uncomparable.x, whas500_uncomparable.y)