Skip to content

Commit

Permalink
Issue #832 - Make default metrics function S3 (#903)
Browse files Browse the repository at this point in the history
* first prgoress

* Issue #832 - Make example data pre-validated (#901)

* implement S3 method for forecast objects

* update docs

* implement `get_metrics` method for scores objects

* fix docs

* actually fix docs
  • Loading branch information
nikosbosse authored Sep 16, 2024
1 parent 7fa754a commit 656d817
Show file tree
Hide file tree
Showing 42 changed files with 409 additions and 295 deletions.
11 changes: 6 additions & 5 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ S3method(assert_forecast,forecast_nominal)
S3method(assert_forecast,forecast_point)
S3method(assert_forecast,forecast_quantile)
S3method(assert_forecast,forecast_sample)
S3method(get_metrics,forecast_binary)
S3method(get_metrics,forecast_nominal)
S3method(get_metrics,forecast_point)
S3method(get_metrics,forecast_quantile)
S3method(get_metrics,forecast_sample)
S3method(get_metrics,scores)
S3method(head,forecast)
S3method(print,forecast)
S3method(score,default)
Expand Down Expand Up @@ -61,11 +67,6 @@ export(logs_binary)
export(logs_nominal)
export(logs_sample)
export(mad_sample)
export(metrics_binary)
export(metrics_nominal)
export(metrics_point)
export(metrics_quantile)
export(metrics_sample)
export(new_forecast)
export(overprediction_quantile)
export(overprediction_sample)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ of our [original](https://doi.org/10.48550/arXiv.2205.07090) `scoringutils` pape
- `score()` and many other functions now require a validated `forecast` object. `forecast` objects can be created using the functions `as_forecast_point()`, `as_forecast_binary()`, `as_forecast_quantile()`, and `as_forecast_sample()` (which replace the previous `check_forecast()`). A forecast object is a data.table with class `forecast` and an additional class corresponding to the forecast type (e.g. `forecast_quantile`).
`score()` now returns objects of class `scores` with a stored attribute `metrics` that holds the names of the scoring rules that were used. Users can call `get_metrics()` to access the names of those scoring rules.
- `score()` now returns one score per forecast, instead of one score per sample or quantile. For binary and point forecasts, the columns "observed" and "predicted" are now removed for consistency with the other forecast types.
- Users can now also use their own scoring rules (making use of the `metrics` argument, which takes in a named list of functions). Default scoring rules can be accessed using the functions `metrics_point()`, `metrics_sample()`, `metrics_quantile()`, `metrics_binary()`, and `metrics_nominal()`, which return a named list of scoring rules suitable for the respective forecast type. Column names of scores in the output of `score()` correspond to the names of the scoring rules (i.e. the names of the functions in the list of metrics).
- Users can now also use their own scoring rules (making use of the `metrics` argument, which takes in a named list of functions). Default scoring rules can be accessed using the function `get_metrics()`, which is a a generic with S3 methods for each forecast type. It returns a named list of scoring rules suitable for the respective forecast object. For example, you could call `get_metrics(example_quantile)`. Column names of scores in the output of `score()` correspond to the names of the scoring rules (i.e. the names of the functions in the list of metrics).
- Instead of supplying arguments to `score()` to manipulate individual scoring rules users should now manipulate the metric list being supplied using `purrr::partial()` and `select_metric()`. See `?score()` for more information.
- the CRPS is now reported as decomposition into dispersion, overprediction and underprediction.

Expand Down
2 changes: 1 addition & 1 deletion R/correlations.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#'
#' get_correlations(scores)
get_correlations <- function(scores,
metrics = get_metrics(scores),
metrics = get_metrics.scores(scores),
...) {
scores <- ensure_data.table(scores)
assert_subset(metrics, colnames(scores), empty.ok = FALSE)
Expand Down
144 changes: 80 additions & 64 deletions R/default-scoring-rules.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
#' @export
#' @examples
#' select_metrics(
#' metrics = metrics_binary(),
#' metrics = get_metrics(example_binary),
#' select = "brier_score"
#' )
#' select_metrics(
#' metrics = metrics_binary(),
#' metrics = get_metrics(example_binary),
#' exclude = "log_score"
#' )
select_metrics <- function(metrics, select = NULL, exclude = NULL) {
Expand All @@ -37,92 +37,114 @@ select_metrics <- function(metrics, select = NULL, exclude = NULL) {
}
assert_subset(select, allowed)
return(metrics[select])
}

#' Get metrics
#'
#' @description
#' Generic function to to obtain default metrics availble for scoring or metrics
#' that were used for scoring.
#'
#' - If called on `forecast` object it returns a list of functions that can be
#' used for scoring.
#' - If called on a `scores` object (see [score()]), it returns a character vector
#' with the names of the metrics that were used for scoring.
#'
#' See the documentation for the actual methods in the `See Also` section below
#' for more details. Alternatively call `?get_metrics.<forecast_type>` or
#' `?get_metrics.scores`.
#'
#' @param x A `forecast` or `scores` object.
#' @param ... Additional arguments passed to the method.
#' @details
#' See [as_forecast()] for more information on `forecast` objects and [score()]
#' for more information on `scores` objects.
#'
#' @family `get_metrics` functions
#' @keywords handle-metrics
#' @export
get_metrics <- function(x, ...) {
UseMethod("get_metrics")
}


#' @title Default metrics and scoring rules for binary forecasts
#' @description
#' Helper function that returns a named list of default
#' scoring rules suitable for binary forecasts.
#' Get default metrics for binary forecasts
#'
#' The default scoring rules are:
#' @description
#' For binary forecasts, the default scoring rules are:
#' - "brier_score" = [brier_score()]
#' - "log_score" = [logs_binary()]
#'
#' @inheritSection illustration-input-metric-binary-point Input format
#' @inherit select_metrics params return
#' @param x A forecast object (a validated data.table with predicted and
#' observed values, see [as_forecast()]).
#' @param select A character vector of scoring rules to select from the list. If
#' `select` is `NULL` (the default), all possible scoring rules are returned.
#' @param exclude A character vector of scoring rules to exclude from the list.
#' If `select` is not `NULL`, this argument is ignored.
#' @param ... unused
#' @return A list of scoring functions.
#' @export
#' @family `get_metrics` functions
#' @keywords handle-metrics
#' @examples
#' metrics_binary()
#' metrics_binary(select = "brier_score")
#' metrics_binary(exclude = "log_score")
metrics_binary <- function(select = NULL, exclude = NULL) {
#' get_metrics(example_binary)
#' get_metrics(example_binary, select = "brier_score")
#' get_metrics(example_binary, exclude = "log_score")
get_metrics.forecast_binary <- function(x, select = NULL, exclude = NULL, ...) {
all <- list(
brier_score = brier_score,
log_score = logs_binary
)
selected <- select_metrics(all, select, exclude)
return(selected)
select_metrics(all, select, exclude)
}


#' @title Scoring rules for nominal forecasts
#' @description Helper function that returns a named list of default
#' scoring rules suitable for nominal forecasts.
#'
#' The default scoring rules are:
#' Get default metrics for nominal forecasts
#' @inheritParams get_metrics.forecast_binary
#' @description
#' For nominal forecasts, the default scoring rule is:
#' - "log_score" = [logs_nominal()]
#' @inherit select_metrics params return
#' @export
#' @keywords metric
#' @family `get_metrics` functions
#' @keywords handle-metrics
#' @examples
#' metrics_nominal()
#' metrics_nominal(select = "log_score")
metrics_nominal <- function(select = NULL, exclude = NULL) {
#' get_metrics(example_nominal)
get_metrics.forecast_nominal <- function(x, select = NULL, exclude = NULL, ...) {
all <- list(
log_score = logs_nominal
)
selected <- select_metrics(all, select, exclude)
return(selected)
select_metrics(all, select, exclude)
}


#' @title Default metrics and scoring rules for point forecasts
#' @description
#' Helper function that returns a named list of default
#' scoring rules suitable for point forecasts.
#' Get default metrics for point forecasts
#'
#' The default scoring rules are:
#' @description
#' For point forecasts, the default scoring rules are:
#' - "ae_point" = [ae()][Metrics::ae()]
#' - "se_point" = [se()][Metrics::se()]
#' - "ape" = [ape()][Metrics::ape()]
#'
#' @inheritSection illustration-input-metric-binary-point Input format
#' @inherit select_metrics params return
#' @inheritParams get_metrics.forecast_binary
#' @export
#' @family `get_metrics` functions
#' @keywords handle-metrics
#' @examples
#' metrics_point()
#' metrics_point(select = "ape")
metrics_point <- function(select = NULL, exclude = NULL) {
#' get_metrics(example_point, select = "ape")
get_metrics.forecast_point <- function(x, select = NULL, exclude = NULL, ...) {
all <- list(
ae_point = Metrics::ae,
se_point = Metrics::se,
ape = Metrics::ape
)
selected <- select_metrics(all, select, exclude)
return(selected)
select_metrics(all, select, exclude)
}


#' @title Default metrics and scoring rules sample-based forecasts
#' @description
#' Helper function that returns a named list of default
#' scoring rules suitable for forecasts in a sample-based format.
#' Get default metrics for sample-based forecasts
#'
#' The default scoring rules are:
#' @description
#' For sample-based forecasts, the default scoring rules are:
#' - "crps" = [crps_sample()]
#' - "overprediction" = [overprediction_sample()]
#' - "underprediction" = [underprediction_sample()]
Expand All @@ -133,15 +155,14 @@ metrics_point <- function(select = NULL, exclude = NULL) {
#' - "bias" = [bias_sample()]
#' - "ae_median" = [ae_median_sample()]
#' - "se_mean" = [se_mean_sample()]
#'
#' @inheritSection illustration-input-metric-sample Input format
#' @inherit select_metrics params return
#' @inheritParams get_metrics.forecast_binary
#' @export
#' @family `get_metrics` functions
#' @keywords handle-metrics
#' @examples
#' metrics_sample()
#' metrics_sample(select = "mad")
metrics_sample <- function(select = NULL, exclude = NULL) {
#' get_metrics(example_sample_continuous, exclude = "mad")
get_metrics.forecast_sample <- function(x, select = NULL, exclude = NULL, ...) {
all <- list(
bias = bias_sample,
dss = dss_sample,
Expand All @@ -154,17 +175,14 @@ metrics_sample <- function(select = NULL, exclude = NULL) {
ae_median = ae_median_sample,
se_mean = se_mean_sample
)
selected <- select_metrics(all, select, exclude)
return(selected)
select_metrics(all, select, exclude)
}


#' @title Default metrics and scoring rules for quantile-based forecasts
#' @description
#' Helper function that returns a named list of default
#' scoring rules suitable for forecasts in a quantile-based format.
#' Get default metrics for quantile-based forecasts
#'
#' The default scoring rules are:
#' @description
#' For quantile-based forecasts, the default scoring rules are:
#' - "wis" = [wis]
#' - "overprediction" = [overprediction_quantile()]
#' - "underprediction" = [underprediction_quantile()]
Expand All @@ -184,16 +202,15 @@ metrics_sample <- function(select = NULL, exclude = NULL) {
#' accept get passed on to it. `interval_range = 90` is set in the function
#' definition, as passing an argument `interval_range = 90` to [score()] would
#' mean it would also get passed to `interval_coverage_50`.
#'
#' @inheritSection illustration-input-metric-quantile Input format
#' @inherit select_metrics params return
#' @inheritParams get_metrics.forecast_binary
#' @export
#' @importFrom purrr partial
#' @family `get_metrics` functions
#' @keywords handle-metrics
#' @importFrom purrr partial
#' @examples
#' metrics_quantile()
#' metrics_quantile(select = "wis")
metrics_quantile <- function(select = NULL, exclude = NULL) {
#' get_metrics(example_quantile, select = "wis")
get_metrics.forecast_quantile <- function(x, select = NULL, exclude = NULL, ...) {
all <- list(
wis = wis,
overprediction = overprediction_quantile,
Expand All @@ -207,6 +224,5 @@ metrics_quantile <- function(select = NULL, exclude = NULL) {
interval_coverage_deviation = interval_coverage_deviation,
ae_median = ae_median_quantile
)
selected <- select_metrics(all, select, exclude)
return(selected)
select_metrics(all, select, exclude)
}
19 changes: 11 additions & 8 deletions R/get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,22 @@ get_type <- function(x) {
#' `attr(scores, "metrics") <- c("names", "of", "your", "scores")` (the
#' order does not matter).
#'
#' @param scores A data.table with an attribute `metrics`.
#' @param x A `scores` object, (a data.table with an attribute `metrics` as
#' produced by [score()]).
#' @param error Throw an error if there is no attribute called `metrics`?
#' Default is FALSE.
#' Default is FALSE.
#' @param ... unused
#' @importFrom cli cli_abort cli_warn
#' @importFrom checkmate assert_data_frame
#' @return
#' Character vector with the names of the scoring rules that were used
#' for scoring or `NULL` if no scores were computed previously.
#' for scoring.
#' @keywords handle-metrics
#' @family `get_metrics` functions
#' @export
get_metrics <- function(scores, error = FALSE) {
assert_data_frame(scores)
metrics <- attr(scores, "metrics")
get_metrics.scores <- function(x, error = FALSE, ...) {
assert_data_frame(x)
metrics <- attr(x, "metrics")
if (error && is.null(metrics)) {
#nolint start: keyword_quote_linter
cli_abort(
Expand All @@ -131,9 +134,9 @@ get_metrics <- function(scores, error = FALSE) {
#nolint end
}

if (!all(metrics %in% names(scores))) {
if (!all(metrics %in% names(x))) {
#nolint start: keyword_quote_linter object_usage_linter
missing <- setdiff(metrics, names(scores))
missing <- setdiff(metrics, names(x))
cli_warn(
c(
"!" = "The following scores have been previously computed, but are no
Expand Down
4 changes: 2 additions & 2 deletions R/pairwise-comparisons.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ get_pairwise_comparisons <- function(

# we need the score names attribute to make sure we can determine the
# forecast unit correctly, so here we check it exists
metrics <- get_metrics(scores, error = TRUE)
metrics <- get_metrics.scores(scores, error = TRUE)

# check that metric is a subset of the scores and is of length 1
assert_subset(metric, metrics, empty.ok = FALSE)
Expand Down Expand Up @@ -577,7 +577,7 @@ add_relative_skill <- function(
)

# store original metrics
metrics <- get_metrics(scores)
metrics <- get_metrics.scores(scores)

# delete unnecessary columns
pairwise[, c(
Expand Down
2 changes: 1 addition & 1 deletion R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ plot_forecast_counts <- function(forecast_counts,
plot_correlations <- function(correlations, digits = NULL) {

assert_data_frame(correlations)
metrics <- get_metrics(correlations, error = TRUE)
metrics <- get_metrics.scores(correlations, error = TRUE)

lower_triangle <- get_lower_tri(correlations[, .SD, .SDcols = metrics])

Expand Down
Loading

0 comments on commit 656d817

Please sign in to comment.