Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Score model out #46

Merged
merged 35 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1f51495
initial draft of score_model_out
elray1 Aug 20, 2024
08d0edd
Merge branch 'transform_pmf' into score_model_out
elray1 Aug 22, 2024
e101808
some partial progress on score_model_out
elray1 Aug 22, 2024
d438a7a
Merge branch 'transform_pmf' into score_model_out
elray1 Aug 22, 2024
448e830
updates to score_model_out
elray1 Aug 22, 2024
b92891a
appease the linter
elray1 Aug 22, 2024
d6a6b36
update package imports
elray1 Aug 22, 2024
c38d738
Merge branch 'transform_pmf' into score_model_out
elray1 Aug 22, 2024
28724e4
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
41c8c50
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
fd06ad7
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
fa64a8e
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
7e3ce73
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
5b8eb40
Update R/score_model_out.R
elray1 Aug 26, 2024
a68fabb
Update R/score_model_out.R
elray1 Aug 26, 2024
a6a9c3b
Update R/score_model_out.R
elray1 Aug 26, 2024
7c8e86d
Update R/score_model_out.R
elray1 Aug 26, 2024
47f29a2
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
fed85ac
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
4c6eabc
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
6c443e6
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
778be41
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
f27b164
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
2367518
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
2c740d7
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
79a4828
updates from review: document lack of checks for statistical validity…
elray1 Aug 26, 2024
6e09adc
fix up comment spacing
elray1 Aug 26, 2024
5d45a2a
refactor metric validation messaging
elray1 Aug 26, 2024
ba70640
add examples for score_model_out
elray1 Aug 26, 2024
7df1803
Update R/score_model_out.R
elray1 Aug 28, 2024
a9c2795
Update R/score_model_out.R
elray1 Aug 28, 2024
18ea11a
updates to score_model_out docs
elray1 Aug 29, 2024
285b4d9
Merge branch 'score_model_out' of https://github.com/Infectious-Disea…
elray1 Aug 29, 2024
587c922
do not support lists of functions in score_model_out
elray1 Sep 11, 2024
78f7295
update docs
elray1 Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,17 @@ Description:
basic tools for scoring hubverse forecasts.
License: MIT + file LICENSE
Encoding: UTF-8
Imports:
Imports:
checkmate,
cli,
dplyr,
hubUtils,
purrr,
rlang,
scoringutils (>= 1.2.2.9000)
Remotes:
epiforecasts/scoringutils,
hubverse-org/hubExamples,
hubverse-org/hubUtils
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
Expand All @@ -55,5 +58,6 @@ Depends:
LazyData: true
Config/Needs/website: hubverse-org/hubStyle
Suggests:
hubExamples,
testthat (>= 3.0.0)
Config/testthat/edition: 3
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Generated by roxygen2: do not edit by hand

export(score_model_out)
importFrom(rlang,.data)
247 changes: 247 additions & 0 deletions R/score_model_out.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
#' Score model output predictions with a single `output_type` against observed data
#'
#' @param model_out_tbl Model output tibble with predictions
#' @param target_observations Observed 'ground truth' data to be compared to
#' predictions
#' @param metrics Optional list of scoring metrics to compute. See details for more.
#' @param summarize boolean indicator of whether summaries of forecast scores
elray1 marked this conversation as resolved.
Show resolved Hide resolved
#' should be computed. Defaults to `TRUE`.
#' @param by character vector naming columns to summarize by. For example,
elray1 marked this conversation as resolved.
Show resolved Hide resolved
#' specifying `by = "model_id"` (the default) will compute average scores for
#' each model.
#' @param output_type_id_order For ordinal variables in pmf format, this is a
#' vector of levels for pmf forecasts, in increasing order of the levels. For
#' all other output types, this is ignored.
#'
#' @details If `metrics` is `NULL` (the default), this function chooses
#' appropriate metrics based on the `output_type` contained in the `model_out_tbl`:
#' - For `output_type == "quantile"`, we use the default metrics provided by
#' `scoringutils::metrics_quantile()`: "wis", "overprediction", "underprediction",
zkamvar marked this conversation as resolved.
Show resolved Hide resolved
#' "dispersion", "bias", "interval_coverage_50", "interval_coverage_90",
#' "interval_coverage_deviation", and "ae_median"
#' - For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating
#' that the predicted variable is a nominal variable), we use the default metrics
#' provided by `scoringutils::metrics_nominal()`, currently just "log_score"
#' - For `output_type == "median"`, we use "ae_point"
#' - For `output_type == "mean"`, we use "se_point"
#'
#' Alternatively, a character vector of scoring metrics can be provided. In this
#' case, the following options are supported:
#' - `output_type == "median"` and `output_type == "median"`:
#' - "ae": absolute error of a point prediction (generally recommended for the median)
#' - "se": squared error of a point prediction (generally recommended for the mean)
#' - `output_type == "quantile"`:
#' - "ae_median": absolute error of the predictive median (i.e., the quantile at probability level 0.5)
#' - "wis": weighted interval score (WIS) of a collection of quantile predictions
#' - "overprediction": The component of WIS measuring the extent to which
#' predictions fell above the observation.
#' - "underprediction": The component of WIS measuring the extent to which
#' predictions fell below the observation.
#' - "dispersion": The component of WIS measuring the dispersion of forecast
#' distributions.
#' - "interval_coverage_XX": interval coverage at the "XX" level. For example,
#' "interval_coverage_95" is the 95% interval coverage rate, which would be calculated
#' based on quantiles at the probability levels 0.025 and 0.975.
#' - `output_type == "pmf"`:
#' - "log_score": log score
#'
#' For more flexibility, it is also possible to directly provide a list of
#' functions to compute the desired metrics, e.g. as would be created by one of
#' the `scoringutils::metrics_*` methods. Note that in this case, `hubEvals`
#' only validates that a list of functions has been provided; no checks for the
#' statistical validity of these metric functions are done.
#'
#' @examplesIf requireNamespace("hubExamples", quietly = TRUE)
#' # compute WIS and interval coverage rates at 80% and 90% levels based on
#' # quantile forecasts, summarized by the mean score for each model
#' quantile_scores <- score_model_out(
#' model_out_tbl = hubExamples::forecast_outputs |>
#' dplyr::filter(.data[["output_type"]] == "quantile"),
#' target_observations = hubExamples::forecast_target_observations,
#' metrics = c("wis", "interval_coverage_80", "interval_coverage_90"),
#' by = c("model_id")
#' )
#' quantile_scores
#'
#' # compute log scores based on pmf predictions for categorical targets,
#' # summarized by the mean score for each combination of model and location.
#' # Note: if the model_out_tbl had forecasts for multiple targets using a
#' # pmf output_type with different bins, it would be necessary to score the
#' # predictions for those targets separately.
#' pmf_scores <- score_model_out(
#' model_out_tbl = hubExamples::forecast_outputs |>
#' dplyr::filter(.data[["output_type"]] == "pmf"),
#' target_observations = hubExamples::forecast_target_observations,
#' metrics = "log_score",
#' by = c("model_id", "location", "horizon")
#' )
#' head(pmf_scores)
#'
#' @return forecast_quantile
#'
#' @export
score_model_out <- function(model_out_tbl, target_observations, metrics = NULL,
summarize = TRUE, by = "model_id",
output_type_id_order = NULL) {
# check that model_out_tbl has a single output_type that is supported by this package
# also, retrieve that output_type
output_type <- validate_output_type(model_out_tbl)

# get/validate the scoring metrics
metrics <- get_metrics(metrics, output_type, output_type_id_order)

# assemble data for scoringutils
su_data <- switch(output_type,
Copy link
Collaborator

@nikosbosse nikosbosse Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small note - long-term it might be a bit more elegant to implement transform_pmf_model() as an S3 method. Then we could omit the switch and just call transform_model_out(model_out_tbl, target_observations)
(don't think it makes much sense to do that now though)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At minimum, this switch will be refactored into a transform_model_out function per issue #11.

Currently, the hubverse tooling does not have separate S3 classes per output type; mostly, our tools accept data frames containing a mix of output types. But it could be worth thinking about whether there are other places (e.g. plotting?) where functionality is specific to the output type and having this kind of class structure would be helpful.

quantile = transform_quantile_model_out(model_out_tbl, target_observations),
pmf = transform_pmf_model_out(model_out_tbl, target_observations, output_type_id_order),
mean = transform_point_model_out(model_out_tbl, target_observations, output_type),
median = transform_point_model_out(model_out_tbl, target_observations, output_type),
NULL # default, should not happen because of the validation above
)

# compute scores
scores <- scoringutils::score(su_data, metrics)

# switch back to hubverse naming conventions for model name
scores <- dplyr::rename(scores, model_id = "model")

# if present, drop predicted and observed columns
drop_cols <- c("predicted", "observed")
scores <- scores[!colnames(scores) %in% drop_cols]
elray1 marked this conversation as resolved.
Show resolved Hide resolved

# if requested, summarize scores
if (summarize) {
scores <- scoringutils::summarize_scores(scores = scores, by = by)
}

return(scores)
}


#' Get metrics if user didn't specify anything; otherwise, process
#' and validate user inputs
#'
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics <- function(metrics, output_type, output_type_id_order) {
if (is.null(metrics)) {
return(get_metrics_default(output_type, output_type_id_order))
} else if (is.character(metrics)) {
return(get_metrics_character(metrics, output_type))
} else {
# We do a minimal preliminary check that the provided `metrics` is a list of
# functions, leaving further checks to scoringutils
checkmate::assert_list(metrics, types = "function")
elray1 marked this conversation as resolved.
Show resolved Hide resolved
return(metrics)
}
}


#' Default metrics if user didn't specify anything
#'
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics_default <- function(output_type, output_type_id_order) {
metrics <- switch(output_type,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll be changing this in scoringutils by replacing functions like metrics_quantile(), metrics_point() etc. with a single get_metrics() function that is an S3 generic. So if you called get_metrics(forecast_quantile_object) you'd get the quantile metrics. This would simplify this code because one could transform the forecasts into a forecast object first and then rely on S3 to get the correct metrics.

quantile = scoringutils::metrics_quantile(),
pmf = scoringutils::metrics_nominal(),
mean = scoringutils::metrics_point(select = "se_point"),
median = scoringutils::metrics_point(select = "ae_point"),
NULL # default
)
if (is.null(metrics)) {
# we have already validated `output_type`, so this case should not be
# triggered; this case is just double checking in case we add something new
# later, to ensure we update this function.
supported_types <- c("mean", "median", "pmf", "quantile") # nolint object_use_linter
cli::cli_abort(
"Provided `model_out_tbl` contains `output_type` {.val {output_type}};
hubEvals currently only supports the following types:
{.val {supported_types}}"
)
}

return(metrics)
}


#' Convert character vector of metrics to list of functions
#'
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics_character <- function(metrics, output_type) {
if (output_type == "quantile") {
# split into metrics for interval coverage and others
interval_metric_inds <- grepl(pattern = "^interval_coverage_[[:digit:]][[:digit:]]$", metrics)
interval_metrics <- metrics[interval_metric_inds]
other_metrics <- metrics[!interval_metric_inds]

# validate metrics
valid_metrics <- c("ae_median", "wis", "overprediction", "underprediction", "dispersion")
invalid_metrics <- other_metrics[!other_metrics %in% valid_metrics]
error_if_invalid_metrics(
valid_metrics = c(valid_metrics, "interval_coverage_XY"),
invalid_metrics = invalid_metrics,
output_type = output_type,
comment = c("i" = "NOTE: `XY` denotes the coverage level, e.g. {.val interval_coverage_95}.")
)

# assemble metric functions
interval_metric_fns <- lapply(
interval_metrics,
function(metric) {
level <- as.integer(substr(metric, 19, 20))
return(purrr::partial(scoringutils::interval_coverage, interval_range = level))
}
)
names(interval_metric_fns) <- interval_metrics

other_metric_fns <- scoringutils::metrics_quantile(select = other_metrics)

metric_fns <- c(other_metric_fns, interval_metric_fns)[metrics]
metrics <- metric_fns
} else if (output_type == "pmf") {
valid_metrics <- c("log_score")
invalid_metrics <- metrics[!metrics %in% valid_metrics]
error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type)

metrics <- scoringutils::metrics_nominal(select = metrics)
} else if (output_type %in% c("median", "mean")) {
valid_metrics <- c("ae_point", "se_point")
invalid_metrics <- metrics[!metrics %in% valid_metrics]
error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type)

metrics <- scoringutils::metrics_point(select = metrics)
} else {
# we have already validated `output_type`, so this case should not be
# triggered; this case is just double checking in case we add something new
# later, to ensure we update this function.
error_if_invalid_output_type(output_type)
}

return(metrics)
}


error_if_invalid_metrics <- function(valid_metrics, invalid_metrics, output_type, comment = NULL) {
n <- length(invalid_metrics)
if (n > 0) {
cli::cli_abort(
c(
"`metrics` had {n} unsupported metric{?s} for
{.arg output_type} {.val {output_type}}: {.strong {.val {invalid_metrics}}};
supported metrics include {.val {valid_metrics}}.",
comment
)
)
}
}
33 changes: 33 additions & 0 deletions R/validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,36 @@ validate_model_out_target_obs <- function(model_out_tbl, target_observations) {

return(model_out_tbl)
}


#' Check that model_out_tble has a single `output_type` that is one of the
#' `output_types` that is supported by this function.
#'
#' @return if valid, the output_type in model_out_tbl
#'
#' @noRd
validate_output_type <- function(model_out_tbl) {
output_type <- unique(model_out_tbl$output_type)
if (length(output_type) != 1) {
cli::cli_abort(
"model_out_tbl must contain a single output_type, but it has multiple:
{.val {output_type}}"
)
}

error_if_invalid_output_type(output_type)

return(output_type)
}


error_if_invalid_output_type <- function(output_type) {
supported_types <- c("mean", "median", "pmf", "quantile")
if (!output_type %in% supported_types) {
cli::cli_abort(
"Provided `model_out_tbl` contains `output_type` {.val {output_type}};
hubEvals currently only supports the following types:
{.val {supported_types}}"
)
}
}
Loading
Loading