Skip to content

Commit

Permalink
Merge pull request #46 from hubverse-org/score_model_out
Browse files Browse the repository at this point in the history
Score model out
  • Loading branch information
elray1 authored Sep 12, 2024
2 parents 2afa410 + 78f7295 commit a6cb8a8
Show file tree
Hide file tree
Showing 6 changed files with 856 additions and 1 deletion.
6 changes: 5 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,17 @@ Description:
basic tools for scoring hubverse forecasts.
License: MIT + file LICENSE
Encoding: UTF-8
Imports:
Imports:
checkmate,
cli,
dplyr,
hubUtils,
purrr,
rlang,
scoringutils (>= 1.2.2.9000)
Remotes:
epiforecasts/scoringutils,
hubverse-org/hubExamples,
hubverse-org/hubUtils
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
Expand All @@ -55,5 +58,6 @@ Depends:
LazyData: true
Config/Needs/website: hubverse-org/hubStyle
Suggests:
hubExamples,
testthat (>= 3.0.0)
Config/testthat/edition: 3
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Generated by roxygen2: do not edit by hand

export(score_model_out)
importFrom(rlang,.data)
241 changes: 241 additions & 0 deletions R/score_model_out.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
#' Score model output predictions with a single `output_type` against observed data
#'
#' @param model_out_tbl Model output tibble with predictions
#' @param target_observations Observed 'ground truth' data to be compared to
#' predictions
#' @param metrics Optional character vector of scoring metrics to compute. See details for more.
#' @param summarize Boolean indicator of whether summaries of forecast scores
#' should be computed. Defaults to `TRUE`.
#' @param by Character vector naming columns to summarize by. For example,
#' specifying `by = "model_id"` (the default) will compute average scores for
#' each model.
#' @param output_type_id_order For ordinal variables in pmf format, this is a
#' vector of levels for pmf forecasts, in increasing order of the levels. For
#' all other output types, this is ignored.
#'
#' @details If `metrics` is `NULL` (the default), this function chooses
#' appropriate metrics based on the `output_type` contained in the `model_out_tbl`:
#' \itemize{
#' \item For `output_type == "quantile"`, we use the default metrics provided by
#' `scoringutils::metrics_quantile()`: `r names(scoringutils::metrics_quantile())`
#' \item For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating
#' that the predicted variable is a nominal variable), we use the default metric
#' provided by `scoringutils::metrics_nominal()`,
#' `r names(scoringutils::metrics_nominal())`
#' \item For `output_type == "median"`, we use "ae_point"
#' \item For `output_type == "mean"`, we use "se_point"
#' }
#'
#' Alternatively, a character vector of scoring metrics can be provided. In this
#' case, the following options are supported:
#' - `output_type == "median"` and `output_type == "mean"`:
#' - "ae_point": absolute error of a point prediction (generally recommended for the median)
#' - "se_point": squared error of a point prediction (generally recommended for the mean)
#' - `output_type == "quantile"`:
#' - "ae_median": absolute error of the predictive median (i.e., the quantile at probability level 0.5)
#' - "wis": weighted interval score (WIS) of a collection of quantile predictions
#' - "overprediction": The component of WIS measuring the extent to which
#' predictions fell above the observation.
#' - "underprediction": The component of WIS measuring the extent to which
#' predictions fell below the observation.
#' - "dispersion": The component of WIS measuring the dispersion of forecast
#' distributions.
#' - "interval_coverage_XX": interval coverage at the "XX" level. For example,
#' "interval_coverage_95" is the 95% interval coverage rate, which would be calculated
#' based on quantiles at the probability levels 0.025 and 0.975.
#' - `output_type == "pmf"`:
#' - "log_score": log score
#'
#' @examplesIf requireNamespace("hubExamples", quietly = TRUE)
#' # compute WIS and interval coverage rates at 80% and 90% levels based on
#' # quantile forecasts, summarized by the mean score for each model
#' quantile_scores <- score_model_out(
#' model_out_tbl = hubExamples::forecast_outputs |>
#' dplyr::filter(.data[["output_type"]] == "quantile"),
#' target_observations = hubExamples::forecast_target_observations,
#' metrics = c("wis", "interval_coverage_80", "interval_coverage_90"),
#' by = "model_id"
#' )
#' quantile_scores
#'
#' # compute log scores based on pmf predictions for categorical targets,
#' # summarized by the mean score for each combination of model and location.
#' # Note: if the model_out_tbl had forecasts for multiple targets using a
#' # pmf output_type with different bins, it would be necessary to score the
#' # predictions for those targets separately.
#' pmf_scores <- score_model_out(
#' model_out_tbl = hubExamples::forecast_outputs |>
#' dplyr::filter(.data[["output_type"]] == "pmf"),
#' target_observations = hubExamples::forecast_target_observations,
#' metrics = "log_score",
#' by = c("model_id", "location", "horizon")
#' )
#' head(pmf_scores)
#'
#' @return forecast_quantile
#'
#' @export
score_model_out <- function(model_out_tbl, target_observations, metrics = NULL,
summarize = TRUE, by = "model_id",
output_type_id_order = NULL) {
# check that model_out_tbl has a single output_type that is supported by this package
# also, retrieve that output_type
output_type <- validate_output_type(model_out_tbl)

# get/validate the scoring metrics
metrics <- get_metrics(metrics, output_type, output_type_id_order)

# assemble data for scoringutils
su_data <- switch(output_type,
quantile = transform_quantile_model_out(model_out_tbl, target_observations),
pmf = transform_pmf_model_out(model_out_tbl, target_observations, output_type_id_order),
mean = transform_point_model_out(model_out_tbl, target_observations, output_type),
median = transform_point_model_out(model_out_tbl, target_observations, output_type),
NULL # default, should not happen because of the validation above
)

# compute scores
scores <- scoringutils::score(su_data, metrics)

# switch back to hubverse naming conventions for model name
scores <- dplyr::rename(scores, model_id = "model")

# if present, drop predicted and observed columns
drop_cols <- c("predicted", "observed")
scores <- scores[!colnames(scores) %in% drop_cols]

# if requested, summarize scores
if (summarize) {
scores <- scoringutils::summarize_scores(scores = scores, by = by)
}

return(scores)
}


#' Get metrics if user didn't specify anything; otherwise, process
#' and validate user inputs
#'
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics <- function(metrics, output_type, output_type_id_order) {
if (is.null(metrics)) {
return(get_metrics_default(output_type, output_type_id_order))
} else if (is.character(metrics)) {
return(get_metrics_character(metrics, output_type))
} else {
cli::cli_abort(
"{.arg metrics} must be either `NULL` or a character vector of supported metrics."
)
}
}


#' Default metrics if user didn't specify anything
#'
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics_default <- function(output_type, output_type_id_order) {
metrics <- switch(output_type,
quantile = scoringutils::metrics_quantile(),
pmf = scoringutils::metrics_nominal(),
mean = scoringutils::metrics_point(select = "se_point"),
median = scoringutils::metrics_point(select = "ae_point"),
NULL # default
)
if (is.null(metrics)) {
# we have already validated `output_type`, so this case should not be
# triggered; this case is just double checking in case we add something new
# later, to ensure we update this function.
supported_types <- c("mean", "median", "pmf", "quantile") # nolint object_use_linter
cli::cli_abort(
"Provided `model_out_tbl` contains `output_type` {.val {output_type}};
hubEvals currently only supports the following types:
{.val {supported_types}}"
)
}

return(metrics)
}


#' Convert character vector of metrics to list of functions
#'
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics_character <- function(metrics, output_type) {
if (output_type == "quantile") {
# split into metrics for interval coverage and others
interval_metric_inds <- grepl(pattern = "^interval_coverage_[[:digit:]][[:digit:]]$", metrics)
interval_metrics <- metrics[interval_metric_inds]
other_metrics <- metrics[!interval_metric_inds]

# validate metrics
valid_metrics <- c("ae_median", "wis", "overprediction", "underprediction", "dispersion")
invalid_metrics <- other_metrics[!other_metrics %in% valid_metrics]
error_if_invalid_metrics(
valid_metrics = c(valid_metrics, "interval_coverage_XY"),
invalid_metrics = invalid_metrics,
output_type = output_type,
comment = c("i" = "NOTE: `XY` denotes the coverage level, e.g. {.val interval_coverage_95}.")
)

# assemble metric functions
interval_metric_fns <- lapply(
interval_metrics,
function(metric) {
level <- as.integer(substr(metric, 19, 20))
return(purrr::partial(scoringutils::interval_coverage, interval_range = level))
}
)
names(interval_metric_fns) <- interval_metrics

other_metric_fns <- scoringutils::metrics_quantile(select = other_metrics)

metric_fns <- c(other_metric_fns, interval_metric_fns)[metrics]
metrics <- metric_fns
} else if (output_type == "pmf") {
valid_metrics <- c("log_score")
invalid_metrics <- metrics[!metrics %in% valid_metrics]
error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type)

metrics <- scoringutils::metrics_nominal(select = metrics)
} else if (output_type %in% c("median", "mean")) {
valid_metrics <- c("ae_point", "se_point")
invalid_metrics <- metrics[!metrics %in% valid_metrics]
error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type)

metrics <- scoringutils::metrics_point(select = metrics)
} else {
# we have already validated `output_type`, so this case should not be
# triggered; this case is just double checking in case we add something new
# later, to ensure we update this function.
error_if_invalid_output_type(output_type)
}

return(metrics)
}


error_if_invalid_metrics <- function(valid_metrics, invalid_metrics, output_type, comment = NULL) {
n <- length(invalid_metrics)
if (n > 0) {
cli::cli_abort(
c(
"`metrics` had {n} unsupported metric{?s} for
{.arg output_type} {.val {output_type}}: {.strong {.val {invalid_metrics}}};
supported metrics include {.val {valid_metrics}}.",
comment
)
)
}
}
33 changes: 33 additions & 0 deletions R/validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,36 @@ validate_model_out_target_obs <- function(model_out_tbl, target_observations) {

return(model_out_tbl)
}


#' Check that model_out_tble has a single `output_type` that is one of the
#' `output_types` that is supported by this function.
#'
#' @return if valid, the output_type in model_out_tbl
#'
#' @noRd
validate_output_type <- function(model_out_tbl) {
output_type <- unique(model_out_tbl$output_type)
if (length(output_type) != 1) {
cli::cli_abort(
"model_out_tbl must contain a single output_type, but it has multiple:
{.val {output_type}}"
)
}

error_if_invalid_output_type(output_type)

return(output_type)
}


error_if_invalid_output_type <- function(output_type) {
supported_types <- c("mean", "median", "pmf", "quantile")
if (!output_type %in% supported_types) {
cli::cli_abort(
"Provided `model_out_tbl` contains `output_type` {.val {output_type}};
hubEvals currently only supports the following types:
{.val {supported_types}}"
)
}
}
Loading

0 comments on commit a6cb8a8

Please sign in to comment.