Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Score model out improvements #54

Merged
merged 10 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
^\.Rdata$
^\.httr-oauth$
^\.secrets$
^.vscode
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,5 @@ docs
.Rdata
.secrets
.quarto

.vscode/
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Remotes:
hubverse-org/hubExamples,
hubverse-org/hubUtils
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
URL: https://hubverse-org.github.io/hubEvals/
Depends:
R (>= 2.10)
Expand Down
221 changes: 77 additions & 144 deletions R/score_model_out.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#' Score model output predictions with a single `output_type` against observed data
#' Score model output predictions
#'
#' Scores model outputs with a single `output_type` against observed data.
#'
#' @param model_out_tbl Model output tibble with predictions
#' @param target_observations Observed 'ground truth' data to be compared to
#' predictions
#' @param metrics Optional character vector of scoring metrics to compute. See details for more.
#' @param metrics Character vector of scoring metrics to compute. If `NULL`
#' (the default), appropriate metrics are chosen automatically. See details
#' for more.
#' @param summarize Boolean indicator of whether summaries of forecast scores
#' should be computed. Defaults to `TRUE`.
#' @param by Character vector naming columns to summarize by. For example,
Expand All @@ -13,38 +17,40 @@
#' vector of levels for pmf forecasts, in increasing order of the levels. For
#' all other output types, this is ignored.
#'
#' @details If `metrics` is `NULL` (the default), this function chooses
#' appropriate metrics based on the `output_type` contained in the `model_out_tbl`:
#' \itemize{
#' \item For `output_type == "quantile"`, we use the default metrics provided by
#' `scoringutils::metrics_quantile()`: `r names(scoringutils::metrics_quantile())`
#' \item For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating
#' that the predicted variable is a nominal variable), we use the default metric
#' provided by `scoringutils::metrics_nominal()`,
#' `r names(scoringutils::metrics_nominal())`
#' \item For `output_type == "median"`, we use "ae_point"
#' \item For `output_type == "mean"`, we use "se_point"
#' }
#'
#' Alternatively, a character vector of scoring metrics can be provided. In this
#' case, the following options are supported:
#' - `output_type == "median"` and `output_type == "mean"`:
#' - "ae_point": absolute error of a point prediction (generally recommended for the median)
#' - "se_point": squared error of a point prediction (generally recommended for the mean)
#' - `output_type == "quantile"`:
#' - "ae_median": absolute error of the predictive median (i.e., the quantile at probability level 0.5)
#' - "wis": weighted interval score (WIS) of a collection of quantile predictions
#' - "overprediction": The component of WIS measuring the extent to which
#' predictions fell above the observation.
#' - "underprediction": The component of WIS measuring the extent to which
#' predictions fell below the observation.
#' - "dispersion": The component of WIS measuring the dispersion of forecast
#' distributions.
#' - "interval_coverage_XX": interval coverage at the "XX" level. For example,
#' @details
#' Default metrics are provided by the `scoringutils` package. You can select
#' metrics by passing in a character vector of metric names to the `metrics`
#' argument.
#'
#' The following metrics can be selected (all are used by default) for the
#' different `output_type`s:
#'
#' **Quantile forecasts:** (`output_type == "quantile"`)
#' `r exclude <- c("interval_coverage_50", "interval_coverage_90")`
#' `r metrics <- scoringutils::get_metrics(scoringutils::example_quantile, exclude = exclude)`
#' `r paste("- ", names(metrics), collapse = "\n")`
#' - "interval_coverage_XX": interval coverage at the "XX" level. For example,
#' "interval_coverage_95" is the 95% interval coverage rate, which would be calculated
#' based on quantiles at the probability levels 0.025 and 0.975.
#' - `output_type == "pmf"`:
#' - "log_score": log score
#'
#' See [scoringutils::get_metrics.forecast_quantile] for details.
#'
#' **Nominal forecasts:** (`output_type == "pmf"` and `output_type_id_order` is `NULL`)
#'
#' `r paste("- ", names(scoringutils::get_metrics(scoringutils::example_nominal)), collapse = "\n")`
#'
#' (scoring for ordinal forecasts will be added in the future).
#'
#' See [scoringutils::get_metrics.forecast_nominal] for details.
#'
#' **Median forecasts:** (`output_type == "median"`)
#'
#' - ae_point: absolute error of the point forecast (recommended for the median, see Gneiting (2011))
#'
#' See [scoringutils::get_metrics.forecast_point] for details.
#'
#' **Mean forecasts:** (`output_type == "mean"`)
#' - `se_point`: squared error of the point forecast (recommended for the mean, see Gneiting (2011))
#'
#' @examplesIf requireNamespace("hubExamples", quietly = TRUE)
#' # compute WIS and interval coverage rates at 80% and 90% levels based on
Expand Down Expand Up @@ -72,7 +78,11 @@
#' )
#' head(pmf_scores)
#'
#' @return forecast_quantile
#' @return A data.table with scores
#'
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically it's an object of class scores which is itself a data.table but probably easiest not to confuse users with scoringutils stuff

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a much better descriptor. Thank you for catching it, especially since data.table can catch people off-guard since it doesn't behave like other R objects (i.e. modifying a data table will modify the object in-place as opposed to making a copy).

#' @references
#' Gneiting, Tilmann. 2011. "Making and Evaluating Point Forecasts." Journal of the
#' American Statistical Association 106 (494): 746–62. <doi: 10.1198/jasa.2011.r10138>.
#'
#' @export
score_model_out <- function(model_out_tbl, target_observations, metrics = NULL,
Expand All @@ -82,9 +92,6 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL,
# also, retrieve that output_type
output_type <- validate_output_type(model_out_tbl)

# get/validate the scoring metrics
metrics <- get_metrics(metrics, output_type, output_type_id_order)

# assemble data for scoringutils
su_data <- switch(output_type,
quantile = transform_quantile_model_out(model_out_tbl, target_observations),
Expand All @@ -94,16 +101,15 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL,
NULL # default, should not happen because of the validation above
)

# get/validate the scoring metrics
metrics <- get_metrics(forecast = su_data, output_type = output_type, select = metrics)

# compute scores
scores <- scoringutils::score(su_data, metrics)

# switch back to hubverse naming conventions for model name
scores <- dplyr::rename(scores, model_id = "model")

# if present, drop predicted and observed columns
drop_cols <- c("predicted", "observed")
scores <- scores[!colnames(scores) %in% drop_cols]

# if requested, summarize scores
if (summarize) {
scores <- scoringutils::summarize_scores(scores = scores, by = by)
Expand All @@ -113,129 +119,56 @@ score_model_out <- function(model_out_tbl, target_observations, metrics = NULL,
}


#' Get metrics if user didn't specify anything; otherwise, process
#' and validate user inputs
#'
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics <- function(metrics, output_type, output_type_id_order) {
if (is.null(metrics)) {
return(get_metrics_default(output_type, output_type_id_order))
} else if (is.character(metrics)) {
return(get_metrics_character(metrics, output_type))
} else {
cli::cli_abort(
"{.arg metrics} must be either `NULL` or a character vector of supported metrics."
)
}
}


#' Default metrics if user didn't specify anything
#' Get scoring metrics
#'
#' @param forecast A scoringutils `forecast` object (see
#' [scoringutils::as_forecast()] for details).
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics_default <- function(output_type, output_type_id_order) {
metrics <- switch(output_type,
quantile = scoringutils::metrics_quantile(),
pmf = scoringutils::metrics_nominal(),
mean = scoringutils::metrics_point(select = "se_point"),
median = scoringutils::metrics_point(select = "ae_point"),
NULL # default
)
if (is.null(metrics)) {
# we have already validated `output_type`, so this case should not be
# triggered; this case is just double checking in case we add something new
# later, to ensure we update this function.
supported_types <- c("mean", "median", "pmf", "quantile") # nolint object_use_linter
cli::cli_abort(
"Provided `model_out_tbl` contains `output_type` {.val {output_type}};
hubEvals currently only supports the following types:
{.val {supported_types}}"
)
}
get_metrics <- function(forecast, output_type, select = NULL) {
forecast_type <- class(forecast)[1]

return(metrics)
}


#' Convert character vector of metrics to list of functions
#'
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics_character <- function(metrics, output_type) {
if (output_type == "quantile") {
# process quantile metrics separately to allow better selection of interval
# coverage metrics
if (forecast_type == "forecast_quantile") {
# split into metrics for interval coverage and others
interval_metric_inds <- grepl(pattern = "^interval_coverage_[[:digit:]][[:digit:]]$", metrics)
interval_metrics <- metrics[interval_metric_inds]
other_metrics <- metrics[!interval_metric_inds]
interval_metric_inds <- grepl(pattern = "^interval_coverage_", select)
interval_metrics <- select[interval_metric_inds]
other_metrics <- select[!interval_metric_inds]

# validate metrics
valid_metrics <- c("ae_median", "wis", "overprediction", "underprediction", "dispersion")
invalid_metrics <- other_metrics[!other_metrics %in% valid_metrics]
error_if_invalid_metrics(
valid_metrics = c(valid_metrics, "interval_coverage_XY"),
invalid_metrics = invalid_metrics,
output_type = output_type,
comment = c("i" = "NOTE: `XY` denotes the coverage level, e.g. {.val interval_coverage_95}.")
)
other_metric_fns <- scoringutils::get_metrics(forecast, select = other_metrics)

# assemble metric functions
# assemble interval coverage functions
interval_metric_fns <- lapply(
interval_metrics,
function(metric) {
level <- as.integer(substr(metric, 19, 20))
level_str <- substr(metric, 19, nchar(metric))
level <- suppressWarnings(as.numeric(level_str))
if (is.na(level) || level <= 0 || level >= 100) {
cli::cli_abort(c(
"Invalid interval coverage level: {level_str}",
"i" = "must be a number between 0 and 100 (exclusive)"
))
}
return(purrr::partial(scoringutils::interval_coverage, interval_range = level))
}
)
names(interval_metric_fns) <- interval_metrics

other_metric_fns <- scoringutils::metrics_quantile(select = other_metrics)

metric_fns <- c(other_metric_fns, interval_metric_fns)[metrics]
metrics <- metric_fns
} else if (output_type == "pmf") {
valid_metrics <- c("log_score")
invalid_metrics <- metrics[!metrics %in% valid_metrics]
error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type)

metrics <- scoringutils::metrics_nominal(select = metrics)
} else if (output_type %in% c("median", "mean")) {
valid_metrics <- c("ae_point", "se_point")
invalid_metrics <- metrics[!metrics %in% valid_metrics]
error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type)

metrics <- scoringutils::metrics_point(select = metrics)
} else {
# we have already validated `output_type`, so this case should not be
# triggered; this case is just double checking in case we add something new
# later, to ensure we update this function.
error_if_invalid_output_type(output_type)
metric_fns <- c(other_metric_fns, interval_metric_fns)
return(metric_fns)
}

return(metrics)
}


error_if_invalid_metrics <- function(valid_metrics, invalid_metrics, output_type, comment = NULL) {
n <- length(invalid_metrics)
if (n > 0) {
cli::cli_abort(
c(
"`metrics` had {n} unsupported metric{?s} for
{.arg output_type} {.val {output_type}}: {.strong {.val {invalid_metrics}}};
supported metrics include {.val {valid_metrics}}.",
comment
)
)
# leave validation of user selection to scoringutils
metric_fns <- scoringutils::get_metrics(forecast, select = select)
if (output_type == "mean") {
metric_fns <- scoringutils::select_metrics(metric_fns, "se_point")
} else if (output_type == "median") {
metric_fns <- scoringutils::select_metrics(metric_fns, "ae_point")
}

return(metric_fns)
}
Loading
Loading