Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brier fix #429

Merged
merged 24 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# mlr3proba 0.7.1

* Removed all `PipeOp`s and pipelines related to survival => regression reduction techniques (see #414)
* Bug fix: `$predict_type` of `survtoclassif_disctime` and `survtoclassif_IPCW` was `prob` (classification type) and not `crank` (survival type)
* cleanup: removed all `PipeOp`s and pipelines related to survival => regression reduction techniques (see #414)
* fix: `$predict_type` of `survtoclassif_disctime` and `survtoclassif_IPCW` was `prob` (classification type) and not `crank` (survival type)
* fix: G(t) is not filtered when `t_max|p_max` is specified in scoring rules (didn't influence evaluation at all)
* docs: Clarified the use and impact of using `t_max` in scoring rules, added examples in scoring rules and AUC scores
* feat: Added new argument `remove_obs` in scoring rules to remove observations with observed time `t > t_max` as a processing step to alleviate IPCW issues.
This was before 'hard-coded' which made the Integrated Brier Score (`msr("surv.graf")`) differ minimally from other implementations and the original definition.

# mlr3proba 0.7.0

Expand Down
1 change: 1 addition & 0 deletions R/MeasureSurvChamblessAUC.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#'
#' @family AUC survival measures
#' @family lp survival measures
#' @template example_auc_measures
#' @export
MeasureSurvChamblessAUC = R6Class("MeasureSurvChamblessAUC",
inherit = MeasureSurvAUC,
Expand Down
1 change: 1 addition & 0 deletions R/MeasureSurvCindex.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#'
#' # Harrell's C-index evaluated up to a specific time horizon
#' p$score(msr("surv.cindex", t_max = 97))
#'
#' # Harrell's C-index evaluated up to the time corresponding to 30% of censoring
#' p$score(msr("surv.cindex", p_max = 0.3))
#'
Expand Down
19 changes: 11 additions & 8 deletions R/MeasureSurvDCalibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#' @templateVar fullname MeasureSurvDCalibration
#'
#' @description
#' `r lifecycle::badge("experimental")`
#'
#' This calibration method is defined by calculating the following statistic:
#' \deqn{s = B/n \sum_i (P_i - n/B)^2}
#' where \eqn{B} is number of 'buckets' (that equally divide \eqn{[0,1]} into intervals),
Expand All @@ -12,8 +14,8 @@
#' falls within the corresponding interval.
#' This statistic assumes that censoring time is independent of death time.
#'
#' A model is well-calibrated if \eqn{s \sim Unif(B)}, tested with `chisq.test`
#' (\eqn{p > 0.05} if well-calibrated).
#' A model is well D-calibrated if \eqn{s \sim Unif(B)}, tested with `chisq.test`
#' (\eqn{p > 0.05} if well-calibrated, i.e. higher p-values are preferred).
#' Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)},
#' meaning that *lower values* of this measure are preferred.
#'
Expand All @@ -23,7 +25,7 @@
#' is well-calibrated. If `chisq = FALSE` and `s` is the predicted value then you can manually
#' compute the p.value with `pchisq(s, B - 1, lower.tail = FALSE)`.
#'
#' NOTE: This measure is still experimental both theoretically and in implementation. Results
#' **NOTE**: This measure is still experimental both theoretically and in implementation. Results
#' should therefore only be taken as an indicator of performance and not for
#' conclusive judgements about model calibration.
#'
Expand All @@ -38,11 +40,12 @@
#' You can manually get the p-value by executing `pchisq(s, B - 1, lower.tail = FALSE)`.
#' The null hypothesis is that the model is D-calibrated.
#' - `truncate` (`double(1)`) \cr
#' This parameter controls the upper bound of the output statistic,
#' when `chisq` is `FALSE`. We use `truncate = Inf` by default but \eqn{10} may be sufficient
#' for most purposes, which corresponds to a p-value of 0.35 for the chisq.test using
#' \eqn{B = 10} buckets. Values \eqn{>10} translate to even lower p-values and thus
#' less calibrated models. If the number of buckets \eqn{B} changes, you probably will want to
#' This parameter controls the upper bound of the output statistic, when `chisq` is `FALSE`.
#' We use `truncate = Inf` by default but values between \eqn{10-16} are sufficient
#' for most purposes, which correspond to p-values of \eqn{0.35-0.06} for the `chisq.test` using
#' the default \eqn{B = 10} buckets.
#' Values \eqn{B > 10} translate to even lower p-values and thus less D-calibrated models.
#' If the number of buckets \eqn{B} changes, you probably will want to
#' change the `truncate` value as well to correspond to the same p-value significance.
#' Note that truncation may severely limit automated tuning with this measure.
#'
Expand Down
15 changes: 9 additions & 6 deletions R/MeasureSurvGraf.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#' @templateVar eps 1e-3
#' @template param_eps
#' @template param_erv
#' @template param_remove_obs
#'
#' @aliases MeasureSurvBrier mlr_measures_surv.brier
#'
Expand All @@ -25,13 +26,13 @@
#' survival function \eqn{S_i(t)}, the *observation-wise* loss integrated across
#' the time dimension up to the time cutoff \eqn{\tau^*}, is:
#'
#' \deqn{L_{ISBS}(S_i, t_i, \delta_i) = \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{S_i^2(\tau) \text{I}(t_i \leq \tau, \delta=1)}{G(t_i)} + \frac{(1-S_i(\tau))^2 \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
#' \deqn{L_{ISBS}(S_i, t_i, \delta_i) = \int^{\tau^*}_0 \frac{S_i^2(\tau) \text{I}(t_i \leq \tau, \delta_i=1)}{G(t_i)} + \frac{(1-S_i(\tau))^2 \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
#'
#' where \eqn{G} is the Kaplan-Meier estimate of the censoring distribution.
#'
#' The **re-weighted ISBS** (RISBS) is:
#'
#' \deqn{L_{RISBS}(S_i, t_i, \delta_i) = \delta_i \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{S_i^2(\tau) \text{I}(t_i \leq \tau) + (1-S_i(\tau))^2 \text{I}(t_i > \tau)}{G(t_i)} \ d\tau}
#' \deqn{L_{RISBS}(S_i, t_i, \delta_i) = \delta_i \frac{\int^{\tau^*}_0 S_i^2(\tau) \text{I}(t_i \leq \tau) + (1-S_i(\tau))^2 \text{I}(t_i > \tau) \ d\tau}{G(t_i)}}
#'
#' which is always weighted by \eqn{G(t_i)} and is equal to zero for a censored subject.
#'
Expand All @@ -48,10 +49,11 @@
#' @template details_tmax
#'
#' @references
#' `r format_bib("graf_1999")`
#' `r format_bib("graf_1999", "sonabend2024", "kvamme2023")`
#'
#' @family Probabilistic survival measures
#' @family distr survival measures
#' @template example_scoring_rules
#' @export
MeasureSurvGraf = R6Class("MeasureSurvGraf",
inherit = MeasureSurv,
Expand All @@ -73,11 +75,12 @@ MeasureSurvGraf = R6Class("MeasureSurvGraf",
se = p_lgl(default = FALSE),
proper = p_lgl(default = FALSE),
eps = p_dbl(0, 1, default = 1e-3),
ERV = p_lgl(default = FALSE)
ERV = p_lgl(default = FALSE),
remove_obs = p_lgl(default = FALSE)
)
ps$set_values(
integrated = TRUE, method = 2L, se = FALSE,
proper = FALSE, eps = 1e-3, ERV = ERV
proper = FALSE, eps = 1e-3, ERV = ERV, remove_obs = FALSE
)

range = if (ERV) c(-Inf, 1) else c(0, Inf)
Expand Down Expand Up @@ -132,7 +135,7 @@ MeasureSurvGraf = R6Class("MeasureSurvGraf",
truth = prediction$truth,
distribution = prediction$data$distr, times = times,
t_max = ps$t_max, p_max = ps$p_max, proper = ps$proper, train = train,
eps = ps$eps
eps = ps$eps, remove_obs = ps$remove_obs
)

if (ps$se) {
Expand Down
1 change: 1 addition & 0 deletions R/MeasureSurvHungAUC.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#'
#' @family AUC survival measures
#' @family lp survival measures
#' @template example_auc_measures
#' @export
MeasureSurvHungAUC = R6Class("MeasureSurvHungAUC",
inherit = MeasureSurvAUC,
Expand Down
15 changes: 9 additions & 6 deletions R/MeasureSurvIntLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#' @templateVar eps 1e-3
#' @template param_eps
#' @template param_erv
#' @template param_remove_obs
#'
#' @description
#' Calculates the **Integrated Survival Log-Likelihood** (ISLL) or Integrated
Expand All @@ -23,13 +24,13 @@
#' survival function \eqn{S_i(t)}, the *observation-wise* loss integrated across
#' the time dimension up to the time cutoff \eqn{\tau^*}, is:
#'
#' \deqn{L_{ISLL}(S_i, t_i, \delta_i) = -\text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{log[1-S_i(\tau)] \text{I}(t_i \leq \tau, \delta=1)}{G(t_i)} + \frac{\log[S_i(\tau)] \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
#' \deqn{L_{ISLL}(S_i, t_i, \delta_i) = - \int^{\tau^*}_0 \frac{log[1-S_i(\tau)] \text{I}(t_i \leq \tau, \delta_i=1)}{G(t_i)} + \frac{\log[S_i(\tau)] \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
#'
#' where \eqn{G} is the Kaplan-Meier estimate of the censoring distribution.
#'
#' The **re-weighted ISLL** (RISLL) is:
#'
#' \deqn{L_{RISLL}(S_i, t_i, \delta_i) = -\delta_i \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{\log[1-S_i(\tau)]) \text{I}(t_i \leq \tau) + \log[S_i(\tau)] \text{I}(t_i > \tau)}{G(t_i)} \ d\tau}
#' \deqn{L_{RISLL}(S_i, t_i, \delta_i) = -\delta_i \frac{\int^{\tau^*}_0 \log[1-S_i(\tau)]) \text{I}(t_i \leq \tau) + \log[S_i(\tau)] \text{I}(t_i > \tau) \ d\tau}{G(t_i)}}
#'
#' which is always weighted by \eqn{G(t_i)} and is equal to zero for a censored subject.
#'
Expand All @@ -46,10 +47,11 @@
#' @template details_tmax
#'
#' @references
#' `r format_bib("graf_1999")`
#' `r format_bib("graf_1999", "sonabend2024", "kvamme2023")`
#'
#' @family Probabilistic survival measures
#' @family distr survival measures
#' @template example_scoring_rules
#' @export
MeasureSurvIntLogloss = R6Class("MeasureSurvIntLogloss",
inherit = MeasureSurv,
Expand All @@ -71,11 +73,12 @@ MeasureSurvIntLogloss = R6Class("MeasureSurvIntLogloss",
se = p_lgl(default = FALSE),
proper = p_lgl(default = FALSE),
eps = p_dbl(0, 1, default = 1e-3),
ERV = p_lgl(default = FALSE)
ERV = p_lgl(default = FALSE),
remove_obs = p_lgl(default = FALSE)
)
ps$set_values(
integrated = TRUE, method = 2L, se = FALSE,
proper = FALSE, eps = 1e-3, ERV = ERV
proper = FALSE, eps = 1e-3, ERV = ERV, remove_obs = FALSE
)

range = if (ERV) c(-Inf, 1) else c(0, Inf)
Expand Down Expand Up @@ -130,7 +133,7 @@ MeasureSurvIntLogloss = R6Class("MeasureSurvIntLogloss",
truth = prediction$truth,
distribution = prediction$data$distr, times = times,
t_max = ps$t_max, p_max = ps$p_max, proper = ps$proper, train = train,
eps = ps$eps
eps = ps$eps, remove_obs = ps$remove_obs
)

if (ps$se) {
Expand Down
5 changes: 4 additions & 1 deletion R/MeasureSurvLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#' Calculates the cross-entropy, or negative log-likelihood (NLL) or logarithmic (log), loss.
#' @section Parameter details:
#' - `IPCW` (`logical(1)`)\cr
#' If `TRUE` (default) then returns the \eqn{L_{RNLL}} score (which is proper), otherwise the \eqn{L_{NLL}} score (improper).
#' If `TRUE` (default) then returns the \eqn{L_{RNLL}} score (which is proper), otherwise the \eqn{L_{NLL}} score (improper). See Sonabend et al. (2024) for more details.
#'
#' @details
#' The Log Loss, in the context of probabilistic predictions, is defined as the
Expand All @@ -33,6 +33,9 @@
#'
#' @template details_trainG
#'
#' @references
#' `r format_bib("sonabend2024")`
#'
#' @family Probabilistic survival measures
#' @family distr survival measures
#' @export
Expand Down
22 changes: 9 additions & 13 deletions R/MeasureSurvSchmid.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#' @templateVar eps 1e-3
#' @template param_eps
#' @template param_erv
#' @template param_remove_obs
#'
#' @description
#' Calculates the **Integrated Schmid Score** (ISS), aka integrated absolute loss.
Expand All @@ -22,27 +23,20 @@
#' survival function \eqn{S_i(t)}, the *observation-wise* loss integrated across
#' the time dimension up to the time cutoff \eqn{\tau^*}, is:
#'
#' \deqn{L_{ISS}(S_i, t_i, \delta_i) = \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{S_i(\tau) \text{I}(t_i \leq \tau, \delta=1)}{G(t_i)} + \frac{(1-S_i(\tau)) \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
#' \deqn{L_{ISS}(S_i, t_i, \delta_i) = \int^{\tau^*}_0 \frac{S_i(\tau) \text{I}(t_i \leq \tau, \delta=1)}{G(t_i)} + \frac{(1-S_i(\tau)) \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
#'
#' where \eqn{G} is the Kaplan-Meier estimate of the censoring distribution.
#'
#' The **re-weighted ISS** (RISS) is:
#'
#' \deqn{L_{RISS}(S_i, t_i, \delta_i) = \delta_i \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{S_i(\tau) \text{I}(t_i \leq \tau) + (1-S_i(\tau)) \text{I}(t_i > \tau)}{G(t_i)} \ d\tau}
#' \deqn{L_{RISS}(S_i, t_i, \delta_i) = \delta_i \frac{\int^{\tau^*}_0 S_i(\tau) \text{I}(t_i \leq \tau) + (1-S_i(\tau)) \text{I}(t_i > \tau) \ d\tau}{G(t_i)}}
#'
#' which is always weighted by \eqn{G(t_i)} and is equal to zero for a censored subject.
#'
#' To get a single score across all \eqn{N} observations of the test set, we
#' return the average of the time-integrated observation-wise scores:
#' \deqn{\sum_{i=1}^N L(S_i, t_i, \delta_i) / N}
#'
#'
#' \deqn{L_{ISS}(S,t|t^*) = [(S(t^*))I(t \le t^*, \delta = 1)(1/G(t))] + [((1 - S(t^*)))I(t > t^*)(1/G(t^*))]}
#' where \eqn{G} is the Kaplan-Meier estimate of the censoring distribution.
#'
#' The re-weighted ISS, RISS is given by
#' \deqn{L_{RISS}(S,t|t^*) = [(S(t^*))I(t \le t^*, \delta = 1)(1/G(t))] + [((1 - S(t^*)))I(t > t^*)(1/G(t))]}
#'
#' @template properness
#' @templateVar improper_id ISS
#' @templateVar proper_id RISS
Expand All @@ -52,10 +46,11 @@
#' @template details_tmax
#'
#' @references
#' `r format_bib("schemper_2000", "schmid_2011")`
#' `r format_bib("schemper_2000", "schmid_2011", "sonabend2024", "kvamme2023")`
#'
#' @family Probabilistic survival measures
#' @family distr survival measures
#' @template example_scoring_rules
#' @export
MeasureSurvSchmid = R6Class("MeasureSurvSchmid",
inherit = MeasureSurv,
Expand All @@ -77,11 +72,12 @@ MeasureSurvSchmid = R6Class("MeasureSurvSchmid",
se = p_lgl(default = FALSE),
proper = p_lgl(default = FALSE),
eps = p_dbl(0, 1, default = 1e-3),
ERV = p_lgl(default = FALSE)
ERV = p_lgl(default = FALSE),
remove_obs = p_lgl(default = FALSE)
)
ps$set_values(
integrated = TRUE, method = 2L, se = FALSE,
proper = FALSE, eps = 1e-3, ERV = ERV
proper = FALSE, eps = 1e-3, ERV = ERV, remove_obs = FALSE
)

range = if (ERV) c(-Inf, 1) else c(0, Inf)
Expand Down Expand Up @@ -135,7 +131,7 @@ MeasureSurvSchmid = R6Class("MeasureSurvSchmid",
truth = prediction$truth,
distribution = prediction$data$distr, times = times,
t_max = ps$t_max, p_max = ps$p_max, proper = ps$proper, train = train,
eps = ps$eps
eps = ps$eps, remove_obs = ps$remove_obs
)

if (ps$se) {
Expand Down
1 change: 1 addition & 0 deletions R/MeasureSurvSongAUC.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#'
#' @family AUC survival measures
#' @family lp survival measures
#' @template example_auc_measures
#' @export
MeasureSurvSongAUC = R6Class("MeasureSurvSongAUC",
inherit = MeasureSurvAUC,
Expand Down
1 change: 1 addition & 0 deletions R/MeasureSurvUnoAUC.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#'
#' @family AUC survival measures
#' @family lp survival measures
#' @template example_auc_measures
#' @export
MeasureSurvUnoAUC = R6Class("MeasureSurvUnoAUC",
inherit = MeasureSurvAUC,
Expand Down
20 changes: 20 additions & 0 deletions R/bibentries.R
Original file line number Diff line number Diff line change
Expand Up @@ -741,5 +741,25 @@ bibentries = c( # nolint start
title = "Simulating Survival Data Using the simsurv R Package",
volume = "97",
year = "2021"
),
sonabend2024 = bibentry("misc",
archivePrefix = "arXiv",
arxivId = "2212.05260",
author = "Sonabend, Raphael and Zobolas, John and Kopper, Philipp and Burk, Lukas and Bender, Andreas",
month = "dec",
title = "Examining properness in the external validation of survival models with squared and logarithmic losses",
url = "https://arxiv.org/abs/2212.05260v2",
year = "2024"
),
kvamme2023 = bibentry("article",
author = "Kvamme, Havard and Borgan, Ornulf",
issn = "1533-7928",
journal = "Journal of Machine Learning Research",
number = "2",
pages = "1--26",
title = "The Brier Score under Administrative Censoring: Problems and a Solution",
url = "http://jmlr.org/papers/v24/19-1030.html",
volume = "24",
year = "2023"
)
)
12 changes: 4 additions & 8 deletions R/integrated_scores.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ score_graf_schmid = function(true_times, unique_times, cdf, power = 2) {
# - `t_max` > 0
# - `p_max` in [0,1]
weighted_survival_score = function(loss, truth, distribution, times = NULL,
t_max = NULL, p_max = NULL, proper, train = NULL, eps, ...) {
t_max = NULL, p_max = NULL, proper, train = NULL, eps, remove_obs = FALSE) {
assert_surv(truth)
# test set's (times, status)
test_times = truth[, "time"]
Expand Down Expand Up @@ -90,8 +90,8 @@ weighted_survival_score = function(loss, truth, distribution, times = NULL,
rownames(cdf) = unique_times # times x obs
}

# apply `t_max` cutoff to remove observations
if (tmax_apply) {
# apply `t_max` cutoff to remove observations as a preprocessing step to alleviate inflation
if (tmax_apply && remove_obs) {
true_times = test_times[test_times <= t_max]
true_status = test_status[test_times <= t_max]
cdf = cdf[, test_times <= t_max, drop = FALSE]
Expand All @@ -118,6 +118,7 @@ weighted_survival_score = function(loss, truth, distribution, times = NULL,

# use the `truth` (time, status) information from the train or test set
if (is.null(train)) {
# no filtering of observations from test data: use ALL
cens = survival::survfit(Surv(test_times, 1 - test_status) ~ 1)
} else {
# no filtering of observations from train data: use ALL
Expand All @@ -128,11 +129,6 @@ weighted_survival_score = function(loss, truth, distribution, times = NULL,
# G(t): KM estimate of the censoring distribution
cens = matrix(c(cens$time, cens$surv), ncol = 2L)

# filter G(t) time points based on `t_max` cutoff
if (tmax_apply) {
cens = cens[cens[, 1L] <= t_max, , drop = FALSE]
}

score = .c_weight_survival_score(score, true_truth, unique_times, cens, proper, eps)
colnames(score) = unique_times

Expand Down
Loading
Loading