From 1dcec1d81ac2268302421c82a120b8fffe330b3e Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Mon, 16 Sep 2024 18:04:52 +0200 Subject: [PATCH 01/10] Fix code after changes in scoringutils --- R/score_model_out.R | 46 ++++++++++++------- tests/testthat/test-score_model_out.R | 3 +- .../testthat/test-transform_point_model_out.R | 9 +++- .../test-transform_quantile_model_out.R | 7 ++- 4 files changed, 43 insertions(+), 22 deletions(-) diff --git a/R/score_model_out.R b/R/score_model_out.R index a2c8525..1b35f16 100644 --- a/R/score_model_out.R +++ b/R/score_model_out.R @@ -15,16 +15,16 @@ #' #' @details If `metrics` is `NULL` (the default), this function chooses #' appropriate metrics based on the `output_type` contained in the `model_out_tbl`: -#' \itemize{ -#' \item For `output_type == "quantile"`, we use the default metrics provided by -#' `scoringutils::metrics_quantile()`: `r names(scoringutils::metrics_quantile())` -#' \item For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating +#' +#' - For `output_type == "quantile"`, we use the default metrics provided by +#' `scoringutils`: +#' `r names(scoringutils::get_metrics(scoringutils::example_quantile))` +#' - For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating #' that the predicted variable is a nominal variable), we use the default metric -#' provided by `scoringutils::metrics_nominal()`, -#' `r names(scoringutils::metrics_nominal())` -#' \item For `output_type == "median"`, we use "ae_point" -#' \item For `output_type == "mean"`, we use "se_point" -#' } +#' provided by `scoringutils`:, +#' `r names(scoringutils::get_metrics(scoringutils::example_nominal))` +#' - For `output_type == "median"`, we use "ae_point" +#' - For `output_type == "mean"`, we use "se_point" #' #' Alternatively, a character vector of scoring metrics can be provided. In this #' case, the following options are supported: @@ -46,6 +46,9 @@ #' - `output_type == "pmf"`: #' - "log_score": log score #' +#' See [scoringutils::get_metrics()] for more details on the default meterics +#' used by `scoringutils`. +#' #' @examplesIf requireNamespace("hubExamples", quietly = TRUE) #' # compute WIS and interval coverage rates at 80% and 90% levels based on #' # quantile forecasts, summarized by the mean score for each model @@ -143,10 +146,10 @@ get_metrics <- function(metrics, output_type, output_type_id_order) { #' @noRd get_metrics_default <- function(output_type, output_type_id_order) { metrics <- switch(output_type, - quantile = scoringutils::metrics_quantile(), - pmf = scoringutils::metrics_nominal(), - mean = scoringutils::metrics_point(select = "se_point"), - median = scoringutils::metrics_point(select = "ae_point"), + quantile = scoringutils::get_metrics(scoringutils::example_quantile), + pmf = scoringutils::get_metrics(scoringutils::example_nominal), + mean = scoringutils::get_metrics(scoringutils::example_point, select = "se_point"), + median = scoringutils::get_metrics(scoringutils::example_point, select = "ae_point"), NULL # default ) if (is.null(metrics)) { @@ -199,7 +202,10 @@ get_metrics_character <- function(metrics, output_type) { ) names(interval_metric_fns) <- interval_metrics - other_metric_fns <- scoringutils::metrics_quantile(select = other_metrics) + other_metric_fns <- scoringutils::get_metrics( + scoringutils::example_quantile, + select = other_metrics + ) metric_fns <- c(other_metric_fns, interval_metric_fns)[metrics] metrics <- metric_fns @@ -208,13 +214,19 @@ get_metrics_character <- function(metrics, output_type) { invalid_metrics <- metrics[!metrics %in% valid_metrics] error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type) - metrics <- scoringutils::metrics_nominal(select = metrics) + metrics <- scoringutils::get_metrics( + scoringutils::example_nominal, + select = metrics + ) } else if (output_type %in% c("median", "mean")) { valid_metrics <- c("ae_point", "se_point") invalid_metrics <- metrics[!metrics %in% valid_metrics] error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type) - metrics <- scoringutils::metrics_point(select = metrics) + metrics <- scoringutils::get_metrics( + scoringutils::example_point, + select = metrics + ) } else { # we have already validated `output_type`, so this case should not be # triggered; this case is just double checking in case we add something new @@ -231,7 +243,7 @@ error_if_invalid_metrics <- function(valid_metrics, invalid_metrics, output_type if (n > 0) { cli::cli_abort( c( - "`metrics` had {n} unsupported metric{?s} for + "`metrics` had {n} unsupported metric{?s} for {.arg output_type} {.val {output_type}}: {.strong {.val {invalid_metrics}}}; supported metrics include {.val {valid_metrics}}.", comment diff --git a/tests/testthat/test-score_model_out.R b/tests/testthat/test-score_model_out.R index 8987ee9..8dc8efe 100644 --- a/tests/testthat/test-score_model_out.R +++ b/tests/testthat/test-score_model_out.R @@ -441,7 +441,7 @@ test_that("score_model_out errors when invalid metrics are requested", { score_model_out( model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "mean"), target_observations = forecast_target_observations, - metrics = scoringutils::metrics_point(), + metrics = scoringutils::get_metrics(scoringutils::example_point), by = c("model_id", "location") ), regexp = "`metrics` must be either `NULL` or a character vector of supported metrics." @@ -464,3 +464,4 @@ test_that("score_model_out errors when an unsupported output_type is provided", regexp = "only supports the following types" ) }) + diff --git a/tests/testthat/test-transform_point_model_out.R b/tests/testthat/test-transform_point_model_out.R index 503c448..b55bc4e 100644 --- a/tests/testthat/test-transform_point_model_out.R +++ b/tests/testthat/test-transform_point_model_out.R @@ -140,6 +140,11 @@ test_that("hubExamples data set is transformed correctly", { reference_date = as.Date(reference_date, "%Y-%m-%d"), target_end_date = as.Date(target_end_date, "%Y-%m-%d") ) - class(exp_forecast) <- c("forecast_point", "forecast", "data.table", "data.frame") - expect_equal(act_forecast, exp_forecast) + expect_s3_class( + act_forecast, + c("forecast_point", "forecast", "data.table", "data.frame") + ) + expect_equal(as.data.frame(act_forecast), as.data.frame(exp_forecast)) }) + + diff --git a/tests/testthat/test-transform_quantile_model_out.R b/tests/testthat/test-transform_quantile_model_out.R index bdc01a0..0e9f14b 100644 --- a/tests/testthat/test-transform_quantile_model_out.R +++ b/tests/testthat/test-transform_quantile_model_out.R @@ -89,6 +89,9 @@ test_that("hubExamples data set is transformed correctly", { reference_date = as.Date(reference_date, "%Y-%m-%d"), target_end_date = as.Date(target_end_date, "%Y-%m-%d") ) - class(exp_forecast) <- c("forecast", "forecast_quantile", "data.table", "data.frame") - expect_equal(act_forecast, exp_forecast, ignore_attr = "class") + expect_s3_class( + act_forecast, + c("forecast_quantile", "forecast", "data.table", "data.frame") + ) + expect_equal(as.data.frame(act_forecast), as.data.frame(exp_forecast)) }) From 937187aeb24bf00a1ec6791aab7daa0a4bc0106e Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Mon, 16 Sep 2024 18:09:36 +0200 Subject: [PATCH 02/10] Update docs --- man/score_model_out.Rd | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/man/score_model_out.Rd b/man/score_model_out.Rd index 8c42a92..94fd2d0 100644 --- a/man/score_model_out.Rd +++ b/man/score_model_out.Rd @@ -43,14 +43,17 @@ If \code{metrics} is \code{NULL} (the default), this function chooses appropriate metrics based on the \code{output_type} contained in the \code{model_out_tbl}: \itemize{ \item For \code{output_type == "quantile"}, we use the default metrics provided by -\code{scoringutils::metrics_quantile()}: wis, overprediction, underprediction, dispersion, bias, interval_coverage_50, interval_coverage_90, interval_coverage_deviation, ae_median +\code{scoringutils}: +\verb{r names(scoringutils::get_metrics(scoringutils::example_quantile))} \item For \code{output_type == "pmf"} and \code{output_type_id_order} is \code{NULL} (indicating that the predicted variable is a nominal variable), we use the default metric -provided by \code{scoringutils::metrics_nominal()}, -log_score +provided by \code{scoringutils}:, +\verb{r names(scoringutils::get_metrics(scoringutils::example_nominal))} +\itemize{ \item For \code{output_type == "median"}, we use "ae_point" \item For \code{output_type == "mean"}, we use "se_point" } +} Alternatively, a character vector of scoring metrics can be provided. In this case, the following options are supported: @@ -79,6 +82,9 @@ based on quantiles at the probability levels 0.025 and 0.975. \item "log_score": log score } } + +See \code{\link[scoringutils:get_metrics]{scoringutils::get_metrics()}} for more details on the default meterics +used by \code{scoringutils}. } \examples{ \dontshow{if (requireNamespace("hubExamples", quietly = TRUE)) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} From 31993ef756e42aa9fac2a45344736341f2df5079 Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Mon, 16 Sep 2024 18:12:43 +0200 Subject: [PATCH 03/10] fix linting issues --- tests/testthat/test-score_model_out.R | 1 - tests/testthat/test-transform_point_model_out.R | 2 -- 2 files changed, 3 deletions(-) diff --git a/tests/testthat/test-score_model_out.R b/tests/testthat/test-score_model_out.R index 8dc8efe..6d02eeb 100644 --- a/tests/testthat/test-score_model_out.R +++ b/tests/testthat/test-score_model_out.R @@ -464,4 +464,3 @@ test_that("score_model_out errors when an unsupported output_type is provided", regexp = "only supports the following types" ) }) - diff --git a/tests/testthat/test-transform_point_model_out.R b/tests/testthat/test-transform_point_model_out.R index b55bc4e..1bc6aa2 100644 --- a/tests/testthat/test-transform_point_model_out.R +++ b/tests/testthat/test-transform_point_model_out.R @@ -146,5 +146,3 @@ test_that("hubExamples data set is transformed correctly", { ) expect_equal(as.data.frame(act_forecast), as.data.frame(exp_forecast)) }) - - From 0d12262957b1c3d3657c1919a3addda42cb3dcf5 Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Mon, 16 Sep 2024 18:30:38 +0200 Subject: [PATCH 04/10] fix docs --- R/score_model_out.R | 14 ++++++++------ man/score_model_out.Rd | 11 +++++------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/R/score_model_out.R b/R/score_model_out.R index 1b35f16..0fd3a29 100644 --- a/R/score_model_out.R +++ b/R/score_model_out.R @@ -16,15 +16,17 @@ #' @details If `metrics` is `NULL` (the default), this function chooses #' appropriate metrics based on the `output_type` contained in the `model_out_tbl`: #' -#' - For `output_type == "quantile"`, we use the default metrics provided by +#' \itemize{ +#' \item For `output_type == "quantile"`, we use the default metrics provided by #' `scoringutils`: #' `r names(scoringutils::get_metrics(scoringutils::example_quantile))` -#' - For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating +#' \item For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating #' that the predicted variable is a nominal variable), we use the default metric -#' provided by `scoringutils`:, +#' provided by `scoringutils`: #' `r names(scoringutils::get_metrics(scoringutils::example_nominal))` -#' - For `output_type == "median"`, we use "ae_point" -#' - For `output_type == "mean"`, we use "se_point" +#' \item For `output_type == "median"`, we use "ae_point" +#' \item For `output_type == "mean"`, we use "se_point" +#' } #' #' Alternatively, a character vector of scoring metrics can be provided. In this #' case, the following options are supported: @@ -46,7 +48,7 @@ #' - `output_type == "pmf"`: #' - "log_score": log score #' -#' See [scoringutils::get_metrics()] for more details on the default meterics +#' See [scoringutils::get_metrics()] for more details on the default metrics #' used by `scoringutils`. #' #' @examplesIf requireNamespace("hubExamples", quietly = TRUE) diff --git a/man/score_model_out.Rd b/man/score_model_out.Rd index 94fd2d0..64b1209 100644 --- a/man/score_model_out.Rd +++ b/man/score_model_out.Rd @@ -41,19 +41,18 @@ Score model output predictions with a single \code{output_type} against observed \details{ If \code{metrics} is \code{NULL} (the default), this function chooses appropriate metrics based on the \code{output_type} contained in the \code{model_out_tbl}: + \itemize{ \item For \code{output_type == "quantile"}, we use the default metrics provided by \code{scoringutils}: -\verb{r names(scoringutils::get_metrics(scoringutils::example_quantile))} +wis, overprediction, underprediction, dispersion, bias, interval_coverage_50, interval_coverage_90, interval_coverage_deviation, ae_median \item For \code{output_type == "pmf"} and \code{output_type_id_order} is \code{NULL} (indicating that the predicted variable is a nominal variable), we use the default metric -provided by \code{scoringutils}:, -\verb{r names(scoringutils::get_metrics(scoringutils::example_nominal))} -\itemize{ +provided by \code{scoringutils}: +log_score \item For \code{output_type == "median"}, we use "ae_point" \item For \code{output_type == "mean"}, we use "se_point" } -} Alternatively, a character vector of scoring metrics can be provided. In this case, the following options are supported: @@ -83,7 +82,7 @@ based on quantiles at the probability levels 0.025 and 0.975. } } -See \code{\link[scoringutils:get_metrics]{scoringutils::get_metrics()}} for more details on the default meterics +See \code{\link[scoringutils:get_metrics]{scoringutils::get_metrics()}} for more details on the default metrics used by \code{scoringutils}. } \examples{ From fbd7cbaf1e2a50b77f382580ce6f25c4b1d8d68e Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Tue, 17 Sep 2024 00:04:15 +0200 Subject: [PATCH 05/10] Simplify score_model_out --- .Rbuildignore | 1 + .gitignore | 2 + DESCRIPTION | 2 +- R/score_model_out.R | 233 +++++++++----------------- man/score_model_out.Rd | 80 ++++----- tests/testthat/test-score_model_out.R | 34 ++-- 6 files changed, 144 insertions(+), 208 deletions(-) diff --git a/.Rbuildignore b/.Rbuildignore index e57306d..8ca8d33 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -15,3 +15,4 @@ ^\.Rdata$ ^\.httr-oauth$ ^\.secrets$ +^.vscode diff --git a/.gitignore b/.gitignore index 440e2e6..27440c7 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,5 @@ docs .Rdata .secrets .quarto + +.vscode/ diff --git a/DESCRIPTION b/DESCRIPTION index c711112..a832dd0 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -51,7 +51,7 @@ Remotes: hubverse-org/hubExamples, hubverse-org/hubUtils Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 URL: https://hubverse-org.github.io/hubEvals/ Depends: R (>= 2.10) diff --git a/R/score_model_out.R b/R/score_model_out.R index 0fd3a29..f74350b 100644 --- a/R/score_model_out.R +++ b/R/score_model_out.R @@ -1,9 +1,13 @@ -#' Score model output predictions with a single `output_type` against observed data +#' Score model output predictions +#' +#' Scores model outputs with a single `output_type` against observed data. #' #' @param model_out_tbl Model output tibble with predictions #' @param target_observations Observed 'ground truth' data to be compared to #' predictions -#' @param metrics Optional character vector of scoring metrics to compute. See details for more. +#' @param metrics Character vector of scoring metrics to compute. If `NULL` +#' (the default), appropriate metrics are chosen automatically. See details +#' for more. #' @param summarize Boolean indicator of whether summaries of forecast scores #' should be computed. Defaults to `TRUE`. #' @param by Character vector naming columns to summarize by. For example, @@ -13,43 +17,40 @@ #' vector of levels for pmf forecasts, in increasing order of the levels. For #' all other output types, this is ignored. #' -#' @details If `metrics` is `NULL` (the default), this function chooses -#' appropriate metrics based on the `output_type` contained in the `model_out_tbl`: -#' -#' \itemize{ -#' \item For `output_type == "quantile"`, we use the default metrics provided by -#' `scoringutils`: -#' `r names(scoringutils::get_metrics(scoringutils::example_quantile))` -#' \item For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating -#' that the predicted variable is a nominal variable), we use the default metric -#' provided by `scoringutils`: -#' `r names(scoringutils::get_metrics(scoringutils::example_nominal))` -#' \item For `output_type == "median"`, we use "ae_point" -#' \item For `output_type == "mean"`, we use "se_point" -#' } -#' -#' Alternatively, a character vector of scoring metrics can be provided. In this -#' case, the following options are supported: -#' - `output_type == "median"` and `output_type == "mean"`: -#' - "ae_point": absolute error of a point prediction (generally recommended for the median) -#' - "se_point": squared error of a point prediction (generally recommended for the mean) -#' - `output_type == "quantile"`: -#' - "ae_median": absolute error of the predictive median (i.e., the quantile at probability level 0.5) -#' - "wis": weighted interval score (WIS) of a collection of quantile predictions -#' - "overprediction": The component of WIS measuring the extent to which -#' predictions fell above the observation. -#' - "underprediction": The component of WIS measuring the extent to which -#' predictions fell below the observation. -#' - "dispersion": The component of WIS measuring the dispersion of forecast -#' distributions. -#' - "interval_coverage_XX": interval coverage at the "XX" level. For example, +#' @details +#' Default metrics are provided by the `scoringutils` package. You can select +#' metrics by passing in a character vector of metric names to the `metrics` +#' argument. +#' +#' The following metrics can be selected (all are used by default) for the +#' different `output_type`s: +#' +#' **Quantile forecasts:** (`output_type == "quantile"`) +#' `r exclude <- c("interval_coverage_50", "interval_coverage_90")` +#' `r metrics <- scoringutils::get_metrics(scoringutils::example_quantile, exclude = exclude)` +#' `r paste("- ", names(metrics), collapse = "\n")` +#' - "interval_coverage_XX": interval coverage at the "XX" level. For example, #' "interval_coverage_95" is the 95% interval coverage rate, which would be calculated #' based on quantiles at the probability levels 0.025 and 0.975. -#' - `output_type == "pmf"`: -#' - "log_score": log score #' -#' See [scoringutils::get_metrics()] for more details on the default metrics -#' used by `scoringutils`. +#' See [scoringutils::get_metrics.forecast_quantile] for details. +#' +#' **Nominal forecasts:** (`output_type == "pmf"` and `output_type_id_order` is `NULL`) +#' +#' `r paste("- ", names(scoringutils::get_metrics(scoringutils::example_nominal)), collapse = "\n")` +#' +#' (scoring for ordinal forecasts will be added in the future). +#' +#' See [scoringutils::get_metrics.forecast_nominal] for details. +#' +#' **Median forecasts:** (`output_type == "median"`) +#' +#' - ae_point: absolute error of the point forecast (recommended for the median, see Gneiting (2011)) +#' +#' See [scoringutils::get_metrics.forecast_point] for details. +#' +#' **Mean forecasts:** (`output_type == "mean"`) +#' - `se_point`: squared error of the point forecast (recommended for the mean, see Gneiting (2011)) #' #' @examplesIf requireNamespace("hubExamples", quietly = TRUE) #' # compute WIS and interval coverage rates at 80% and 90% levels based on @@ -77,7 +78,11 @@ #' ) #' head(pmf_scores) #' -#' @return forecast_quantile +#' @return A data.table with scores +#' +#' @references +#' #' Making and Evaluating Point Forecasts, Gneiting, Tilmann, 2011, +#' Journal of the American Statistical Association. #' #' @export score_model_out <- function(model_out_tbl, target_observations, metrics = NULL, @@ -87,9 +92,6 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL, # also, retrieve that output_type output_type <- validate_output_type(model_out_tbl) - # get/validate the scoring metrics - metrics <- get_metrics(metrics, output_type, output_type_id_order) - # assemble data for scoringutils su_data <- switch(output_type, quantile = transform_quantile_model_out(model_out_tbl, target_observations), @@ -99,16 +101,15 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL, NULL # default, should not happen because of the validation above ) + # get/validate the scoring metrics + metrics <- get_metrics(forecast = su_data, output_type = output_type, select = metrics) + # compute scores scores <- scoringutils::score(su_data, metrics) # switch back to hubverse naming conventions for model name scores <- dplyr::rename(scores, model_id = "model") - # if present, drop predicted and observed columns - drop_cols <- c("predicted", "observed") - scores <- scores[!colnames(scores) %in% drop_cols] - # if requested, summarize scores if (summarize) { scores <- scoringutils::summarize_scores(scores = scores, by = by) @@ -118,138 +119,56 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL, } -#' Get metrics if user didn't specify anything; otherwise, process -#' and validate user inputs -#' -#' @inheritParams score_model_out -#' -#' @return a list of metric functions as required by scoringutils::score() -#' -#' @noRd -get_metrics <- function(metrics, output_type, output_type_id_order) { - if (is.null(metrics)) { - return(get_metrics_default(output_type, output_type_id_order)) - } else if (is.character(metrics)) { - return(get_metrics_character(metrics, output_type)) - } else { - cli::cli_abort( - "{.arg metrics} must be either `NULL` or a character vector of supported metrics." - ) - } -} - - -#' Default metrics if user didn't specify anything +#' Get scoring metrics #' +#' @param forecast A scoringutils `forecast` object (see +#' [scoringutils::as_forecast()] for details). #' @inheritParams score_model_out #' #' @return a list of metric functions as required by scoringutils::score() #' #' @noRd -get_metrics_default <- function(output_type, output_type_id_order) { - metrics <- switch(output_type, - quantile = scoringutils::get_metrics(scoringutils::example_quantile), - pmf = scoringutils::get_metrics(scoringutils::example_nominal), - mean = scoringutils::get_metrics(scoringutils::example_point, select = "se_point"), - median = scoringutils::get_metrics(scoringutils::example_point, select = "ae_point"), - NULL # default - ) - if (is.null(metrics)) { - # we have already validated `output_type`, so this case should not be - # triggered; this case is just double checking in case we add something new - # later, to ensure we update this function. - supported_types <- c("mean", "median", "pmf", "quantile") # nolint object_use_linter - cli::cli_abort( - "Provided `model_out_tbl` contains `output_type` {.val {output_type}}; - hubEvals currently only supports the following types: - {.val {supported_types}}" - ) - } - - return(metrics) -} +get_metrics <- function(forecast, output_type, select = NULL) { + forecast_type <- class(forecast)[1] - -#' Convert character vector of metrics to list of functions -#' -#' @inheritParams score_model_out -#' -#' @return a list of metric functions as required by scoringutils::score() -#' -#' @noRd -get_metrics_character <- function(metrics, output_type) { - if (output_type == "quantile") { + # process quantile metrics separately to allow better selection of interval + # coverage metrics + if (forecast_type == "forecast_quantile") { # split into metrics for interval coverage and others - interval_metric_inds <- grepl(pattern = "^interval_coverage_[[:digit:]][[:digit:]]$", metrics) - interval_metrics <- metrics[interval_metric_inds] - other_metrics <- metrics[!interval_metric_inds] + interval_metric_inds <- grepl(pattern = "interval_coverage_", select) + interval_metrics <- select[interval_metric_inds] + other_metrics <- select[!interval_metric_inds] - # validate metrics - valid_metrics <- c("ae_median", "wis", "overprediction", "underprediction", "dispersion") - invalid_metrics <- other_metrics[!other_metrics %in% valid_metrics] - error_if_invalid_metrics( - valid_metrics = c(valid_metrics, "interval_coverage_XY"), - invalid_metrics = invalid_metrics, - output_type = output_type, - comment = c("i" = "NOTE: `XY` denotes the coverage level, e.g. {.val interval_coverage_95}.") - ) + other_metric_fns <- scoringutils::get_metrics(forecast, select = other_metrics) - # assemble metric functions + # assemble interval coverage functions interval_metric_fns <- lapply( interval_metrics, function(metric) { - level <- as.integer(substr(metric, 19, 20)) + level_str <- substr(metric, 19, nchar(metric)) + level <- suppressWarnings(as.numeric(level_str)) + if (is.na(level) || level <= 0 || level >= 100) { + stop(paste( + "Invalid interval coverage level:", level_str, + "- must be a number between 0 and 100 (exclusive)" + )) + } return(purrr::partial(scoringutils::interval_coverage, interval_range = level)) } ) names(interval_metric_fns) <- interval_metrics - other_metric_fns <- scoringutils::get_metrics( - scoringutils::example_quantile, - select = other_metrics - ) - - metric_fns <- c(other_metric_fns, interval_metric_fns)[metrics] - metrics <- metric_fns - } else if (output_type == "pmf") { - valid_metrics <- c("log_score") - invalid_metrics <- metrics[!metrics %in% valid_metrics] - error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type) - - metrics <- scoringutils::get_metrics( - scoringutils::example_nominal, - select = metrics - ) - } else if (output_type %in% c("median", "mean")) { - valid_metrics <- c("ae_point", "se_point") - invalid_metrics <- metrics[!metrics %in% valid_metrics] - error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type) - - metrics <- scoringutils::get_metrics( - scoringutils::example_point, - select = metrics - ) - } else { - # we have already validated `output_type`, so this case should not be - # triggered; this case is just double checking in case we add something new - # later, to ensure we update this function. - error_if_invalid_output_type(output_type) + metric_fns <- c(other_metric_fns, interval_metric_fns) + return(metric_fns) } - return(metrics) -} - - -error_if_invalid_metrics <- function(valid_metrics, invalid_metrics, output_type, comment = NULL) { - n <- length(invalid_metrics) - if (n > 0) { - cli::cli_abort( - c( - "`metrics` had {n} unsupported metric{?s} for - {.arg output_type} {.val {output_type}}: {.strong {.val {invalid_metrics}}}; - supported metrics include {.val {valid_metrics}}.", - comment - ) - ) + # leave validation of user selection to scoringutils + metric_fns <- scoringutils::get_metrics(forecast, select = select) + if (output_type == "mean") { + metric_fns <- scoringutils::select_metrics(metric_fns, "se_point") + } else if (output_type == "median") { + metric_fns <- scoringutils::select_metrics(metric_fns, "ae_point") } + + return(metric_fns) } diff --git a/man/score_model_out.Rd b/man/score_model_out.Rd index 64b1209..9af9861 100644 --- a/man/score_model_out.Rd +++ b/man/score_model_out.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/score_model_out.R \name{score_model_out} \alias{score_model_out} -\title{Score model output predictions with a single \code{output_type} against observed data} +\title{Score model output predictions} \usage{ score_model_out( model_out_tbl, @@ -19,7 +19,9 @@ score_model_out( \item{target_observations}{Observed 'ground truth' data to be compared to predictions} -\item{metrics}{Optional character vector of scoring metrics to compute. See details for more.} +\item{metrics}{Character vector of scoring metrics to compute. If \code{NULL} +(the default), appropriate metrics are chosen automatically. See details +for more.} \item{summarize}{Boolean indicator of whether summaries of forecast scores should be computed. Defaults to \code{TRUE}.} @@ -33,57 +35,55 @@ vector of levels for pmf forecasts, in increasing order of the levels. For all other output types, this is ignored.} } \value{ -forecast_quantile +A data.table with scores } \description{ -Score model output predictions with a single \code{output_type} against observed data +Scores model outputs with a single \code{output_type} against observed data. } \details{ -If \code{metrics} is \code{NULL} (the default), this function chooses -appropriate metrics based on the \code{output_type} contained in the \code{model_out_tbl}: +Default metrics are provided by the \code{scoringutils} package. You can select +metrics by passing in a character vector of metric names to the \code{metrics} +argument. -\itemize{ -\item For \code{output_type == "quantile"}, we use the default metrics provided by -\code{scoringutils}: -wis, overprediction, underprediction, dispersion, bias, interval_coverage_50, interval_coverage_90, interval_coverage_deviation, ae_median -\item For \code{output_type == "pmf"} and \code{output_type_id_order} is \code{NULL} (indicating -that the predicted variable is a nominal variable), we use the default metric -provided by \code{scoringutils}: -log_score -\item For \code{output_type == "median"}, we use "ae_point" -\item For \code{output_type == "mean"}, we use "se_point" -} +The following metrics can be selected (all are used by default) for the +different \code{output_type}s: -Alternatively, a character vector of scoring metrics can be provided. In this -case, the following options are supported: -\itemize{ -\item \code{output_type == "median"} and \code{output_type == "mean"}: -\itemize{ -\item "ae_point": absolute error of a point prediction (generally recommended for the median) -\item "se_point": squared error of a point prediction (generally recommended for the mean) -} -\item \code{output_type == "quantile"}: +\strong{Quantile forecasts:} (\code{output_type == "quantile"}) \itemize{ -\item "ae_median": absolute error of the predictive median (i.e., the quantile at probability level 0.5) -\item "wis": weighted interval score (WIS) of a collection of quantile predictions -\item "overprediction": The component of WIS measuring the extent to which -predictions fell above the observation. -\item "underprediction": The component of WIS measuring the extent to which -predictions fell below the observation. -\item "dispersion": The component of WIS measuring the dispersion of forecast -distributions. +\item wis +\item overprediction +\item underprediction +\item dispersion +\item bias +\item interval_coverage_deviation +\item ae_median \item "interval_coverage_XX": interval coverage at the "XX" level. For example, "interval_coverage_95" is the 95\% interval coverage rate, which would be calculated based on quantiles at the probability levels 0.025 and 0.975. } -\item \code{output_type == "pmf"}: + +See \link[scoringutils:get_metrics.forecast_quantile]{scoringutils::get_metrics.forecast_quantile} for details. + +\strong{Nominal forecasts:} (\code{output_type == "pmf"} and \code{output_type_id_order} is \code{NULL}) \itemize{ -\item "log_score": log score +\item log_score } + +(scoring for ordinal forecasts will be added in the future). + +See \link[scoringutils:get_metrics.forecast_nominal]{scoringutils::get_metrics.forecast_nominal} for details. + +\strong{Median forecasts:} (\code{output_type == "median"}) +\itemize{ +\item ae_point: absolute error of the point forecast (recommended for the median, see Gneiting (2011)) } -See \code{\link[scoringutils:get_metrics]{scoringutils::get_metrics()}} for more details on the default metrics -used by \code{scoringutils}. +See \link[scoringutils:get_metrics.forecast_point]{scoringutils::get_metrics.forecast_point} for details. + +\strong{Mean forecasts:} (\code{output_type == "mean"}) +\itemize{ +\item \code{se_point}: squared error of the point forecast (recommended for the mean, see Gneiting (2011)) +} } \examples{ \dontshow{if (requireNamespace("hubExamples", quietly = TRUE)) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} @@ -113,3 +113,7 @@ pmf_scores <- score_model_out( head(pmf_scores) \dontshow{\}) # examplesIf} } +\references{ +#' Making and Evaluating Point Forecasts, Gneiting, Tilmann, 2011, +Journal of the American Statistical Association. +} diff --git a/tests/testthat/test-score_model_out.R b/tests/testthat/test-score_model_out.R index 6d02eeb..6437827 100644 --- a/tests/testthat/test-score_model_out.R +++ b/tests/testthat/test-score_model_out.R @@ -107,7 +107,6 @@ test_that("score_model_out succeeds with valid inputs: mean output_type, charact c("model_id", "location") ))) |> dplyr::summarize( - ae_point = mean(.data[["ae"]]), se_point = mean(.data[["se"]]), .groups = "drop" ) @@ -380,7 +379,7 @@ test_that("score_model_out errors when model_out_tbl has multiple output_types", }) -test_that("score_model_out errors when invalid interval levels are requested", { +test_that("score_model_out works with all kinds of interval levels are requested", { # Forecast data from HubExamples: load(test_path("testdata/forecast_outputs.rda")) # sets forecast_outputs load(test_path("testdata/forecast_target_observations.rda")) # sets forecast_target_observations @@ -389,27 +388,36 @@ test_that("score_model_out errors when invalid interval levels are requested", { score_model_out( model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), target_observations = forecast_target_observations, - metrics = "interval_level_5" + metrics = "interval_coverage_5d2a" ), - regexp = "unsupported metric" + regexp = "must be a number between 0 and 100" ) - expect_error( + expect_warning( score_model_out( model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), target_observations = forecast_target_observations, - metrics = "interval_level_100" + metrics = "interval_coverage_55" ), - regexp = "unsupported metric" + "To compute the interval coverage for an interval range of" #scoringutils warning ) expect_error( score_model_out( model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), target_observations = forecast_target_observations, - metrics = "interval_level_XY" + metrics = "interval_coverage_100" + ), + regexp = "must be a number between 0 and 100" + ) + + expect_warning( + score_model_out( + model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), + target_observations = forecast_target_observations, + metrics = "interval_coverage_5.3" ), - regexp = "unsupported metric" + "To compute the interval coverage for an interval range of" #scoringutils warning ) }) @@ -425,7 +433,7 @@ test_that("score_model_out errors when invalid metrics are requested", { target_observations = forecast_target_observations, metrics = "log_score" ), - regexp = "unsupported metric" + regexp = "has additional elements" ) expect_error( @@ -434,7 +442,8 @@ test_that("score_model_out errors when invalid metrics are requested", { target_observations = forecast_target_observations, metrics = list(5, 6, "asdf") ), - regexp = "`metrics` must be either `NULL` or a character vector of supported metrics." + regexp = + "^Assertion on 'c\\(select, exclude\\)' failed: Must be of type 'character' \\(or 'NULL'\\), not 'list'\\.$" ) expect_error( @@ -444,7 +453,8 @@ test_that("score_model_out errors when invalid metrics are requested", { metrics = scoringutils::get_metrics(scoringutils::example_point), by = c("model_id", "location") ), - regexp = "`metrics` must be either `NULL` or a character vector of supported metrics." + regexp = + "^Assertion on 'c\\(select, exclude\\)' failed: Must be of type 'character' \\(or 'NULL'\\), not 'list'\\.$" ) }) From 823ddc9f5e9c89e350090db2aa1584eb96cae8d1 Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Tue, 17 Sep 2024 00:11:47 +0200 Subject: [PATCH 06/10] update docs --- R/score_model_out.R | 2 +- man/score_model_out.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/score_model_out.R b/R/score_model_out.R index f74350b..40205ec 100644 --- a/R/score_model_out.R +++ b/R/score_model_out.R @@ -81,7 +81,7 @@ #' @return A data.table with scores #' #' @references -#' #' Making and Evaluating Point Forecasts, Gneiting, Tilmann, 2011, +#' Making and Evaluating Point Forecasts, Gneiting, Tilmann, 2011, #' Journal of the American Statistical Association. #' #' @export diff --git a/man/score_model_out.Rd b/man/score_model_out.Rd index 9af9861..70b1f60 100644 --- a/man/score_model_out.Rd +++ b/man/score_model_out.Rd @@ -114,6 +114,6 @@ head(pmf_scores) \dontshow{\}) # examplesIf} } \references{ -#' Making and Evaluating Point Forecasts, Gneiting, Tilmann, 2011, +Making and Evaluating Point Forecasts, Gneiting, Tilmann, 2011, Journal of the American Statistical Association. } From 19c6692c0e743b4474acc682745e995f6492d553 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Tue, 17 Sep 2024 13:13:57 -0400 Subject: [PATCH 07/10] regexp cleanup; match start of string for interval_coverage_XX --- R/score_model_out.R | 2 +- tests/testthat/test-score_model_out.R | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/R/score_model_out.R b/R/score_model_out.R index 40205ec..32f8c86 100644 --- a/R/score_model_out.R +++ b/R/score_model_out.R @@ -135,7 +135,7 @@ get_metrics <- function(forecast, output_type, select = NULL) { # coverage metrics if (forecast_type == "forecast_quantile") { # split into metrics for interval coverage and others - interval_metric_inds <- grepl(pattern = "interval_coverage_", select) + interval_metric_inds <- grepl(pattern = "^interval_coverage_", select) interval_metrics <- select[interval_metric_inds] other_metrics <- select[!interval_metric_inds] diff --git a/tests/testthat/test-score_model_out.R b/tests/testthat/test-score_model_out.R index 6437827..1c497f7 100644 --- a/tests/testthat/test-score_model_out.R +++ b/tests/testthat/test-score_model_out.R @@ -446,6 +446,16 @@ test_that("score_model_out errors when invalid metrics are requested", { "^Assertion on 'c\\(select, exclude\\)' failed: Must be of type 'character' \\(or 'NULL'\\), not 'list'\\.$" ) + expect_error( + score_model_out( + model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "quantile"), + target_observations = forecast_target_observations, + metrics = c("asdfinterval_coverage_90") + ), + regexp = + "has additional elements" + ) + expect_error( score_model_out( model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "mean"), From 99336ecc73bc75ce9cf03086890ed9b7a3c36506 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Tue, 17 Sep 2024 15:32:58 -0400 Subject: [PATCH 08/10] Update R/score_model_out.R Co-authored-by: Zhian N. Kamvar --- R/score_model_out.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R/score_model_out.R b/R/score_model_out.R index 32f8c86..d72b305 100644 --- a/R/score_model_out.R +++ b/R/score_model_out.R @@ -148,9 +148,9 @@ get_metrics <- function(forecast, output_type, select = NULL) { level_str <- substr(metric, 19, nchar(metric)) level <- suppressWarnings(as.numeric(level_str)) if (is.na(level) || level <= 0 || level >= 100) { - stop(paste( - "Invalid interval coverage level:", level_str, - "- must be a number between 0 and 100 (exclusive)" + cli::cli_abort(c( + "Invalid interval coverage level: {level_str}", + "i" = "must be a number between 0 and 100 (exclusive)" )) } return(purrr::partial(scoringutils::interval_coverage, interval_range = level)) From 13afc3af444e2263b2e26a18c26366f303d9eb0a Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Tue, 17 Sep 2024 15:33:10 -0400 Subject: [PATCH 09/10] Update R/score_model_out.R Co-authored-by: Zhian N. Kamvar --- R/score_model_out.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/score_model_out.R b/R/score_model_out.R index d72b305..d14c6ba 100644 --- a/R/score_model_out.R +++ b/R/score_model_out.R @@ -81,8 +81,8 @@ #' @return A data.table with scores #' #' @references -#' Making and Evaluating Point Forecasts, Gneiting, Tilmann, 2011, -#' Journal of the American Statistical Association. +#' Gneiting, Tilmann. 2011. "Making and Evaluating Point Forecasts." Journal of the +#' American Statistical Association 106 (494): 746–62. . #' #' @export score_model_out <- function(model_out_tbl, target_observations, metrics = NULL, From db560404ffb6bc3c47ecd70717c2700b6938cdc7 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Tue, 17 Sep 2024 15:35:45 -0400 Subject: [PATCH 10/10] remove trailing whitespace --- R/score_model_out.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/score_model_out.R b/R/score_model_out.R index d14c6ba..0769ed0 100644 --- a/R/score_model_out.R +++ b/R/score_model_out.R @@ -81,7 +81,7 @@ #' @return A data.table with scores #' #' @references -#' Gneiting, Tilmann. 2011. "Making and Evaluating Point Forecasts." Journal of the +#' Gneiting, Tilmann. 2011. "Making and Evaluating Point Forecasts." Journal of the #' American Statistical Association 106 (494): 746–62. . #' #' @export