Skip to content

Commit

Permalink
Merge branch 'develop' into rework-validate_forceasts
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse authored Jan 3, 2024
2 parents ba91d18 + c811d55 commit 41ccdbb
Show file tree
Hide file tree
Showing 28 changed files with 481 additions and 212 deletions.
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,15 @@ export(plot_score_table)
export(plot_wis)
export(quantile_score)
export(quantile_to_interval)
export(rules_binary)
export(rules_point)
export(rules_quantile)
export(rules_sample)
export(run_safely)
export(sample_to_quantile)
export(score)
export(se_mean_sample)
export(select_rules)
export(set_forecast_unit)
export(squared_error)
export(summarise_scores)
Expand All @@ -86,6 +91,7 @@ importFrom(checkmate,assert_logical)
importFrom(checkmate,assert_number)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,assert_string)
importFrom(checkmate,assert_subset)
importFrom(checkmate,assert_vector)
importFrom(checkmate,check_atomic_vector)
importFrom(checkmate,check_data_frame)
Expand Down
13 changes: 7 additions & 6 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# scoringutils 1.3
# scoringutils 2.0.0

This major update and addresses a variety of comments made by reviewers from the Journal of Statistical Software (see preprint of the manuscript [here](https://arxiv.org/abs/2205.07090)).

Expand All @@ -7,20 +7,21 @@ The update introduces breaking changes. If you want to keep using the older vers
## Package updates
- In `score()`, required columns "true_value" and "prediction" were renamed and replaced by required columns "observed" and "predicted". Scoring functions now also use the function arguments "observed" and "predicted" everywhere consistently.
- The overall scoring workflow was updated. `score()` is now a generic function that dispatches the correct method based on the forecast type. forecast types currently supported are "binary", "point", "sample" and "quantile" with corresponding classes "forecast_binary", "forecast_point", "forecast_sample" and "forecast_quantile". An object of class `forecast_*` can be created using the function `as_forecast()`, which also replaces the previous function `check_forecasts()` (see more information below).
- Scoring functions received a consistent interface and input checks:
- metrics for binary forecasts:
- Scoring rules (functions used for scoring) received a consistent interface and input checks:
- Scoring rules for binary forecasts:
- `observed`: factor with exactly 2 levels
- `predicted`: numeric, vector with probabilities
- metrics for point forecasts:
- Scoring rules for point forecasts:
- `observed`: numeric vector
- `predicted`: numeric vector
- metrics for sample-based forecasts:
- Scoring rules for sample-based forecasts:
- `observed`: numeric, either a scalar or a vector
- `predicted`: numeric, a vector (if `observed` is a scalar) or a matrix (if `observed` is a vector)
- metrics for quantile-based forecasts:
- Scoring rules for quantile-based forecasts:
- `observed`: numeric, either a scalar or a vector
- `predicted`: numeric, a vector (if `observed` is a scalar) or a matrix (if `observed` is a vector)
- `quantile`: numeric, a vector with quantile-levels. Can alternatively be a matrix of the same shape as `predicted`.
- Users can now supply their own scoring rules to `score()` as a list of functions. Default scoring rules can be accessed using the functions `rules_point()`, `rules_sample()`, `rules_quantile()` and `rules_binary()`, which return a list of scoring rules suitable for the respective forecast type.
- `check_forecasts()` was replaced by a different workflow. There now is a function, `as_forecast()`, that determines forecast type of the data, constructs a forecasting object and validates it using the function `validate_forecast()` (a generic that dispatches the correct method based on the forecast type). Objects of class `forecast_binary`, `forecast_point`, `forecast_sample` and `forecast_quantile` have print methods that fulfill the functionality of `check_forecasts()`.
- The functionality for computing pairwise comparisons was now split from `summarise_scores()`. Instead of doing pairwise comparisons as part of summarising scores, a new function, `add_pairwise_comparison()`, was introduced that takes summarised scores as an input and adds pairwise comparisons to it.
- `add_coverage()` was reworked completely. It's new purpose is now to add coverage information to the raw forecast data (essentially fulfilling some of the functionality that was previously covered by `score_quantile()`)
Expand Down
1 change: 1 addition & 0 deletions R/check-input-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ assure_model_column <- function(data) {
#' @inherit document_check_functions params return
#' @keywords internal_input_check
check_number_per_forecast <- function(data, forecast_unit) {
data <- na.omit(data)
# check whether there are the same number of quantiles, samples --------------
data[, scoringutils_InternalNumCheck := length(predicted), by = forecast_unit]
n <- unique(data$scoringutils_InternalNumCheck)
Expand Down
48 changes: 0 additions & 48 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -195,51 +195,3 @@
#'
#' @keywords info
"metrics"

#' Default metrics for binary forecasts.
#'
#' A named list with functions:
#' - "brier_score" = [brier_score()]
#' - "log_score" = [logs_binary()]
#' @keywords info
"metrics_binary"

#' Default metrics for point forecasts.
#'
#' A named list with functions:
#' - "ae_point" = [ae()][Metrics::ae()]
#' - "se_point" = [se()][Metrics::se()]
#' - "ape" = [ape()][Metrics::ape()]
#' @keywords info
"metrics_point"

#' Default metrics for sample-based forecasts.
#'
#' A named list with functions:
#' - "mad" = [mad_sample()]
#' - "bias" = [bias_sample()]
#' - "dss" = [dss_sample()]
#' - "crps" = [crps_sample()]
#' - "log_score" = [logs_sample()]
#' - "mad" = [mad_sample()]
#' - "ae_median" = [ae_median_sample()]
#' - "se_mean" = [se_mean_sample()]
#' @keywords info
"metrics_sample"

#' Default metrics for quantile-based forecasts.
#'
#' A named list with functions:
#' - "wis" = [wis]
#' - "overprediction" = [overprediction()]
#' - "underprediction" = [underprediction()]
#' - "dispersion" = [dispersion()]
#' - "bias" = [bias_quantile()]
#' - "coverage_50" = [interval_coverage_quantile()]
#' - "coverage_90" = \(...) \{
#' run_safely(..., range = 90, fun = [interval_coverage_quantile])
#' \}
#' - "coverage_deviation" = [interval_coverage_dev_quantile()],
#' - "ae_median" = [ae_median_quantile()]
#' @keywords info
"metrics_quantile"
169 changes: 169 additions & 0 deletions R/default-scoring-rules.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#' @title Select Scoring Rules From A List of Possible Scoring Rules
#' @description Helper function to return only the scoring rules selected by
#' the user from a list of possible scoring rules.
#' @param rules A list of scoring rules.
#' @param select A character vector of scoring rules to select from the list.
#' If `select` is `NULL` (the default), all possible scoring rules are returned.
#' @param exclude A character vector of scoring rules to exclude from the list.
#' If `select` is not `NULL`, this argument is ignored.
#' @return A list of scoring rules.
#' @keywords metric
#' @importFrom checkmate assert_subset assert_list
#' @export
#' @examples
#' select_rules(
#' rules = rules_binary(),
#' select = "brier_score"
#' )
#' select_rules(
#' rules = rules_binary(),
#' exclude = "log_score"
#' )
select_rules <- function(rules, select = NULL, exclude = NULL) {
assert_character(x = c(select, exclude), null.ok = TRUE)
assert_list(rules, names = "named")
allowed <- names(rules)

if (is.null(select) && is.null(exclude)) {
return(rules)
} else if (is.null(select)) {
assert_subset(exclude, allowed)
select <- allowed[!allowed %in% exclude]
return(rules[select])
} else {
assert_subset(select, allowed)
return(rules[select])
}
}


#' @title Scoring Rules for Binary Forecasts
#' @description Helper function that returns a named list of default
#' scoring rules suitable for binary forecasts.
#'
#' The default scoring rules are:
#' - "brier_score" = [brier_score()]
#' - "log_score" = [logs_binary()]
#' @inherit select_rules params return
#' @export
#' @keywords metric
#' @examples
#' rules_binary()
#' rules_binary(select = "brier_score")
#' rules_binary(exclude = "log_score")
rules_binary <- function(select = NULL, exclude = NULL) {
all <- list(
brier_score = brier_score,
log_score = logs_binary
)
selected <- select_rules(all, select, exclude)
return(selected)
}


#' @title Scoring Rules for Point Forecasts
#' @description Helper function that returns a named list of default
#' scoring rules suitable for point forecasts.
#'
#' The default scoring rules are:
#' - "ae_point" = [ae()][Metrics::ae()]
#' - "se_point" = [se()][Metrics::se()]
#' - "ape" = [ape()][Metrics::ape()]
#' @inherit select_rules params return
#' @export
#' @keywords metric
#' @examples
#' rules_point()
#' rules_point(select = "ape")
rules_point <- function(select = NULL, exclude = NULL) {
all <- list(
ae_point = Metrics::ae,
se_point = Metrics::se,
ape = Metrics::ape
)
selected <- select_rules(all, select, exclude)
return(selected)
}


#' @title Scoring Rules for Sample-Based Forecasts
#' @description Helper function that returns a named list of default
#' scoring rules suitable for forecasts in a sample-based format
#'
#' The default scoring rules are:
#' - "mad" = [mad_sample()]
#' - "bias" = [bias_sample()]
#' - "dss" = [dss_sample()]
#' - "crps" = [crps_sample()]
#' - "log_score" = [logs_sample()]
#' - "mad" = [mad_sample()]
#' - "ae_median" = [ae_median_sample()]
#' - "se_mean" = [se_mean_sample()]
#' @inherit select_rules params return
#' @export
#' @keywords metric
#' @examples
#' rules_sample()
#' rules_sample(select = "mad")
rules_sample <- function(select = NULL, exclude = NULL) {
all <- list(
bias = bias_sample,
dss = dss_sample,
crps = crps_sample,
log_score = logs_sample,
mad = mad_sample,
ae_median = ae_median_sample,
se_mean = se_mean_sample
)
selected <- select_rules(all, select, exclude)
return(selected)
}


#' @title Scoring Rules for Quantile-Based Forecasts
#' @description Helper function that returns a named list of default
#' scoring rules suitable for forecasts in a quantile-based format
#'
#' The default scoring rules are:
#' - "wis" = [wis]
#' - "overprediction" = [overprediction()]
#' - "underprediction" = [underprediction()]
#' - "dispersion" = [dispersion()]
#' - "bias" = [bias_quantile()]
#' - "coverage_50" = [interval_coverage_quantile()]
#' - "coverage_90" = function(...) \{
#' run_safely(..., range = 90, fun = [interval_coverage_quantile])
#' \}
#' - "coverage_deviation" = [interval_coverage_dev_quantile()],
#' - "ae_median" = [ae_median_quantile()]
#'
#' Note: The `coverage_90` scoring rule is created as a wrapper around
#' [interval_coverage_quantile()], making use of the function [run_safely()].
#' This construct allows the function to deal with arbitrary arguments in `...`,
#' while making sure that only those that [interval_coverage_quantile()] can
#' accept get passed on to it. `range = 90` is set in the function definition,
#' as passing an argument `range = 90` to [score()] would mean it would also
#' get passed to `coverage_50`.
#' @inherit select_rules params return
#' @export
#' @keywords metric
#' @examples
#' rules_quantile()
#' rules_quantile(select = "wis")
rules_quantile <- function(select = NULL, exclude = NULL) {
all <- list(
wis = wis,
overprediction = overprediction,
underprediction = underprediction,
dispersion = dispersion,
bias = bias_quantile,
coverage_50 = interval_coverage_quantile,
coverage_90 = function(...) {
run_safely(..., range = 90, fun = interval_coverage_quantile)
},
coverage_deviation = interval_coverage_dev_quantile,
ae_median = ae_median_quantile
)
selected <- select_rules(all, select, exclude)
return(selected)
}
12 changes: 6 additions & 6 deletions R/score.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
#' @inheritSection forecast_types Forecast unit
#' @param data A data.frame or data.table with predicted and observed values.
#' @param metrics A named list of scoring functions. Names will be used as
#' column names in the output. See [metrics_point()], [metrics_binary()],
#' `metrics_quantile()`, and [metrics_sample()] for more information on the
#' column names in the output. See [rules_point()], [rules_binary()],
#' [rules_quantile()], and [rules_sample()] for more information on the
#' default metrics used.
#' @param ... additional arguments
#' @return A data.table with unsummarised scores. This will generally be
Expand Down Expand Up @@ -78,7 +78,7 @@ score.default <- function(data, ...) {
#' @importFrom stats na.omit
#' @rdname score
#' @export
score.forecast_binary <- function(data, metrics = metrics_binary, ...) {
score.forecast_binary <- function(data, metrics = rules_binary(), ...) {
data <- validate_forecast(data)
data <- na.omit(data)
metrics <- validate_metrics(metrics)
Expand All @@ -99,7 +99,7 @@ score.forecast_binary <- function(data, metrics = metrics_binary, ...) {
#' @importFrom stats na.omit
#' @rdname score
#' @export
score.forecast_point <- function(data, metrics = metrics_point, ...) {
score.forecast_point <- function(data, metrics = rules_point(), ...) {
data <- validate_forecast(data)
data <- na.omit(data)
metrics <- validate_metrics(metrics)
Expand All @@ -117,7 +117,7 @@ score.forecast_point <- function(data, metrics = metrics_point, ...) {
#' @importFrom stats na.omit
#' @rdname score
#' @export
score.forecast_sample <- function(data, metrics = metrics_sample, ...) {
score.forecast_sample <- function(data, metrics = rules_sample(), ...) {
data <- validate_forecast(data)
data <- na.omit(data)
forecast_unit <- get_forecast_unit(data)
Expand Down Expand Up @@ -155,7 +155,7 @@ score.forecast_sample <- function(data, metrics = metrics_sample, ...) {
#' @importFrom data.table `:=` as.data.table rbindlist %like%
#' @rdname score
#' @export
score.forecast_quantile <- function(data, metrics = metrics_quantile, ...) {
score.forecast_quantile <- function(data, metrics = rules_quantile(), ...) {
data <- validate_forecast(data)
data <- na.omit(data)
forecast_unit <- get_forecast_unit(data)
Expand Down
2 changes: 1 addition & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ strip_attributes <- function(object, attributes) {

#' @title Run a function safely
#' @description This is a wrapper function designed to run a function safely
#' when it is not completely clear what arguments coulld be passed to the
#' when it is not completely clear what arguments could be passed to the
#' function.
#'
#' All named arguments in `...` that are not accepted by `fun` are removed.
Expand Down
8 changes: 4 additions & 4 deletions R/z_globalVariables.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ globalVariables(c(
"metric",
"metrics_select",
"metrics",
"metrics_binary",
"metrics_point",
"metrics_quantile",
"metrics_sample",
"rules_binary",
"rules_point",
"rules_quantile",
"rules_sample",
"model",
"n_obs",
"n_obs wis_component_name",
Expand Down
Binary file removed data/metrics_binary.rda
Binary file not shown.
Binary file removed data/metrics_point.rda
Binary file not shown.
Binary file removed data/metrics_quantile.rda
Binary file not shown.
Binary file removed data/metrics_sample.rda
Binary file not shown.
36 changes: 0 additions & 36 deletions inst/create-list-available-forecasts.R

This file was deleted.

Loading

0 comments on commit 41ccdbb

Please sign in to comment.