Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Score model out #46

Merged
merged 35 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1f51495
initial draft of score_model_out
elray1 Aug 20, 2024
08d0edd
Merge branch 'transform_pmf' into score_model_out
elray1 Aug 22, 2024
e101808
some partial progress on score_model_out
elray1 Aug 22, 2024
d438a7a
Merge branch 'transform_pmf' into score_model_out
elray1 Aug 22, 2024
448e830
updates to score_model_out
elray1 Aug 22, 2024
b92891a
appease the linter
elray1 Aug 22, 2024
d6a6b36
update package imports
elray1 Aug 22, 2024
c38d738
Merge branch 'transform_pmf' into score_model_out
elray1 Aug 22, 2024
28724e4
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
41c8c50
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
fd06ad7
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
fa64a8e
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
7e3ce73
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
5b8eb40
Update R/score_model_out.R
elray1 Aug 26, 2024
a68fabb
Update R/score_model_out.R
elray1 Aug 26, 2024
a6a9c3b
Update R/score_model_out.R
elray1 Aug 26, 2024
7c8e86d
Update R/score_model_out.R
elray1 Aug 26, 2024
47f29a2
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
fed85ac
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
4c6eabc
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
6c443e6
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
778be41
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
f27b164
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
2367518
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
2c740d7
Update tests/testthat/test-score_model_out.R
elray1 Aug 26, 2024
79a4828
updates from review: document lack of checks for statistical validity…
elray1 Aug 26, 2024
6e09adc
fix up comment spacing
elray1 Aug 26, 2024
5d45a2a
refactor metric validation messaging
elray1 Aug 26, 2024
ba70640
add examples for score_model_out
elray1 Aug 26, 2024
7df1803
Update R/score_model_out.R
elray1 Aug 28, 2024
a9c2795
Update R/score_model_out.R
elray1 Aug 28, 2024
18ea11a
updates to score_model_out docs
elray1 Aug 29, 2024
285b4d9
Merge branch 'score_model_out' of https://github.com/Infectious-Disea…
elray1 Aug 29, 2024
587c922
do not support lists of functions in score_model_out
elray1 Sep 11, 2024
78f7295
update docs
elray1 Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ Description:
basic tools for scoring hubverse forecasts.
License: MIT + file LICENSE
Encoding: UTF-8
Imports:
Imports:
checkmate,
cli,
dplyr,
hubUtils,
purrr,
rlang,
scoringutils (>= 1.2.2.9000)
Remotes:
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Generated by roxygen2: do not edit by hand

export(score_model_out)
importFrom(rlang,.data)
229 changes: 229 additions & 0 deletions R/score_model_out.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
#' Score model output predictions with a single `output_type` against observed data
#'
#' @param model_out_tbl Model output tibble with predictions
#' @param target_observations Observed 'ground truth' data to be compared to
#' predictions
#' @param metrics Optional list of scoring metrics to compute. See details for more.
#' @param summarize boolean indicator of whether summaries of forecast scores
elray1 marked this conversation as resolved.
Show resolved Hide resolved
#' should be computed. Defaults to `TRUE`.
#' @param by character vector naming columns to summarize by. For example,
elray1 marked this conversation as resolved.
Show resolved Hide resolved
#' specifying `by = "model_id"` (the default) will compute average scores for
#' each model.
#' @param output_type_id_order For ordinal variables in pmf format, this is a
#' vector of levels for pmf forecasts, in increasing order of the levels. For
#' all other output types, this is ignored.
#'
#' @details If `metrics` is `NULL` (the default), this function chooses
#' appropriate metrics based on the `output_type` contained in the `model_out_tbl`:
#' - For `output_type == "quantile"`, we use the default metrics provided by
#' `scoringutils::metrics_quantile()`: "wis", "overprediction", "underprediction",
zkamvar marked this conversation as resolved.
Show resolved Hide resolved
#' "dispersion", "bias", "interval_coverage_50", "interval_coverage_90",
#' "interval_coverage_deviation", and "ae_median"
#' - For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating
#' that the predicted variable is a nominal variable), we use the default metrics
#' provided by `scoringutils::metrics_nominal()`, currently just "log_score"
#' - For `output_type == "median"`, we use "ae_point"
#' - For `output_type == "mean"`, we use "se_point"
#'
#' It is also possible to directly provide a list of metrics, e.g. as would be
#' created by one of those function calls. Alternatively, a character vector of
#' scoring metrics can be provided. In this case, the following options are supported:
#' - `output_type == "median"` and `output_type == "median"`:
#' - "ae": absolute error of a point prediction (generally recommended for the median)
#' - "se": squared error of a point prediction (generally recommended for the mean)
#' - `output_type == "quantile"`:
#' - "ae_median": absolute error of the predictive median (i.e., the quantile at probability level 0.5)
#' - "wis": weighted interval score (WIS) of a collection of quantile predictions
#' - "overprediction": The component of WIS measuring the extent to which
#' predictions fell above the observation.
#' - "underprediction": The component of WIS measuring the extent to which
#' predictions fell below the observation.
#' - "dispersion": The component of WIS measuring the dispersion of forecast
#' distributions.
#' - "interval_coverage_XX": interval coverage at the "XX" level. For example,
#' "interval_coverage_95" is the 95% interval coverage rate, which would be calculated
#' based on quantiles at the probability levels 0.025 and 0.975.
#' - `output_type == "pmf"`:
#' - "log_score": log score
#'
#' @return forecast_quantile
#'
#' @export
score_model_out <- function(model_out_tbl, target_observations, metrics = NULL,
summarize = TRUE, by = "model_id",
output_type_id_order = NULL) {
# check that model_out_tbl has a single output_type that is supported by this package
# also, retrieve that output_type
output_type <- validate_output_type(model_out_tbl)

# get/validate the scoring metrics
metrics <- get_metrics(metrics, output_type, output_type_id_order)

# assemble data for scoringutils and select output_type-specific metrics
if (output_type == "quantile") {
su_data <- transform_quantile_model_out(model_out_tbl, target_observations)
} else if (output_type == "pmf") {
su_data <- transform_pmf_model_out(model_out_tbl, target_observations, output_type_id_order)
} else if (output_type %in% c("mean", "median")) {
su_data <- transform_point_model_out(model_out_tbl, target_observations, output_type)
}
elray1 marked this conversation as resolved.
Show resolved Hide resolved

# compute scores
scores <- scoringutils::score(su_data, metrics)

# switch back to hubverse naming conventions for model name
scores <- dplyr::rename(
scores,
model_id = "model"
)

# if present, drop predicted and observed columns
drop_cols <- c("predicted", "observed")
scores <- scores[!colnames(scores) %in% drop_cols]
elray1 marked this conversation as resolved.
Show resolved Hide resolved

# if requested, summarize scores
if (summarize) {
scores <- scoringutils::summarize_scores(scores = scores, by = by)
}

return(scores)
}


#' Get metrics if user didn't specify anything; otherwise, process
#' and validate user inputs
#'
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics <- function(metrics, output_type, output_type_id_order) {
if (is.null(metrics)) {
return(get_metrics_default(output_type, output_type_id_order))
} else if (is.character(metrics)) {
return(get_metrics_character(metrics, output_type))
} else {
# We do a minimal preliminary check that the provided `metrics` is a list of
# functions, leaving further checks to scoringutils
checkmate::assert_list(metrics, types = "function")
elray1 marked this conversation as resolved.
Show resolved Hide resolved
return(metrics)
}
}


#' Default metrics if user didn't specify anything
#'
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics_default <- function(output_type, output_type_id_order) {
if (output_type == "quantile") {
metrics <- scoringutils::metrics_quantile()
} else if (output_type == "pmf") {
metrics <- scoringutils::metrics_nominal()
} else if (output_type == "mean") {
metrics <- scoringutils::metrics_point(select = "se_point")
} else if (output_type == "median") {
metrics <- scoringutils::metrics_point(select = "ae_point")
} else {
# we have already validated `output_type`, so this case should not be
# triggered; this case is just double checking in case we add something new
# later, to ensure we update this function.
supported_types <- c("mean", "median", "pmf", "quantile") # nolint object_use_linter
cli::cli_abort(
"Provided `model_out_tbl` contains `output_type` {.val {output_type}};
hubEvals currently only supports the following types:
{.val {supported_types}}"
)
}
elray1 marked this conversation as resolved.
Show resolved Hide resolved

return(metrics)
}


#' Convert character vector of metrics to list of functions
#'
#' @inheritParams score_model_out
#'
#' @return a list of metric functions as required by scoringutils::score()
#'
#' @noRd
get_metrics_character <- function(metrics, output_type) {
if (output_type == "quantile") {
# split into metrics for interval coverage and others
interval_metric_inds <- grepl(pattern = "^interval_coverage_[[:digit:]][[:digit:]]$", metrics)
interval_metrics <- metrics[interval_metric_inds]
other_metrics <- metrics[!interval_metric_inds]

# validate metrics
valid_metrics <- c("ae_median", "wis", "overprediction", "underprediction", "dispersion")
invalid_metrics <- other_metrics[!other_metrics %in% valid_metrics]
valid_metrics <- c(valid_metrics, "interval_coverage_XY")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This simplifies the logic a bit and gives logical indicator variables that will make it easier to debug.

Suggested change
# split into metrics for interval coverage and others
interval_metric_inds <- grepl(pattern = "^interval_coverage_[[:digit:]][[:digit:]]$", metrics)
interval_metrics <- metrics[interval_metric_inds]
other_metrics <- metrics[!interval_metric_inds]
# validate metrics
valid_metrics <- c("ae_median", "wis", "overprediction", "underprediction", "dispersion")
invalid_metrics <- other_metrics[!other_metrics %in% valid_metrics]
valid_metrics <- c(valid_metrics, "interval_coverage_XY")
# identify interval metrics for later use
interval_metric_inds <- grepl(pattern = "^interval_coverage_[[:digit:]][[:digit:]]$", metrics)
interval_metrics <- metrics[interval_metric_inds]
# validate metrics
valid_metrics <- c("ae_median", "wis", "overprediction", "underprediction", "dispersion", "interval_coverage_XY")
not_valid <- !metrics %in% valid_metrics
# excluding the interval metrics from strict matching validation because we have
# identified them earlier with pattern match.
not_interval <- !interval_metric_inds
invalid_metrics <- metrics[not_valid & not_interval]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had considered something like this, but I implemented the original logic because the exact string "interval_coverage_XY" is not actually a valid metric. This is sort of a clumsy indication to the package user that they can ask for any 2-digit interval coverage level, like "interval_coverage_80". So my idea was to add "interval_coverage_XY" to the list of valid metrics that's shown to the user in an error message (with an explanation of what we mean), but to not include that in the list of valid_metrics.

if (length(invalid_metrics) > 0) {
cli::cli_abort(
c(
"`metrics` had {length(invalid_metrics)} unsupported metric{?s} for",
" `output_type` {.val {output_type}}: {.val {invalid_metrics}};",
" supported metrics include {.val {valid_metrics}}",
" where `XY` denotes the coverage level, e.g. \"interval_coverage_95\"."
)
)
}

# assemble metric functions
interval_metric_fns <- lapply(
interval_metrics,
function(metric) {
level <- as.integer(substr(metric, 19, 20))
return(purrr::partial(scoringutils::interval_coverage, interval_range = level))
}
)
names(interval_metric_fns) <- interval_metrics

other_metric_fns <- scoringutils::metrics_quantile(select = other_metrics)

metric_fns <- c(other_metric_fns, interval_metric_fns)[metrics]
return(metric_fns)
elray1 marked this conversation as resolved.
Show resolved Hide resolved
} else if (output_type == "pmf") {
valid_metrics <- c("log_score")
invalid_metrics <- metrics[!metrics %in% valid_metrics]
if (length(invalid_metrics) > 0) {
cli::cli_abort(
c(
"`metrics` had {length(invalid_metrics)} unsupported metric{?s} for",
" `output_type` {.val{output_type}}: {.val {invalid_metrics}};",
elray1 marked this conversation as resolved.
Show resolved Hide resolved
" supported metrics include {.val {valid_metrics}}."
)
)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is repeated three times. It could be a separate validation function, which would improve readability and maintainability of the error.

Note I've included a comment argument to accommodate notes as for the case in quantiles and I've formatted the output a bit better so that the invalid metric stands out from the other options.

  validate_metrics <- function(metrics, valid_metrics, output_type, comment = NULL) {
    invalid_metrics <- metrics[!metrics %in% valid_metrics]
    n <- length(invalid_metrics) 
    if (n > 0) {
      cli::cli_abort(
        c(
          "`metrics` had {n} unsupported metric{?s} for 
          {.arg output_type} {.val {output_type}}: {.strong {.val {invalid_metrics}}};",
          " supported metrics include {.val {valid_metrics}}.",
          comment
        )
      )
    }
  }

With this, each of the different else if statements could be written as:

if (output_type == "quantile") {
  # --- find interval metrics here ...
  validate_metrics(metrics[not_interval], 
    valid_metrics = c("ae_median", "wis", "overprediction", "underprediction", "dispersion", "interval_coverage_XY"), 
    output_type,
    comment = c("i" = "NOTE: `XY` denotes the coverage level, e.g. {.val interval_coverage_95}.")
  )
  # --- define metrics
} else if (output_type == "pmf") {
  validate_metrics(metrics, valid_metrics = c("quantile"), output_type)
  # --- define metrics
} else if (output_type %in% c("mean", "median")) {
  validate_metrics(metrics, valid_metrics = c("ae_point", "se_point"), output_type)
  # --- define metrics
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented a variation on this suggestion in 5d45a2a. I felt it was more straightforward to keep the calculation of which metrics were invalid outside of this function, so the function itself just issues the error message.

metrics <- scoringutils::metrics_nominal(select = metrics)
} else if (output_type %in% c("median", "mean")) {
valid_metrics <- c("ae_point", "se_point")
invalid_metrics <- metrics[!metrics %in% valid_metrics]
if (length(invalid_metrics) > 0) {
cli::cli_abort(
c(
"`metrics` had {length(invalid_metrics)} unsupported metric{?s} for",
" `output_type` {.val{output_type}}: {.val {invalid_metrics}};",
" supported metrics include {.val {valid_metrics}}."
)
)
}
metrics <- scoringutils::metrics_point(select = metrics)
} else {
# we have already validated `output_type`, so this case should not be
# triggered; this case is just double checking in case we add something new
# later, to ensure we update this function.
supported_types <- c("mean", "median", "pmf", "quantile") # nolint object_use_linter
cli::cli_abort(
"Provided `model_out_tbl` contains `output_type` {.val {output_type}};
hubEvals currently only supports the following types:
{.val {supported_types}}"
)
}
elray1 marked this conversation as resolved.
Show resolved Hide resolved

return(metrics)
}
28 changes: 28 additions & 0 deletions R/validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,31 @@ validate_model_out_target_obs <- function(model_out_tbl, target_observations) {

return(model_out_tbl)
}


#' Check that model_out_tble has a single `output_type` that is one of the
#' `output_types` that is supported by this function.
#'
#' @return if valid, the output_type in model_out_tbl
#'
#' @noRd
validate_output_type <- function(model_out_tbl) {
output_type <- unique(model_out_tbl$output_type)
if (length(output_type) != 1) {
cli::cli_abort(
"model_out_tbl must contain a single output_type, but it has multiple:
{.val {output_type}}"
)
}

supported_types <- c("mean", "median", "pmf", "quantile")
if (!output_type %in% supported_types) {
cli::cli_abort(
"Provided `model_out_tbl` contains `output_type` {.val {output_type}};
hubEvals currently only supports the following types:
{.val {supported_types}}"
)
}

return(output_type)
}
84 changes: 84 additions & 0 deletions man/score_model_out.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading