diff --git a/DESCRIPTION b/DESCRIPTION
index 6b1498bfe..34454da13 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
-Version: 0.5.5
+Version: 0.5.6
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
diff --git a/NEWS.md b/NEWS.md
index d45a64b2e..6900bd9a3 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -1,3 +1,9 @@
+# mlr3proba 0.5.6
+
+* Add `extend_quantile` to `autoplot.PredictionSurv` for `type = "dcalib"`, which imputes NAs with the maximum observed survival time
+* Fixes default in `autoplot.PredictionSurv`, now `"calib"`
+* Update `msr("surv.dcalib")` default for `truncate` to `Inf`
+
# mlr3proba 0.5.5
* Add `$reverse()` method to `TaskSurv`, which returns the same task but with 1-status.
diff --git a/R/MeasureSurvDCalibration.R b/R/MeasureSurvDCalibration.R
index e7d2d5c99..1904ed168 100644
--- a/R/MeasureSurvDCalibration.R
+++ b/R/MeasureSurvDCalibration.R
@@ -45,22 +45,22 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration",
#' If `TRUE` returns the p-value of the corresponding chisq.test instead of the measure.
#' Default is `FALSE` and returns the statistic `s`.
#' You can manually get the p-value by executing `pchisq(s, B - 1, lower.tail = FALSE)`.
- #' `p > 0.05` indicates a well-calibrated model.
+ #' The null hypothesis is that the model is D-calibrated.
#' @param truncate (`double(1)`) \cr
#' This parameter controls the upper bound of the output statistic,
- #' when `chisq` is `FALSE`. The default `truncate` value of \eqn{10}
- #' corresponds to a p-value of 0.35 for the chisq.test using \eqn{B = 10} buckets.
- #' Values \eqn{>10} translate to even lower p-values and thus less calibrated
- #' models. If the number of buckets \eqn{B} changes, you probably will want to
+ #' when `chisq` is `FALSE`. We use `truncate = Inf` by default but \eqn{10} may be sufficient
+ #' for most purposes, which corresponds to a p-value of 0.35 for the chisq.test using
+ #' \eqn{B = 10} buckets. Values \eqn{>10} translate to even lower p-values and thus
+ #' less calibrated models. If the number of buckets \eqn{B} changes, you probably will want to
#' change the `truncate` value as well to correspond to the same p-value significance.
- #' Initialize with `truncate = Inf` if no truncation is desired.
+ #' Note that truncation may severely limit automated tuning with this measure.
initialize = function() {
ps = ps(
B = p_int(1, default = 10),
chisq = p_lgl(default = FALSE),
- truncate = p_dbl(lower = 0, upper = Inf, default = 10)
+ truncate = p_dbl(lower = 0, upper = Inf, default = Inf)
)
- ps$values = list(B = 10L, chisq = FALSE, truncate = 10)
+ ps$values = list(B = 10L, chisq = FALSE, truncate = Inf)
super$initialize(
id = "surv.dcalib",
diff --git a/R/autoplot.R b/R/autoplot.R
index fff1eed03..37cd332b2 100644
--- a/R/autoplot.R
+++ b/R/autoplot.R
@@ -173,6 +173,8 @@ plot.TaskDens = function(x, ...) {
#' @param cuts (`integer(1)`) \cr
#' Number of cuts in (0,1) to plot `dcalib` over, default is `11`.
#' @template param_theme
+#' @param extend_quantile `(logical(1))` \cr
+#' If `TRUE` then `dcalib` will impute NAs from predicted quantile function with the maximum observed outcome time, e.g. if the last predicted survival probability is greater than 0.1, then the last predicted cdf is smaller than 0.9 so F^1(0.9) = NA, this would be imputed with max(times). Default is `FALSE`.
#' @param ... (`any`):
#' Additional arguments, currently unused.
#'
@@ -199,14 +201,15 @@ plot.TaskDens = function(x, ...) {
#'
#' # Predictions
#' autoplot(p, type = "preds")
-autoplot.PredictionSurv = function(object, type = "dcalib",
+autoplot.PredictionSurv = function(object, type = "calib",
task = NULL, row_ids = NULL, times = NULL, xyline = TRUE,
- cuts = 11L, theme = theme_minimal(), ...) {
+ cuts = 11L, theme = theme_minimal(), extend_quantile = FALSE, ...) {
assert("distr" %in% object$predict_types)
switch(type,
"calib" = {
+ assert_task(task)
if (is.null(times)) {
times = sort(unique(task$truth()[, 1]))
}
@@ -234,14 +237,20 @@ autoplot.PredictionSurv = function(object, type = "dcalib",
"dcalib" = {
p = seq.int(0, 1, length.out = cuts)
+ true_times = object$truth[, 1L]
q = map_dbl(p, function(.x) {
- sum(object$truth[, 1L] <= as.numeric(object$distr$quantile(.x)), na.rm = TRUE) / length(object$row_ids)
+ qi = as.numeric(object$distr$quantile(.x))
+ if (extend_quantile) {
+ qi[is.na(qi)] = max(true_times)
+ }
+ sum(true_times <= qi) / length(object$row_ids)
})
pl = ggplot(data = data.frame(p, q), aes(x = p, y = q)) +
geom_line()
if (xyline) {
- pl = pl + geom_abline(slope = 1, intercept = 0, color = "lightgray")
+ pl = pl +
+ geom_segment(aes(x = 0, y = 0, xend = 1, yend = 1), color = "lightgray")
}
pl +
labs(x = "True", y = "Predicted") +
diff --git a/man/autoplot.PredictionSurv.Rd b/man/autoplot.PredictionSurv.Rd
index 92f82fc06..9db104459 100644
--- a/man/autoplot.PredictionSurv.Rd
+++ b/man/autoplot.PredictionSurv.Rd
@@ -6,13 +6,14 @@
\usage{
\method{autoplot}{PredictionSurv}(
object,
- type = "dcalib",
+ type = "calib",
task = NULL,
row_ids = NULL,
times = NULL,
xyline = TRUE,
cuts = 11L,
theme = theme_minimal(),
+ extend_quantile = FALSE,
...
)
}
@@ -41,6 +42,9 @@ Number of cuts in (0,1) to plot \code{dcalib} over, default is \code{11}.}
\item{theme}{(\code{\link[ggplot2:theme]{ggplot2::theme()}})\cr
The \code{\link[ggplot2:ggtheme]{ggplot2::theme_minimal()}} is applied by default to all plots.}
+\item{extend_quantile}{\code{(logical(1))} \cr
+If \code{TRUE} then \code{dcalib} will impute NAs from predicted quantile function with the maximum observed outcome time, e.g. if the last predicted survival probability is greater than 0.1, then the last predicted cdf is smaller than 0.9 so F^1(0.9) = NA, this would be imputed with max(times). Default is \code{FALSE}.}
+
\item{...}{(\code{any}):
Additional arguments, currently unused.}
}
diff --git a/man/mlr_measures_surv.dcalib.Rd b/man/mlr_measures_surv.dcalib.Rd
index 6694eed71..e4da89b3d 100644
--- a/man/mlr_measures_surv.dcalib.Rd
+++ b/man/mlr_measures_surv.dcalib.Rd
@@ -138,16 +138,16 @@ Changing this parameter affects \code{truncate}.}
If \code{TRUE} returns the p-value of the corresponding chisq.test instead of the measure.
Default is \code{FALSE} and returns the statistic \code{s}.
You can manually get the p-value by executing \code{pchisq(s, B - 1, lower.tail = FALSE)}.
-\code{p > 0.05} indicates a well-calibrated model.}
+The null hypothesis is that the model is D-calibrated.}
\item{\code{truncate}}{(\code{double(1)}) \cr
This parameter controls the upper bound of the output statistic,
-when \code{chisq} is \code{FALSE}. The default \code{truncate} value of \eqn{10}
-corresponds to a p-value of 0.35 for the chisq.test using \eqn{B = 10} buckets.
-Values \eqn{>10} translate to even lower p-values and thus less calibrated
-models. If the number of buckets \eqn{B} changes, you probably will want to
+when \code{chisq} is \code{FALSE}. We use \code{truncate = Inf} by default but \eqn{10} may be sufficient
+for most purposes, which corresponds to a p-value of 0.35 for the chisq.test using
+\eqn{B = 10} buckets. Values \eqn{>10} translate to even lower p-values and thus
+less calibrated models. If the number of buckets \eqn{B} changes, you probably will want to
change the \code{truncate} value as well to correspond to the same p-value significance.
-Initialize with \code{truncate = Inf} if no truncation is desired.}
+Note that truncation may severely limit automated tuning with this measure.}
}
\if{html}{\out{}}
}
diff --git a/tests/testthat/_snaps/autoplot/predictionsurv-dcalib.svg b/tests/testthat/_snaps/autoplot/predictionsurv-dcalib.svg
index ab9682622..86f2a8eef 100644
--- a/tests/testthat/_snaps/autoplot/predictionsurv-dcalib.svg
+++ b/tests/testthat/_snaps/autoplot/predictionsurv-dcalib.svg
@@ -25,31 +25,43 @@
-
-
-
-
+
+
+
+
-
-
-
+
+
+
+
-
-
+
+
+
+
+
+
+
+
+
+
+
+
0.00
-0.05
-0.10
-0.15
+0.25
+0.50
+0.75
+1.00
0.00
0.25
0.50
diff --git a/tests/testthat/test_autoplot.R b/tests/testthat/test_autoplot.R
index 4abccfa92..dff586b6b 100644
--- a/tests/testthat/test_autoplot.R
+++ b/tests/testthat/test_autoplot.R
@@ -11,7 +11,7 @@ test_that("autoplot.PredictionSurv", {
expect_true(is.ggplot(p))
vdiffr::expect_doppelganger("predictionsurv_calib", p)
- p = autoplot(prediction, type = "dcalib")
+ p = autoplot(prediction, type = "dcalib", extend_quantile = TRUE)
expect_true(is.ggplot(p))
vdiffr::expect_doppelganger("predictionsurv_dcalib", p)
diff --git a/tests/testthat/test_mlr_measures.R b/tests/testthat/test_mlr_measures.R
index 5f766256c..cd06b63df 100644
--- a/tests/testthat/test_mlr_measures.R
+++ b/tests/testthat/test_mlr_measures.R
@@ -327,7 +327,7 @@ test_that("dcal works", {
expect_equal(score, score2)
expect_true(score > 10)
- score3 = p$score(msr("surv.dcalib")) # default truncate = 10
+ score3 = p$score(msr("surv.dcalib", truncate = 10))
expect_equal(unname(score3), 10)
score4 = p$score(msr("surv.dcalib", truncate = 5))
expect_equal(unname(score4), 5)