Skip to content

Commit

Permalink
Merge pull request #388 from epiforecasts/rework-quantile-scores
Browse files Browse the repository at this point in the history
Rework quantile scores
  • Loading branch information
nikosbosse authored Nov 16, 2023
2 parents 9ba205a + eb45cbb commit 4df4576
Show file tree
Hide file tree
Showing 94 changed files with 3,278 additions and 2,129 deletions.
13 changes: 13 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ export(bias_sample)
export(brier_score)
export(correlation)
export(crps_sample)
export(dispersion)
export(dss_sample)
export(get_duplicate_forecasts)
export(interval_coverage_deviation_quantile)
export(interval_coverage_quantile)
export(interval_coverage_sample)
export(interval_score)
export(log_shift)
export(logs_binary)
Expand All @@ -39,6 +43,7 @@ export(make_NA)
export(make_na)
export(merge_pred_and_obs)
export(new_scoringutils)
export(overprediction)
export(pairwise_comparison)
export(pit)
export(pit_sample)
Expand All @@ -54,6 +59,7 @@ export(plot_ranges)
export(plot_score_table)
export(plot_wis)
export(quantile_score)
export(run_safely)
export(sample_to_quantile)
export(score)
export(se_mean_sample)
Expand All @@ -63,8 +69,10 @@ export(summarise_scores)
export(summarize_scores)
export(theme_scoringutils)
export(transform_forecasts)
export(underprediction)
export(validate)
export(validate_general)
export(wis)
importFrom(Metrics,ae)
importFrom(Metrics,ape)
importFrom(Metrics,se)
Expand All @@ -74,12 +82,15 @@ importFrom(checkmate,assert_data_frame)
importFrom(checkmate,assert_data_table)
importFrom(checkmate,assert_factor)
importFrom(checkmate,assert_list)
importFrom(checkmate,assert_logical)
importFrom(checkmate,assert_number)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,check_atomic_vector)
importFrom(checkmate,check_data_frame)
importFrom(checkmate,check_function)
importFrom(checkmate,check_matrix)
importFrom(checkmate,check_numeric)
importFrom(checkmate,check_vector)
importFrom(checkmate,test_factor)
importFrom(checkmate,test_list)
importFrom(checkmate,test_numeric)
Expand All @@ -98,6 +109,7 @@ importFrom(data.table,nafill)
importFrom(data.table,rbindlist)
importFrom(data.table,setDT)
importFrom(data.table,setattr)
importFrom(data.table,setcolorder)
importFrom(data.table,setnames)
importFrom(ggdist,geom_lineribbon)
importFrom(ggplot2,.data)
Expand Down Expand Up @@ -157,5 +169,6 @@ importFrom(stats,rbinom)
importFrom(stats,reorder)
importFrom(stats,runif)
importFrom(stats,sd)
importFrom(stats,weighted.mean)
importFrom(stats,wilcox.test)
importFrom(utils,combn)
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The update introduces breaking changes. If you want to keep using the older vers
- `quantile`: numeric, a vector with quantile-levels. Can alternatively be a matrix of the same shape as `predicted`.
- `check_forecasts()` was replaced by a new function `validate()`. `validate()` validates the input and in that sense fulfills the purpose of `check_forecasts()`. It has different methods: `validate.default()` assigns the input a class based on their forecast type. Other methods validate the input specifically for the various forecast types.
- The functionality for computing pairwise comparisons was now split from `summarise_scores()`. Instead of doing pairwise comparisons as part of summarising scores, a new function, `add_pairwise_comparison()`, was introduced that takes summarised scores as an input and adds pairwise comparisons to it.
- `add_coverage()` was reworked completely. It's new purpose is now to add coverage information to the raw forecast data (essentially fulfilling some of the functionality that was previously covered by `score_quantile()`)
- The function `find_duplicates()` was renamed to `get_duplicate_forecasts()`
- Changes to `avail_forecasts()` and `plot_avail_forecasts()`:
- The function `avail_forecasts()` was renamed to `available_forecasts()` for consistency with `available_metrics()`. The old function, `avail_forecasts()` is still available as an alias, but will be removed in the future.
Expand Down
83 changes: 83 additions & 0 deletions R/add_coverage.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#' @title Add Coverage Values to Quantile-Based Forecasts
#'
#' @description Adds interval coverage of central prediction intervals,
#' quantile coverage for predictive quantiles, as well as the deviation between
#' desired and actual coverage to a data.table. Forecasts should be in a
#' quantile format (following the input requirements of `score()`).
#'
#' **Interval coverage**
#'
#' Coverage for a given interval range is defined as the proportion of
#' observations that fall within the corresponding central prediction intervals.
#' Central prediction intervals are symmetric around the median and and formed
#' by two quantiles that denote the lower and upper bound. For example, the 50%
#' central prediction interval is the interval between the 0.25 and 0.75
#' quantiles of the predictive distribution.
#'
#' The function `add_coverage()` computes the coverage per central prediction
#' interval, so the coverage will always be either `TRUE` (observed value falls
#' within the interval) or `FALSE` (observed value falls outside the interval).
#' You can summarise the coverage values to get the proportion of observations
#' that fall within the central prediction intervals.
#'
#' **Quantile coverage**
#'
#' Quantile coverage for a given quantile is defined as the proportion of
#' observed values that are smaller than the corresponding predictive quantile.
#' For example, the 0.5 quantile coverage is the proportion of observed values
#' that are smaller than the 0.5 quantile of the predictive distribution.
#'
#' **Coverage deviation**
#'
#' The coverage deviation is the difference between the desired coverage and the
#' actual coverage. For example, if the desired coverage is 90% and the actual
#' coverage is 80%, the coverage deviation is -0.1.
#'
#' @inheritParams score
#' @return a data.table with the input and columns "interval_coverage",
#' "interval_coverage_deviation", "quantile_coverage",
#' "quantile_coverage_deviation" added.
#' @importFrom data.table setcolorder
#' @examples
#' library(magrittr) # pipe operator
#' example_quantile %>%
#' add_coverage()
#' @export
#' @keywords scoring
#' @export
add_coverage <- function(data) {
stored_attributes <- get_scoringutils_attributes(data)
data <- validate(data)
forecast_unit <- get_forecast_unit(data)
data_cols <- colnames(data) # store so we can reset column order later

# what happens if quantiles are not symmetric around the median?
# should things error? Also write tests for that.
interval_data <- quantile_to_interval(data, format = "wide")
interval_data[, interval_coverage := ifelse(
observed <= upper & observed >= lower,
TRUE,
FALSE)
][, c("lower", "upper", "observed") := NULL]

data[, range := get_range_from_quantile(quantile)]

data <- merge(interval_data, data, by = unique(c(forecast_unit, "range")))
data[, interval_coverage_deviation := interval_coverage - range / 100]
data[, quantile_coverage := observed <= predicted]
data[, quantile_coverage_deviation := quantile_coverage - quantile]

# reset column order
new_metrics <- c("interval_coverage", "interval_coverage_deviation",
"quantile_coverage", "quantile_coverage_deviation")
setcolorder(data, unique(c(data_cols, "range", new_metrics)))

# add coverage "metrics" to list of stored metrics
# this makes it possible to use `summarise_scores()` later on
stored_attributes[["metric_names"]] <- c(
stored_attributes[["metric_names"]],
new_metrics
)
data <- assign_attributes(data, stored_attributes)
return(data[])
}
14 changes: 12 additions & 2 deletions R/check-input-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,22 @@ check_columns_present <- function(data, columns) {
}
assert_character(columns, min.len = 1)
colnames <- colnames(data)
missing <- list()
for (x in columns){
if (!(x %in% colnames)) {
msg <- paste0("Column '", x, "' not found in data")
return(msg)
missing[[x]] <- x
}
}
missing <- unlist(missing)
if (length(missing) > 1) {
msg <- paste0(
"Columns '", paste(missing, collapse = "', '"), "' not found in data"
)
return(msg)
} else if (length(missing) == 1) {
msg <- paste0("Column '", missing, "' not found in data")
return(msg)
}
return(TRUE)
}

Expand Down
73 changes: 70 additions & 3 deletions R/check-inputs-scoring-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,20 @@ check_input_sample <- function(observed, predicted) {
#' @param quantile Input to be checked. Should be a vector of size N that
#' denotes the quantile levels corresponding to the columns of the prediction
#' matrix.
#' @importFrom checkmate assert assert_numeric check_matrix
#' @param unique_quantiles Input to be checked. Should be TRUE (default) or
#' FALSE. Whether the quantile levels are required to be unique or not.
#' @importFrom checkmate assert assert_numeric check_matrix check_vector
#' @inherit document_assert_functions return
#' @keywords internal
assert_input_quantile <- function(observed, predicted, quantile) {
assert_input_quantile <- function(observed, predicted, quantile,
unique_quantiles = TRUE) {
assert_numeric(observed, min.len = 1)
n_obs <- length(observed)

assert_numeric(quantile, min.len = 1, lower = 0, upper = 1)
assert_numeric(
quantile, min.len = 1, lower = 0, upper = 1,
unique = unique_quantiles
)
n_quantiles <- length(quantile)
if (n_obs == 1) {
assert(
Expand All @@ -66,6 +72,7 @@ assert_input_quantile <- function(observed, predicted, quantile) {
check_matrix(predicted, mode = "numeric",
nrows = n_obs, ncols = n_quantiles)
)
assert(check_vector(quantile, len = length(predicted)))
} else {
assert(
check_matrix(predicted, mode = "numeric",
Expand All @@ -85,6 +92,66 @@ check_input_quantile <- function(observed, predicted, quantile) {
}


#' @title Assert that inputs are correct for interval-based forecast
#' @description Function assesses whether the inputs correspond to the
#' requirements for scoring interval-based forecasts.
#' @param observed Input to be checked. Should be a numeric vector with the
#' observed values of size n
#' @param lower Input to be checked. Should be a numeric vector of size n that
#' holds the predicted value for the lower bounds of the prediction intervals.
#' @param upper Input to be checked. Should be a numeric vector of size n that
#' holds the predicted value for the upper bounds of the prediction intervals.
#' @param range Input to be checked. Should be a vector of size n that
#' denotes the interval range in percent. E.g. a value of 50 denotes a
#' (25%, 75%) prediction interval.
#' @importFrom rlang warn
#' @inherit document_assert_functions return
#' @keywords internal
assert_input_interval <- function(observed, lower, upper, range) {

assert(check_numeric_vector(observed, min.len = 1))
n <- length(observed)
assert(check_numeric_vector(lower, len = n))
assert(check_numeric_vector(upper, len = n))
assert(
check_numeric_vector(range, len = 1, lower = 0, upper = 100),
check_numeric_vector(range, len = n, lower = 0, upper = 100)
)

diff <- upper - lower
diff <- diff[!is.na(diff)]
if (any(diff < 0)) {
stop(
"All values in `upper` need to be greater than or equal to ",
"the corresponding values in `lower`"
)
}
if (any(range > 0 & range < 1, na.rm = TRUE)) {
msg <- paste(
"Found interval ranges between 0 and 1. Are you sure that's right? An",
"interval range of 0.5 e.g. implies a (49.75%, 50.25%) prediction",
"interval. If you want to score a (25%, 75%) prediction interval, set",
"`interval_range = 50`."
)
rlang::warn(
message = msg, .frequency = "once",
.frequency_id = "small_interval_range"
)
}
return(invisible(NULL))
}


#' @title Check that inputs are correct for interval-based forecast
#' @inherit assert_input_interval params description
#' @inherit check_input_sample return description
#' @keywords check-inputs
check_input_interval <- function(observed, lower, upper, range) {
result <- check_try(assert_input_quantile(observed, lower, upper, range))
return(result)
}


#' @title Assert that inputs are correct for binary forecast
#' @description Function assesses whether the inputs correspond to the
#' requirements for scoring binary forecasts.
Expand Down
18 changes: 5 additions & 13 deletions R/convenience-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,21 +235,13 @@ log_shift <- function(x, offset = 0, base = exp(1)) {
#' example_quantile,
#' c("location", "target_end_date", "target_type", "horizon", "model")
#' )

set_forecast_unit <- function(data, forecast_unit) {

datacols <- colnames(data)
missing <- forecast_unit[!(forecast_unit %in% datacols)]

if (length(missing) > 0) {
warning(
"Column(s) '",
missing,
"' are not columns of the data and will be ignored."
)
forecast_unit <- intersect(forecast_unit, datacols)
data <- ensure_data.table(data)
missing <- check_columns_present(data, forecast_unit)
if (!is.logical(missing)) {
warning(missing)
forecast_unit <- intersect(forecast_unit, colnames(data))
}

keep_cols <- c(get_protected_columns(data), forecast_unit)
out <- unique(data[, .SD, .SDcols = keep_cols])[]
return(out)
Expand Down
27 changes: 21 additions & 6 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#' \item{model}{name of the model that generated the forecasts}
#' \item{horizon}{forecast horizon in weeks}
#' }
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} # nolint
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/}
"example_quantile"


Expand All @@ -44,7 +44,7 @@
#' \item{model}{name of the model that generated the forecasts}
#' \item{horizon}{forecast horizon in weeks}
#' }
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} # nolint
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/}
"example_point"


Expand All @@ -69,7 +69,7 @@
#' \item{predicted}{predicted value}
#' \item{sample_id}{id for the corresponding sample}
#' }
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} # nolint
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/}
"example_continuous"


Expand Down Expand Up @@ -124,7 +124,7 @@
#' \item{horizon}{forecast horizon in weeks}
#' \item{predicted}{predicted value}
#' }
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} # nolint
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/}
"example_binary"


Expand All @@ -147,7 +147,7 @@
#' \item{model}{name of the model that generated the forecasts}
#' \item{horizon}{forecast horizon in weeks}
#' }
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} # nolint
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/}
"example_quantile_forecasts_only"


Expand All @@ -167,7 +167,7 @@
#' \item{observed}{observed values}
#' \item{location_name}{name of the country for which a prediction was made}
#' }
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} # nolint
#' @source \url{https://github.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/}
"example_truth_only"

#' Summary information for selected metrics
Expand Down Expand Up @@ -211,3 +211,18 @@
#' - "se_mean" = [se_mean_sample()]
#' @keywords info
"metrics_sample"

#' Default metrics for quantile-based forecasts.
#'
#' A named list with functions:
#' - "wis" = [wis()]
#' - "overprediction" = [overprediction()]
#' - "underprediction" = [underprediction()]
#' - "dispersion" = [dispersion()]
#' - "bias" = [bias_quantile()]
#' - "coverage_50" = \(...) {run_safely(..., range = 50, fun = [interval_coverage_quantile][interval_coverage_quantile()])}
#' - "coverage_90" = \(...) {run_safely(..., range = 90, fun = [interval_coverage_quantile][interval_coverage_quantile()])}
#' - "coverage_deviation" = [interval_coverage_deviation_quantile()],
#' - "ae_median" = [ae_median_quantile()]
#' @keywords info
"metrics_quantile"
2 changes: 2 additions & 0 deletions R/get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ get_protected_columns <- function(data = NULL) {
protected_columns <- c(
"predicted", "observed", "sample_id", "quantile", "upper", "lower",
"pit_value", "range", "boundary", "relative_skill", "scaled_rel_skill",
"interval_coverage", "interval_coverage_deviation",
"quantile_coverage", "quantile_coverage_deviation",
available_metrics(),
grep("coverage_", names(data), fixed = TRUE, value = TRUE)
)
Expand Down
Loading

0 comments on commit 4df4576

Please sign in to comment.