Skip to content

Commit

Permalink
Throw exception when trying to estimate c-index from uncomparable data
Browse files Browse the repository at this point in the history
Fixes #117
  • Loading branch information
sebp committed Jun 27, 2020
1 parent 07237ea commit dafe20a
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
18 changes: 18 additions & 0 deletions sksurv/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


class NoComparablePairException(ValueError):
"""An error indicating that data of censored event times
does not contain one or more comparable pairs.
"""
5 changes: 5 additions & 0 deletions sksurv/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from scipy.integrate import trapz
from sklearn.utils import check_consistent_length, check_array

from .exceptions import NoComparablePairException
from .nonparametric import CensoringDistributionEstimator, SurvivalFunctionEstimator
from .util import check_y_survival

Expand Down Expand Up @@ -101,6 +102,10 @@ def _estimate_concordance_index(event_indicator, event_time, estimate, weights,

comparable, tied_time = _get_comparable(event_indicator, event_time, order)

if len(comparable) == 0:
raise NoComparablePairException(
"Data has no comparable pairs, cannot estimate concordance index.")

concordant = 0
discordant = 0
tied_risk = 0
Expand Down
41 changes: 35 additions & 6 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
cumulative_dynamic_auc,
integrated_brier_score,
)
from sksurv.exceptions import NoComparablePairException
from sksurv.nonparametric import kaplan_meier_estimator
from sksurv.preprocessing import OneHotEncoder
from sksurv.util import Surv
Expand All @@ -30,6 +31,20 @@ def whas500_pred():
return event, time, risk


@pytest.fixture
def no_comparable_pairs():
y = numpy.array([(False, 849.), (False, 28.), (False, 55.), (False, 727.),
(False, 505.), (False, 1558.), (False, 1292.), (False, 1737.),
(False, 944.), (False, 750.), (False, 2513.), (False, 472.),
(False, 2417.), (False, 538.), (False, 49.), (False, 723.),
(True, 3563.), (False, 1090.), (False, 1167.), (False, 587.),
(False, 1354.), (False, 910.), (False, 398.), (False, 854.),
(False, 3534.), (False, 280.), (False, 183.), (False, 883.),
(False, 32.), (False, 144.)], dtype=[("event", bool), ("time", float)])
scores = numpy.random.randn(y.shape[0])
return y, scores


def test_concordance_index_no_censoring_all_correct():
time = [1, 5, 6, 11, 34, 45, 46, 50]
event = numpy.repeat(True, len(time))
Expand Down Expand Up @@ -130,7 +145,7 @@ def test_concordance_index_with_tied_risk():
def test_concordance_index_with_almost_tied_risk():
event = [False, True, True, False, True, True, False, False]
time = [1, 5, 6, 11, 34, 45, 46, 50]
estimate = [5, 15, 11, 34, 12+4.5e-9, 3, 9, 12-4.5e-9]
estimate = [5, 15, 11, 34, 12 + 4.5e-9, 3, 9, 12 - 4.5e-9]

c, con, dis, tie_r, tie_t = concordance_index_censored(event, time, estimate)

Expand Down Expand Up @@ -239,6 +254,13 @@ def test_concordance_index_all_finite():
concordance_index_censored(event, time, estimate)


def test_concordance_index_no_comparable(no_comparable_pairs):
y, scores = no_comparable_pairs

with pytest.raises(NoComparablePairException):
concordance_index_censored(y["event"], y["time"], scores)


def assert_uno_c_almost_equal(y_train, y_test, estimate, expected, tau=None):
result = concordance_index_ipcw(y_train, y_test, estimate, tau=tau)
assert_array_equal(result[1:], expected[1:])
Expand Down Expand Up @@ -356,7 +378,7 @@ def uno_c_failure_data(request):
time=(1, 3, 5, 7, 12, 13, 20),
event=(True, False, False, True, True, False, True))
estimate = (5, 8, 13, 11, 9, 7, 4)
match = "time must be smaller than largest "\
match = "time must be smaller than largest " \
"observed time point:"
elif p == 'last_time_uncensored_2':
y_train = Surv.from_arrays(
Expand All @@ -366,7 +388,7 @@ def uno_c_failure_data(request):
time=(1, 23, 5, 27, 12),
event=(True, False, True, True, False))
estimate = (5, 13, 11, 9, 4)
match = "time must be smaller than largest "\
match = "time must be smaller than largest " \
"observed time point:"
elif p == 'zero_prob_1':
y_train = Surv.from_arrays(
Expand All @@ -376,7 +398,7 @@ def uno_c_failure_data(request):
time=(1, 3, 5, 7, 12, 13, 19),
event=(True, False, False, True, True, False, True))
estimate = (5, 8, 13, 11, 9, 7, 4)
match = "censoring survival function is zero "\
match = "censoring survival function is zero " \
"at one or more time points"
elif p == 'zero_prob_2':
y_train = Surv.from_arrays(
Expand All @@ -386,7 +408,7 @@ def uno_c_failure_data(request):
time=(1, 3, 5, 7, 12, 13, 19),
event=(True, False, False, True, True, False, True))
estimate = (5, 8, 13, 11, 9, 7, 4)
match = "censoring survival function is zero "\
match = "censoring survival function is zero " \
"at one or more time points"
elif p == 'zero_prob_3':
y_train = Surv.from_arrays(
Expand All @@ -396,7 +418,7 @@ def uno_c_failure_data(request):
time=(1, 3, 5, 19, 12, 13, 7),
event=(True, False, False, True, True, False, True))
estimate = (5, 8, 13, 11, 9, 7, 4)
match = "censoring survival function is zero "\
match = "censoring survival function is zero " \
"at one or more time points"
else:
assert False
Expand Down Expand Up @@ -425,6 +447,13 @@ def test_uno_c_all_censored():
assert ret_uno == ret_harrell


def test_uno_c_no_comparable(no_comparable_pairs):
y, scores = no_comparable_pairs

with pytest.raises(NoComparablePairException):
concordance_index_ipcw(y, y, scores)


@pytest.fixture()
def uno_auc_data_15():
y = Surv.from_arrays(
Expand Down

0 comments on commit dafe20a

Please sign in to comment.