From fbd7cbaf1e2a50b77f382580ce6f25c4b1d8d68e Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Tue, 17 Sep 2024 00:04:15 +0200 Subject: [PATCH] Simplify score_model_out --- .Rbuildignore | 1 + .gitignore | 2 + DESCRIPTION | 2 +- R/score_model_out.R | 233 +++++++++----------------- man/score_model_out.Rd | 80 ++++----- tests/testthat/test-score_model_out.R | 34 ++-- 6 files changed, 144 insertions(+), 208 deletions(-) diff --git a/.Rbuildignore b/.Rbuildignore index e57306d..8ca8d33 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -15,3 +15,4 @@ ^\.Rdata$ ^\.httr-oauth$ ^\.secrets$ +^.vscode diff --git a/.gitignore b/.gitignore index 440e2e6..27440c7 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,5 @@ docs .Rdata .secrets .quarto + +.vscode/ diff --git a/DESCRIPTION b/DESCRIPTION index c711112..a832dd0 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -51,7 +51,7 @@ Remotes: hubverse-org/hubExamples, hubverse-org/hubUtils Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 URL: https://hubverse-org.github.io/hubEvals/ Depends: R (>= 2.10) diff --git a/R/score_model_out.R b/R/score_model_out.R index 0fd3a29..f74350b 100644 --- a/R/score_model_out.R +++ b/R/score_model_out.R @@ -1,9 +1,13 @@ -#' Score model output predictions with a single `output_type` against observed data +#' Score model output predictions +#' +#' Scores model outputs with a single `output_type` against observed data. #' #' @param model_out_tbl Model output tibble with predictions #' @param target_observations Observed 'ground truth' data to be compared to #' predictions -#' @param metrics Optional character vector of scoring metrics to compute. See details for more. +#' @param metrics Character vector of scoring metrics to compute. If `NULL` +#' (the default), appropriate metrics are chosen automatically. See details +#' for more. #' @param summarize Boolean indicator of whether summaries of forecast scores #' should be computed. Defaults to `TRUE`. #' @param by Character vector naming columns to summarize by. For example, @@ -13,43 +17,40 @@ #' vector of levels for pmf forecasts, in increasing order of the levels. For #' all other output types, this is ignored. #' -#' @details If `metrics` is `NULL` (the default), this function chooses -#' appropriate metrics based on the `output_type` contained in the `model_out_tbl`: -#' -#' \itemize{ -#' \item For `output_type == "quantile"`, we use the default metrics provided by -#' `scoringutils`: -#' `r names(scoringutils::get_metrics(scoringutils::example_quantile))` -#' \item For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating -#' that the predicted variable is a nominal variable), we use the default metric -#' provided by `scoringutils`: -#' `r names(scoringutils::get_metrics(scoringutils::example_nominal))` -#' \item For `output_type == "median"`, we use "ae_point" -#' \item For `output_type == "mean"`, we use "se_point" -#' } -#' -#' Alternatively, a character vector of scoring metrics can be provided. In this -#' case, the following options are supported: -#' - `output_type == "median"` and `output_type == "mean"`: -#' - "ae_point": absolute error of a point prediction (generally recommended for the median) -#' - "se_point": squared error of a point prediction (generally recommended for the mean) -#' - `output_type == "quantile"`: -#' - "ae_median": absolute error of the predictive median (i.e., the quantile at probability level 0.5) -#' - "wis": weighted interval score (WIS) of a collection of quantile predictions -#' - "overprediction": The component of WIS measuring the extent to which -#' predictions fell above the observation. -#' - "underprediction": The component of WIS measuring the extent to which -#' predictions fell below the observation. -#' - "dispersion": The component of WIS measuring the dispersion of forecast -#' distributions. -#' - "interval_coverage_XX": interval coverage at the "XX" level. For example, +#' @details +#' Default metrics are provided by the `scoringutils` package. You can select +#' metrics by passing in a character vector of metric names to the `metrics` +#' argument. +#' +#' The following metrics can be selected (all are used by default) for the +#' different `output_type`s: +#' +#' **Quantile forecasts:** (`output_type == "quantile"`) +#' `r exclude <- c("interval_coverage_50", "interval_coverage_90")` +#' `r metrics <- scoringutils::get_metrics(scoringutils::example_quantile, exclude = exclude)` +#' `r paste("- ", names(metrics), collapse = "\n")` +#' - "interval_coverage_XX": interval coverage at the "XX" level. For example, #' "interval_coverage_95" is the 95% interval coverage rate, which would be calculated #' based on quantiles at the probability levels 0.025 and 0.975. -#' - `output_type == "pmf"`: -#' - "log_score": log score #' -#' See [scoringutils::get_metrics()] for more details on the default metrics -#' used by `scoringutils`. +#' See [scoringutils::get_metrics.forecast_quantile] for details. +#' +#' **Nominal forecasts:** (`output_type == "pmf"` and `output_type_id_order` is `NULL`) +#' +#' `r paste("- ", names(scoringutils::get_metrics(scoringutils::example_nominal)), collapse = "\n")` +#' +#' (scoring for ordinal forecasts will be added in the future). +#' +#' See [scoringutils::get_metrics.forecast_nominal] for details. +#' +#' **Median forecasts:** (`output_type == "median"`) +#' +#' - ae_point: absolute error of the point forecast (recommended for the median, see Gneiting (2011)) +#' +#' See [scoringutils::get_metrics.forecast_point] for details. +#' +#' **Mean forecasts:** (`output_type == "mean"`) +#' - `se_point`: squared error of the point forecast (recommended for the mean, see Gneiting (2011)) #' #' @examplesIf requireNamespace("hubExamples", quietly = TRUE) #' # compute WIS and interval coverage rates at 80% and 90% levels based on @@ -77,7 +78,11 @@ #' ) #' head(pmf_scores) #' -#' @return forecast_quantile +#' @return A data.table with scores +#' +#' @references +#' #' Making and Evaluating Point Forecasts, Gneiting, Tilmann, 2011, +#' Journal of the American Statistical Association. #' #' @export score_model_out <- function(model_out_tbl, target_observations, metrics = NULL, @@ -87,9 +92,6 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL, # also, retrieve that output_type output_type <- validate_output_type(model_out_tbl) - # get/validate the scoring metrics - metrics <- get_metrics(metrics, output_type, output_type_id_order) - # assemble data for scoringutils su_data <- switch(output_type, quantile = transform_quantile_model_out(model_out_tbl, target_observations), @@ -99,16 +101,15 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL, NULL # default, should not happen because of the validation above ) + # get/validate the scoring metrics + metrics <- get_metrics(forecast = su_data, output_type = output_type, select = metrics) + # compute scores scores <- scoringutils::score(su_data, metrics) # switch back to hubverse naming conventions for model name scores <- dplyr::rename(scores, model_id = "model") - # if present, drop predicted and observed columns - drop_cols <- c("predicted", "observed") - scores <- scores[!colnames(scores) %in% drop_cols] - # if requested, summarize scores if (summarize) { scores <- scoringutils::summarize_scores(scores = scores, by = by) @@ -118,138 +119,56 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL, } -#' Get metrics if user didn't specify anything; otherwise, process -#' and validate user inputs -#' -#' @inheritParams score_model_out -#' -#' @return a list of metric functions as required by scoringutils::score() -#' -#' @noRd -get_metrics <- function(metrics, output_type, output_type_id_order) { - if (is.null(metrics)) { - return(get_metrics_default(output_type, output_type_id_order)) - } else if (is.character(metrics)) { - return(get_metrics_character(metrics, output_type)) - } else { - cli::cli_abort( - "{.arg metrics} must be either `NULL` or a character vector of supported metrics." - ) - } -} - - -#' Default metrics if user didn't specify anything +#' Get scoring metrics #' +#' @param forecast A scoringutils `forecast` object (see +#' [scoringutils::as_forecast()] for details). #' @inheritParams score_model_out #' #' @return a list of metric functions as required by scoringutils::score() #' #' @noRd -get_metrics_default <- function(output_type, output_type_id_order) { - metrics <- switch(output_type, - quantile = scoringutils::get_metrics(scoringutils::example_quantile), - pmf = scoringutils::get_metrics(scoringutils::example_nominal), - mean = scoringutils::get_metrics(scoringutils::example_point, select = "se_point"), - median = scoringutils::get_metrics(scoringutils::example_point, select = "ae_point"), - NULL # default - ) - if (is.null(metrics)) { - # we have already validated `output_type`, so this case should not be - # triggered; this case is just double checking in case we add something new - # later, to ensure we update this function. - supported_types <- c("mean", "median", "pmf", "quantile") # nolint object_use_linter - cli::cli_abort( - "Provided `model_out_tbl` contains `output_type` {.val {output_type}}; - hubEvals currently only supports the following types: - {.val {supported_types}}" - ) - } - - return(metrics) -} +get_metrics <- function(forecast, output_type, select = NULL) { + forecast_type <- class(forecast)[1] - -#' Convert character vector of metrics to list of functions -#' -#' @inheritParams score_model_out -#' -#' @return a list of metric functions as required by scoringutils::score() -#' -#' @noRd -get_metrics_character <- function(metrics, output_type) { - if (output_type == "quantile") { + # process quantile metrics separately to allow better selection of interval + # coverage metrics + if (forecast_type == "forecast_quantile") { # split into metrics for interval coverage and others - interval_metric_inds <- grepl(pattern = "^interval_coverage_[[:digit:]][[:digit:]]$", metrics) - interval_metrics <- metrics[interval_metric_inds] - other_metrics <- metrics[!interval_metric_inds] + interval_metric_inds <- grepl(pattern = "interval_coverage_", select) + interval_metrics <- select[interval_metric_inds] + other_metrics <- select[!interval_metric_inds] - # validate metrics - valid_metrics <- c("ae_median", "wis", "overprediction", "underprediction", "dispersion") - invalid_metrics <- other_metrics[!other_metrics %in% valid_metrics] - error_if_invalid_metrics( - valid_metrics = c(valid_metrics, "interval_coverage_XY"), - invalid_metrics = invalid_metrics, - output_type = output_type, - comment = c("i" = "NOTE: `XY` denotes the coverage level, e.g. {.val interval_coverage_95}.") - ) + other_metric_fns <- scoringutils::get_metrics(forecast, select = other_metrics) - # assemble metric functions + # assemble interval coverage functions interval_metric_fns <- lapply( interval_metrics, function(metric) { - level <- as.integer(substr(metric, 19, 20)) + level_str <- substr(metric, 19, nchar(metric)) + level <- suppressWarnings(as.numeric(level_str)) + if (is.na(level) || level <= 0 || level >= 100) { + stop(paste( + "Invalid interval coverage level:", level_str, + "- must be a number between 0 and 100 (exclusive)" + )) + } return(purrr::partial(scoringutils::interval_coverage, interval_range = level)) } ) names(interval_metric_fns) <- interval_metrics - other_metric_fns <- scoringutils::get_metrics( - scoringutils::example_quantile, - select = other_metrics - ) - - metric_fns <- c(other_metric_fns, interval_metric_fns)[metrics] - metrics <- metric_fns - } else if (output_type == "pmf") { - valid_metrics <- c("log_score") - invalid_metrics <- metrics[!metrics %in% valid_metrics] - error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type) - - metrics <- scoringutils::get_metrics( - scoringutils::example_nominal, - select = metrics - ) - } else if (output_type %in% c("median", "mean")) { - valid_metrics <- c("ae_point", "se_point") - invalid_metrics <- metrics[!metrics %in% valid_metrics] - error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type) - - metrics <- scoringutils::get_metrics( - scoringutils::example_point, - select = metrics - ) - } else { - # we have already validated `output_type`, so this case should not be - # triggered; this case is just double checking in case we add something new - # later, to ensure we update this function. - error_if_invalid_output_type(output_type) + metric_fns <- c(other_metric_fns, interval_metric_fns) + return(metric_fns) } - return(metrics) -} - - -error_if_invalid_metrics <- function(valid_metrics, invalid_metrics, output_type, comment = NULL) { - n <- length(invalid_metrics) - if (n > 0) { - cli::cli_abort( - c( - "`metrics` had {n} unsupported metric{?s} for - {.arg output_type} {.val {output_type}}: {.strong {.val {invalid_metrics}}}; - supported metrics include {.val {valid_metrics}}.", - comment - ) - ) + # leave validation of user selection to scoringutils + metric_fns <- scoringutils::get_metrics(forecast, select = select) + if (output_type == "mean") { + metric_fns <- scoringutils::select_metrics(metric_fns, "se_point") + } else if (output_type == "median") { + metric_fns <- scoringutils::select_metrics(metric_fns, "ae_point") } + + return(metric_fns) } diff --git a/man/score_model_out.Rd b/man/score_model_out.Rd index 64b1209..9af9861 100644 --- a/man/score_model_out.Rd +++ b/man/score_model_out.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/score_model_out.R \name{score_model_out} \alias{score_model_out} -\title{Score model output predictions with a single \code{output_type} against observed data} +\title{Score model output predictions} \usage{ score_model_out( model_out_tbl, @@ -19,7 +19,9 @@ score_model_out( \item{target_observations}{Observed 'ground truth' data to be compared to predictions} -\item{metrics}{Optional character vector of scoring metrics to compute. See details for more.} +\item{metrics}{Character vector of scoring metrics to compute. If \code{NULL} +(the default), appropriate metrics are chosen automatically. See details +for more.} \item{summarize}{Boolean indicator of whether summaries of forecast scores should be computed. Defaults to \code{TRUE}.} @@ -33,57 +35,55 @@ vector of levels for pmf forecasts, in increasing order of the levels. For all other output types, this is ignored.} } \value{ -forecast_quantile +A data.table with scores } \description{ -Score model output predictions with a single \code{output_type} against observed data +Scores model outputs with a single \code{output_type} against observed data. } \details{ -If \code{metrics} is \code{NULL} (the default), this function chooses -appropriate metrics based on the \code{output_type} contained in the \code{model_out_tbl}: +Default metrics are provided by the \code{scoringutils} package. You can select +metrics by passing in a character vector of metric names to the \code{metrics} +argument. -\itemize{ -\item For \code{output_type == "quantile"}, we use the default metrics provided by -\code{scoringutils}: -wis, overprediction, underprediction, dispersion, bias, interval_coverage_50, interval_coverage_90, interval_coverage_deviation, ae_median -\item For \code{output_type == "pmf"} and \code{output_type_id_order} is \code{NULL} (indicating -that the predicted variable is a nominal variable), we use the default metric -provided by \code{scoringutils}: -log_score -\item For \code{output_type == "median"}, we use "ae_point" -\item For \code{output_type == "mean"}, we use "se_point" -} +The following metrics can be selected (all are used by default) for the +different \code{output_type}s: -Alternatively, a character vector of scoring metrics can be provided. In this -case, the following options are supported: -\itemize{ -\item \code{output_type == "median"} and \code{output_type == "mean"}: -\itemize{ -\item "ae_point": absolute error of a point prediction (generally recommended for the median) -\item "se_point": squared error of a point prediction (generally recommended for the mean) -} -\item \code{output_type == "quantile"}: +\strong{Quantile forecasts:} (\code{output_type == "quantile"}) \itemize{ -\item "ae_median": absolute error of the predictive median (i.e., the quantile at probability level 0.5) -\item "wis": weighted interval score (WIS) of a collection of quantile predictions -\item "overprediction": The component of WIS measuring the extent to which -predictions fell above the observation. -\item "underprediction": The component of WIS measuring the extent to which -predictions fell below the observation. -\item "dispersion": The component of WIS measuring the dispersion of forecast -distributions. +\item wis +\item overprediction +\item underprediction +\item dispersion +\item bias +\item interval_coverage_deviation +\item ae_median \item "interval_coverage_XX": interval coverage at the "XX" level. For example, "interval_coverage_95" is the 95\% interval coverage rate, which would be calculated based on quantiles at the probability levels 0.025 and 0.975. } -\item \code{output_type == "pmf"}: + +See \link[scoringutils:get_metrics.forecast_quantile]{scoringutils::get_metrics.forecast_quantile} for details. + +\strong{Nominal forecasts:} (\code{output_type == "pmf"} and \code{output_type_id_order} is \code{NULL}) \itemize{ -\item "log_score": log score +\item log_score } + +(scoring for ordinal forecasts will be added in the future). + +See \link[scoringutils:get_metrics.forecast_nominal]{scoringutils::get_metrics.forecast_nominal} for details. + +\strong{Median forecasts:} (\code{output_type == "median"}) +\itemize{ +\item ae_point: absolute error of the point forecast (recommended for the median, see Gneiting (2011)) } -See \code{\link[scoringutils:get_metrics]{scoringutils::get_metrics()}} for more details on the default metrics -used by \code{scoringutils}. +See \link[scoringutils:get_metrics.forecast_point]{scoringutils::get_metrics.forecast_point} for details. + +\strong{Mean forecasts:} (\code{output_type == "mean"}) +\itemize{ +\item \code{se_point}: squared error of the point forecast (recommended for the mean, see Gneiting (2011)) +} } \examples{ \dontshow{if (requireNamespace("hubExamples", quietly = TRUE)) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} @@ -113,3 +113,7 @@ pmf_scores <- score_model_out( head(pmf_scores) \dontshow{\}) # examplesIf} } +\references{ +#' Making and Evaluating Point Forecasts, Gneiting, Tilmann, 2011, +Journal of the American Statistical Association. +} diff --git a/tests/testthat/test-score_model_out.R b/tests/testthat/test-score_model_out.R index 6d02eeb..6437827 100644 --- a/tests/testthat/test-score_model_out.R +++ b/tests/testthat/test-score_model_out.R @@ -107,7 +107,6 @@ test_that("score_model_out succeeds with valid inputs: mean output_type, charact c("model_id", "location") ))) |> dplyr::summarize( - ae_point = mean(.data[["ae"]]), se_point = mean(.data[["se"]]), .groups = "drop" ) @@ -380,7 +379,7 @@ test_that("score_model_out errors when model_out_tbl has multiple output_types", }) -test_that("score_model_out errors when invalid interval levels are requested", { +test_that("score_model_out works with all kinds of interval levels are requested", { # Forecast data from HubExamples: load(test_path("testdata/forecast_outputs.rda")) # sets forecast_outputs load(test_path("testdata/forecast_target_observations.rda")) # sets forecast_target_observations @@ -389,27 +388,36 @@ test_that("score_model_out errors when invalid interval levels are requested", { score_model_out( model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), target_observations = forecast_target_observations, - metrics = "interval_level_5" + metrics = "interval_coverage_5d2a" ), - regexp = "unsupported metric" + regexp = "must be a number between 0 and 100" ) - expect_error( + expect_warning( score_model_out( model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), target_observations = forecast_target_observations, - metrics = "interval_level_100" + metrics = "interval_coverage_55" ), - regexp = "unsupported metric" + "To compute the interval coverage for an interval range of" #scoringutils warning ) expect_error( score_model_out( model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), target_observations = forecast_target_observations, - metrics = "interval_level_XY" + metrics = "interval_coverage_100" + ), + regexp = "must be a number between 0 and 100" + ) + + expect_warning( + score_model_out( + model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), + target_observations = forecast_target_observations, + metrics = "interval_coverage_5.3" ), - regexp = "unsupported metric" + "To compute the interval coverage for an interval range of" #scoringutils warning ) }) @@ -425,7 +433,7 @@ test_that("score_model_out errors when invalid metrics are requested", { target_observations = forecast_target_observations, metrics = "log_score" ), - regexp = "unsupported metric" + regexp = "has additional elements" ) expect_error( @@ -434,7 +442,8 @@ test_that("score_model_out errors when invalid metrics are requested", { target_observations = forecast_target_observations, metrics = list(5, 6, "asdf") ), - regexp = "`metrics` must be either `NULL` or a character vector of supported metrics." + regexp = + "^Assertion on 'c\\(select, exclude\\)' failed: Must be of type 'character' \\(or 'NULL'\\), not 'list'\\.$" ) expect_error( @@ -444,7 +453,8 @@ test_that("score_model_out errors when invalid metrics are requested", { metrics = scoringutils::get_metrics(scoringutils::example_point), by = c("model_id", "location") ), - regexp = "`metrics` must be either `NULL` or a character vector of supported metrics." + regexp = + "^Assertion on 'c\\(select, exclude\\)' failed: Must be of type 'character' \\(or 'NULL'\\), not 'list'\\.$" ) })