Skip to content

Commit

Permalink
Simplify score_model_out
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse committed Sep 16, 2024
1 parent 0d12262 commit fbd7cba
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 208 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
^\.Rdata$
^\.httr-oauth$
^\.secrets$
^.vscode
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,5 @@ docs
.Rdata
.secrets
.quarto

.vscode/
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Remotes:
hubverse-org/hubExamples,
hubverse-org/hubUtils
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
URL: https://hubverse-org.github.io/hubEvals/
Depends:
R (>= 2.10)
Expand Down
233 changes: 76 additions & 157 deletions R/score_model_out.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#' Score model output predictions with a single `output_type` against observed data
#' Score model output predictions
#'
#' Scores model outputs with a single `output_type` against observed data.
#'
#' @param model_out_tbl Model output tibble with predictions
#' @param target_observations Observed 'ground truth' data to be compared to
#' predictions
#' @param metrics Optional character vector of scoring metrics to compute. See details for more.
#' @param metrics Character vector of scoring metrics to compute. If `NULL`
#' (the default), appropriate metrics are chosen automatically. See details
#' for more.
#' @param summarize Boolean indicator of whether summaries of forecast scores
#' should be computed. Defaults to `TRUE`.
#' @param by Character vector naming columns to summarize by. For example,
Expand All @@ -13,43 +17,40 @@
#' vector of levels for pmf forecasts, in increasing order of the levels. For
#' all other output types, this is ignored.
#'
#' @details If `metrics` is `NULL` (the default), this function chooses
#' appropriate metrics based on the `output_type` contained in the `model_out_tbl`:
#'
#' \itemize{
#' \item For `output_type == "quantile"`, we use the default metrics provided by
#' `scoringutils`:
#' `r names(scoringutils::get_metrics(scoringutils::example_quantile))`
#' \item For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating
#' that the predicted variable is a nominal variable), we use the default metric
#' provided by `scoringutils`:
#' `r names(scoringutils::get_metrics(scoringutils::example_nominal))`
#' \item For `output_type == "median"`, we use "ae_point"
#' \item For `output_type == "mean"`, we use "se_point"
#' }
#'
#' Alternatively, a character vector of scoring metrics can be provided. In this
#' case, the following options are supported:
#' - `output_type == "median"` and `output_type == "mean"`:
#' - "ae_point": absolute error of a point prediction (generally recommended for the median)
#' - "se_point": squared error of a point prediction (generally recommended for the mean)
#' - `output_type == "quantile"`:
#' - "ae_median": absolute error of the predictive median (i.e., the quantile at probability level 0.5)
#' - "wis": weighted interval score (WIS) of a collection of quantile predictions
#' - "overprediction": The component of WIS measuring the extent to which
#' predictions fell above the observation.
#' - "underprediction": The component of WIS measuring the extent to which
#' predictions fell below the observation.
#' - "dispersion": The component of WIS measuring the dispersion of forecast
#' distributions.
#' - "interval_coverage_XX": interval coverage at the "XX" level. For example,
#' @details
#' Default metrics are provided by the `scoringutils` package. You can select
#' metrics by passing in a character vector of metric names to the `metrics`
#' argument.
#'
#' The following metrics can be selected (all are used by default) for the
#' different `output_type`s:
#'
#' **Quantile forecasts:** (`output_type == "quantile"`)
#' `r exclude <- c("interval_coverage_50", "interval_coverage_90")`
#' `r metrics <- scoringutils::get_metrics(scoringutils::example_quantile, exclude = exclude)`
#' `r paste("- ", names(metrics), collapse = "\n")`
#' - "interval_coverage_XX": interval coverage at the "XX" level. For example,
#' "interval_coverage_95" is the 95% interval coverage rate, which would be calculated
#' based on quantiles at the probability levels 0.025 and 0.975.
#' - `output_type == "pmf"`:
#' - "log_score": log score
#'
#' See [scoringutils::get_metrics()] for more details on the default metrics
#' used by `scoringutils`.
#' See [scoringutils::get_metrics.forecast_quantile] for details.
#'
#' **Nominal forecasts:** (`output_type == "pmf"` and `output_type_id_order` is `NULL`)
#'
#' `r paste("- ", names(scoringutils::get_metrics(scoringutils::example_nominal)), collapse = "\n")`
#'
#' (scoring for ordinal forecasts will be added in the future).
#'
#' See [scoringutils::get_metrics.forecast_nominal] for details.
#'
#' **Median forecasts:** (`output_type == "median"`)
#'
#' - ae_point: absolute error of the point forecast (recommended for the median, see Gneiting (2011))
#'
#' See [scoringutils::get_metrics.forecast_point] for details.
#'
#' **Mean forecasts:** (`output_type == "mean"`)
#' - `se_point`: squared error of the point forecast (recommended for the mean, see Gneiting (2011))
#'
#' @examplesIf requireNamespace("hubExamples", quietly = TRUE)
#' # compute WIS and interval coverage rates at 80% and 90% levels based on
Expand Down Expand Up @@ -77,7 +78,11 @@
#' )
#' head(pmf_scores)
#'
#' @return forecast_quantile
#' @return A data.table with scores
#'
#' @references
#' #' Making and Evaluating Point Forecasts, Gneiting, Tilmann, 2011,
#' Journal of the American Statistical Association.
#'
#' @export
score_model_out <- function(model_out_tbl, target_observations, metrics = NULL,
Expand All @@ -87,9 +92,6 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL,
# also, retrieve that output_type
output_type <- validate_output_type(model_out_tbl)

# get/validate the scoring metrics
metrics <- get_metrics(metrics, output_type, output_type_id_order)

# assemble data for scoringutils
su_data <- switch(output_type,
quantile = transform_quantile_model_out(model_out_tbl, target_observations),
Expand All @@ -99,16 +101,15 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL,
NULL # default, should not happen because of the validation above
)

# get/validate the scoring metrics
metrics <- get_metrics(forecast = su_data, output_type = output_type, select = metrics)

# compute scores
scores <- scoringutils::score(su_data, metrics)

# switch back to hubverse naming conventions for model name
scores <- dplyr::rename(scores, model_id = "model")

# if present, drop predicted and observed columns
drop_cols <- c("predicted", "observed")
scores <- scores[!colnames(scores) %in% drop_cols]

# if requested, summarize scores
if (summarize) {
scores <- scoringutils::summarize_scores(scores = scores, by = by)
Expand All @@ -118,138 +119,56 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL,
}


#' Get metrics if user didn't specify anything; otherwise, process
#' and validate user inputs
#'
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics <- function(metrics, output_type, output_type_id_order) {
if (is.null(metrics)) {
return(get_metrics_default(output_type, output_type_id_order))
} else if (is.character(metrics)) {
return(get_metrics_character(metrics, output_type))
} else {
cli::cli_abort(
"{.arg metrics} must be either `NULL` or a character vector of supported metrics."
)
}
}


#' Default metrics if user didn't specify anything
#' Get scoring metrics
#'
#' @param forecast A scoringutils `forecast` object (see
#' [scoringutils::as_forecast()] for details).
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics_default <- function(output_type, output_type_id_order) {
metrics <- switch(output_type,
quantile = scoringutils::get_metrics(scoringutils::example_quantile),
pmf = scoringutils::get_metrics(scoringutils::example_nominal),
mean = scoringutils::get_metrics(scoringutils::example_point, select = "se_point"),
median = scoringutils::get_metrics(scoringutils::example_point, select = "ae_point"),
NULL # default
)
if (is.null(metrics)) {
# we have already validated `output_type`, so this case should not be
# triggered; this case is just double checking in case we add something new
# later, to ensure we update this function.
supported_types <- c("mean", "median", "pmf", "quantile") # nolint object_use_linter
cli::cli_abort(
"Provided `model_out_tbl` contains `output_type` {.val {output_type}};
hubEvals currently only supports the following types:
{.val {supported_types}}"
)
}

return(metrics)
}
get_metrics <- function(forecast, output_type, select = NULL) {
forecast_type <- class(forecast)[1]


#' Convert character vector of metrics to list of functions
#'
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics_character <- function(metrics, output_type) {
if (output_type == "quantile") {
# process quantile metrics separately to allow better selection of interval
# coverage metrics
if (forecast_type == "forecast_quantile") {
# split into metrics for interval coverage and others
interval_metric_inds <- grepl(pattern = "^interval_coverage_[[:digit:]][[:digit:]]$", metrics)
interval_metrics <- metrics[interval_metric_inds]
other_metrics <- metrics[!interval_metric_inds]
interval_metric_inds <- grepl(pattern = "interval_coverage_", select)
interval_metrics <- select[interval_metric_inds]
other_metrics <- select[!interval_metric_inds]

# validate metrics
valid_metrics <- c("ae_median", "wis", "overprediction", "underprediction", "dispersion")
invalid_metrics <- other_metrics[!other_metrics %in% valid_metrics]
error_if_invalid_metrics(
valid_metrics = c(valid_metrics, "interval_coverage_XY"),
invalid_metrics = invalid_metrics,
output_type = output_type,
comment = c("i" = "NOTE: `XY` denotes the coverage level, e.g. {.val interval_coverage_95}.")
)
other_metric_fns <- scoringutils::get_metrics(forecast, select = other_metrics)

# assemble metric functions
# assemble interval coverage functions
interval_metric_fns <- lapply(
interval_metrics,
function(metric) {
level <- as.integer(substr(metric, 19, 20))
level_str <- substr(metric, 19, nchar(metric))
level <- suppressWarnings(as.numeric(level_str))
if (is.na(level) || level <= 0 || level >= 100) {
stop(paste(
"Invalid interval coverage level:", level_str,
"- must be a number between 0 and 100 (exclusive)"
))
}
return(purrr::partial(scoringutils::interval_coverage, interval_range = level))
}
)
names(interval_metric_fns) <- interval_metrics

other_metric_fns <- scoringutils::get_metrics(
scoringutils::example_quantile,
select = other_metrics
)

metric_fns <- c(other_metric_fns, interval_metric_fns)[metrics]
metrics <- metric_fns
} else if (output_type == "pmf") {
valid_metrics <- c("log_score")
invalid_metrics <- metrics[!metrics %in% valid_metrics]
error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type)

metrics <- scoringutils::get_metrics(
scoringutils::example_nominal,
select = metrics
)
} else if (output_type %in% c("median", "mean")) {
valid_metrics <- c("ae_point", "se_point")
invalid_metrics <- metrics[!metrics %in% valid_metrics]
error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type)

metrics <- scoringutils::get_metrics(
scoringutils::example_point,
select = metrics
)
} else {
# we have already validated `output_type`, so this case should not be
# triggered; this case is just double checking in case we add something new
# later, to ensure we update this function.
error_if_invalid_output_type(output_type)
metric_fns <- c(other_metric_fns, interval_metric_fns)
return(metric_fns)
}

return(metrics)
}


error_if_invalid_metrics <- function(valid_metrics, invalid_metrics, output_type, comment = NULL) {
n <- length(invalid_metrics)
if (n > 0) {
cli::cli_abort(
c(
"`metrics` had {n} unsupported metric{?s} for
{.arg output_type} {.val {output_type}}: {.strong {.val {invalid_metrics}}};
supported metrics include {.val {valid_metrics}}.",
comment
)
)
# leave validation of user selection to scoringutils
metric_fns <- scoringutils::get_metrics(forecast, select = select)
if (output_type == "mean") {
metric_fns <- scoringutils::select_metrics(metric_fns, "se_point")
} else if (output_type == "median") {
metric_fns <- scoringutils::select_metrics(metric_fns, "ae_point")
}

return(metric_fns)
}
Loading

0 comments on commit fbd7cba

Please sign in to comment.