Skip to content

Commit

Permalink
Merge pull request #341 from mlr-org/extend_dcalib
Browse files Browse the repository at this point in the history
impute NAs in dcalib autoplot
  • Loading branch information
bblodfon authored Dec 21, 2023
2 parents 3bbfb89 + bbd8a1f commit af918e9
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 34 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.5.5
Version: 0.5.6
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# mlr3proba 0.5.6

* Add `extend_quantile` to `autoplot.PredictionSurv` for `type = "dcalib"`, which imputes NAs with the maximum observed survival time
* Fixes default in `autoplot.PredictionSurv`, now `"calib"`
* Update `msr("surv.dcalib")` default for `truncate` to `Inf`

# mlr3proba 0.5.5

* Add `$reverse()` method to `TaskSurv`, which returns the same task but with 1-status.
Expand Down
16 changes: 8 additions & 8 deletions R/MeasureSurvDCalibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,22 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration",
#' If `TRUE` returns the p-value of the corresponding chisq.test instead of the measure.
#' Default is `FALSE` and returns the statistic `s`.
#' You can manually get the p-value by executing `pchisq(s, B - 1, lower.tail = FALSE)`.
#' `p > 0.05` indicates a well-calibrated model.
#' The null hypothesis is that the model is D-calibrated.
#' @param truncate (`double(1)`) \cr
#' This parameter controls the upper bound of the output statistic,
#' when `chisq` is `FALSE`. The default `truncate` value of \eqn{10}
#' corresponds to a p-value of 0.35 for the chisq.test using \eqn{B = 10} buckets.
#' Values \eqn{>10} translate to even lower p-values and thus less calibrated
#' models. If the number of buckets \eqn{B} changes, you probably will want to
#' when `chisq` is `FALSE`. We use `truncate = Inf` by default but \eqn{10} may be sufficient
#' for most purposes, which corresponds to a p-value of 0.35 for the chisq.test using
#' \eqn{B = 10} buckets. Values \eqn{>10} translate to even lower p-values and thus
#' less calibrated models. If the number of buckets \eqn{B} changes, you probably will want to
#' change the `truncate` value as well to correspond to the same p-value significance.
#' Initialize with `truncate = Inf` if no truncation is desired.
#' Note that truncation may severely limit automated tuning with this measure.
initialize = function() {
ps = ps(
B = p_int(1, default = 10),
chisq = p_lgl(default = FALSE),
truncate = p_dbl(lower = 0, upper = Inf, default = 10)
truncate = p_dbl(lower = 0, upper = Inf, default = Inf)
)
ps$values = list(B = 10L, chisq = FALSE, truncate = 10)
ps$values = list(B = 10L, chisq = FALSE, truncate = Inf)

super$initialize(
id = "surv.dcalib",
Expand Down
17 changes: 13 additions & 4 deletions R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ plot.TaskDens = function(x, ...) {
#' @param cuts (`integer(1)`) \cr
#' Number of cuts in (0,1) to plot `dcalib` over, default is `11`.
#' @template param_theme
#' @param extend_quantile `(logical(1))` \cr
#' If `TRUE` then `dcalib` will impute NAs from predicted quantile function with the maximum observed outcome time, e.g. if the last predicted survival probability is greater than 0.1, then the last predicted cdf is smaller than 0.9 so F^1(0.9) = NA, this would be imputed with max(times). Default is `FALSE`.
#' @param ... (`any`):
#' Additional arguments, currently unused.
#'
Expand All @@ -199,14 +201,15 @@ plot.TaskDens = function(x, ...) {
#'
#' # Predictions
#' autoplot(p, type = "preds")
autoplot.PredictionSurv = function(object, type = "dcalib",
autoplot.PredictionSurv = function(object, type = "calib",
task = NULL, row_ids = NULL, times = NULL, xyline = TRUE,
cuts = 11L, theme = theme_minimal(), ...) {
cuts = 11L, theme = theme_minimal(), extend_quantile = FALSE, ...) {

assert("distr" %in% object$predict_types)

switch(type,
"calib" = {
assert_task(task)
if (is.null(times)) {
times = sort(unique(task$truth()[, 1]))
}
Expand Down Expand Up @@ -234,14 +237,20 @@ autoplot.PredictionSurv = function(object, type = "dcalib",

"dcalib" = {
p = seq.int(0, 1, length.out = cuts)
true_times = object$truth[, 1L]
q = map_dbl(p, function(.x) {
sum(object$truth[, 1L] <= as.numeric(object$distr$quantile(.x)), na.rm = TRUE) / length(object$row_ids)
qi = as.numeric(object$distr$quantile(.x))
if (extend_quantile) {
qi[is.na(qi)] = max(true_times)
}
sum(true_times <= qi) / length(object$row_ids)
})
pl = ggplot(data = data.frame(p, q), aes(x = p, y = q)) +
geom_line()

if (xyline) {
pl = pl + geom_abline(slope = 1, intercept = 0, color = "lightgray")
pl = pl +
geom_segment(aes(x = 0, y = 0, xend = 1, yend = 1), color = "lightgray")
}
pl +
labs(x = "True", y = "Predicted") +
Expand Down
6 changes: 5 additions & 1 deletion man/autoplot.PredictionSurv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions man/mlr_measures_surv.dcalib.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 24 additions & 12 deletions tests/testthat/_snaps/autoplot/predictionsurv-dcalib.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion tests/testthat/test_autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ test_that("autoplot.PredictionSurv", {
expect_true(is.ggplot(p))
vdiffr::expect_doppelganger("predictionsurv_calib", p)

p = autoplot(prediction, type = "dcalib")
p = autoplot(prediction, type = "dcalib", extend_quantile = TRUE)
expect_true(is.ggplot(p))
vdiffr::expect_doppelganger("predictionsurv_dcalib", p)

Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_mlr_measures.R
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ test_that("dcal works", {
expect_equal(score, score2)
expect_true(score > 10)

score3 = p$score(msr("surv.dcalib")) # default truncate = 10
score3 = p$score(msr("surv.dcalib", truncate = 10))
expect_equal(unname(score3), 10)
score4 = p$score(msr("surv.dcalib", truncate = 5))
expect_equal(unname(score4), 5)
Expand Down

0 comments on commit af918e9

Please sign in to comment.