Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #604 - Add support for nominal forecasts #837

Merged
merged 46 commits into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
d930c6a
Add skeleton for a score method for categorical forecasts
nikosbosse Jan 17, 2024
be05422
skeleton for validate_forecast method for categorical forecasts
nikosbosse Jan 17, 2024
b1ac284
add skeleton for default scoring rules for categorical forecasts
nikosbosse Jan 17, 2024
d0f57fb
empty skeleton for check functions for categorical forecasts
nikosbosse Jan 17, 2024
b0ee467
fix merge conflict
nikosbosse Jun 2, 2024
3c5c23e
implement nominal forecast class
nikosbosse Jun 7, 2024
34be31a
fix issues
nikosbosse Jun 7, 2024
bf71982
make code work
nikosbosse Jun 7, 2024
5eadf04
add example data
nikosbosse Jun 7, 2024
5e80e29
fix warnings
nikosbosse Jun 7, 2024
907511d
add tests
nikosbosse Jun 7, 2024
4681f92
Refine tests and docs
nikosbosse Jun 7, 2024
b5ef322
improve tests
nikosbosse Jun 7, 2024
d666e64
fix linting issues
nikosbosse Jun 7, 2024
9aa1e0a
fix issues
nikosbosse Jun 8, 2024
0f15226
try fix for failing test
nikosbosse Jun 8, 2024
fec609f
try fixing test again...
nikosbosse Jun 8, 2024
14a5fb7
update test to work with old R version
nikosbosse Jun 8, 2024
dfc6c50
round and round and round it goes
nikosbosse Jun 8, 2024
13cd0f9
Merge branch 'main' into multiclass
nikosbosse Jun 14, 2024
1cad15a
Require R4.0
nikosbosse Jun 12, 2024
de851e3
remove R3.6 from CI checks
nikosbosse Jun 13, 2024
be1eab0
update NEWS file
nikosbosse Jun 13, 2024
efa1ec7
add CI check for 4.0 back in
nikosbosse Jun 13, 2024
3bd0db6
fix typo in news file
nikosbosse Jun 14, 2024
dd6cc28
update manual figure
nikosbosse Jun 14, 2024
2f12516
Merge branch 'main' into multiclass
nikosbosse Jun 14, 2024
fb3f3ad
Merge branch 'main' into multiclass
nikosbosse Jun 16, 2024
3ba8a46
update tests after merge conflict
nikosbosse Jun 16, 2024
7f922ac
Merge branch 'main' into multiclass
nikosbosse Jul 21, 2024
7b65d58
update docs
nikosbosse Jul 21, 2024
bac88d9
fix linter issue
nikosbosse Jul 21, 2024
7ff58b8
fix tests
nikosbosse Jul 21, 2024
708ad74
update tests
nikosbosse Jul 21, 2024
9082a0a
update tests
nikosbosse Jul 21, 2024
746b3c8
use magrittr pipe
nikosbosse Jul 21, 2024
340e9d5
update docs
nikosbosse Jul 22, 2024
4adee5a
Merge branch 'main' into multiclass
nikosbosse Jul 23, 2024
92c8d80
address comments from Nick
nikosbosse Jul 27, 2024
b800007
Merge branch 'main' into multiclass
nikosbosse Jul 27, 2024
fd36025
fix test
nikosbosse Jul 27, 2024
dedbc97
Merge branch 'main' into multiclass
seabbs Jul 30, 2024
7ac15f1
Merge branch 'main' into multiclass
nikosbosse Aug 10, 2024
013816a
update docs
nikosbosse Aug 10, 2024
8d78327
Merge branch 'main' into multiclass
nikosbosse Aug 10, 2024
f453895
Update docs
nikosbosse Aug 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@ S3method(as_forecast_quantile,default)
S3method(as_forecast_quantile,forecast_sample)
S3method(assert_forecast,default)
S3method(assert_forecast,forecast_binary)
S3method(assert_forecast,forecast_nominal)
S3method(assert_forecast,forecast_point)
S3method(assert_forecast,forecast_quantile)
S3method(assert_forecast,forecast_sample)
S3method(print,forecast)
S3method(score,default)
S3method(score,forecast_binary)
S3method(score,forecast_nominal)
S3method(score,forecast_point)
S3method(score,forecast_quantile)
S3method(score,forecast_sample)
export(add_relative_skill)
export(ae_median_quantile)
export(ae_median_sample)
export(as_forecast_binary)
export(as_forecast_nominal)
export(as_forecast_point)
export(as_forecast_quantile)
export(as_forecast_sample)
Expand All @@ -44,14 +47,17 @@ export(interval_coverage)
export(interval_coverage_deviation)
export(is_forecast)
export(is_forecast_binary)
export(is_forecast_nominal)
export(is_forecast_point)
export(is_forecast_quantile)
export(is_forecast_sample)
export(log_shift)
export(logs_binary)
export(logs_nominal)
export(logs_sample)
export(mad_sample)
export(metrics_binary)
export(metrics_nominal)
export(metrics_point)
export(metrics_quantile)
export(metrics_sample)
Expand Down Expand Up @@ -93,8 +99,10 @@ importFrom(checkmate,assert_function)
importFrom(checkmate,assert_list)
importFrom(checkmate,assert_logical)
importFrom(checkmate,assert_matrix)
importFrom(checkmate,assert_names)
importFrom(checkmate,assert_number)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,assert_set_equal)
importFrom(checkmate,assert_subset)
importFrom(checkmate,assert_vector)
importFrom(checkmate,check_atomic_vector)
Expand All @@ -109,6 +117,7 @@ importFrom(checkmate,test_factor)
importFrom(checkmate,test_list)
importFrom(checkmate,test_names)
importFrom(checkmate,test_numeric)
importFrom(checkmate,test_set_equal)
importFrom(checkmate,test_subset)
importFrom(cli,cli_abort)
importFrom(cli,cli_inform)
Expand Down
8 changes: 4 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ of our [original](https://doi.org/10.48550/arXiv.2205.07090) `scoringutils` pape
### `score()`
- The main function of the package is still the function `score()`. However, we reworked the function and updated and clarified its input requirements.
- The previous columns "true_value" and "prediction" were renamed. `score()` now requires columns called "observed" and "predicted" and "model". The column `quantile` was renamed to `quantile_level` and `sample` was renamed to `sample_id`
- `score()` is now a generic. It has S3 methods for the classes `forecast_point`, `forecast_binary`, `forecast_quantile` and `forecast_sample`, which correspond to the different forecast types that can be scored with `scoringutils`.
- `score()` is now a generic. It has S3 methods for the classes `forecast_point`, `forecast_binary`, `forecast_quantile`, `forecast_sample`, and `forecast_nominal`, which correspond to the different forecast types that can be scored with `scoringutils`.
- `score()` now calls `na.omit()` on the data, instead of only removing rows with missing values in the columns `observed` and `predicted`. This is because `NA` values in other columns can also mess up e.g. grouping of forecasts according to the unit of a single forecast.
- `score()` and many other functions now require a validated `forecast` object. `forecast` objects can be created using the functions `as_forecast_point()`, `as_forecast_binary()`, `as_forecast_quantile()`, and `as_forecast_sample()` (which replace the previous `check_forecast()`). A forecast object is a data.table with class `forecast` and an additional class corresponding to the forecast type (e.g. `forecast_quantile`).
`score()` now returns objects of class `scores` with a stored attribute `metrics` that holds the names of the scoring rules that were used. Users can call `get_metrics()` to access the names of those scoring rules.
- `score()` now returns one score per forecast, instead of one score per sample or quantile.
- Users can now also use their own scoring rules (making use of the `metrics` argument, which takes in a named list of functions). Default scoring rules can be accessed using the functions `metrics_point()`, `metrics_sample()`, `metrics_quantile()` and `metrics_binary()`, which return a named list of scoring rules suitable for the respective forecast type. Column names of scores in the output of `score()` correspond to the names of the scoring rules (i.e. the names of the functions in the list of metrics).
- Users can now also use their own scoring rules (making use of the `metrics` argument, which takes in a named list of functions). Default scoring rules can be accessed using the functions `metrics_point()`, `metrics_sample()`, `metrics_quantile()`, `metrics_binary()`, and `metrics_nominal()`, which return a named list of scoring rules suitable for the respective forecast type. Column names of scores in the output of `score()` correspond to the names of the scoring rules (i.e. the names of the functions in the list of metrics).
- Instead of supplying arguments to `score()` to manipulate individual scoring rules users should now manipulate the metric list being supplied using `purrr::partial()` and `select_metric()`. See `?score()` for more information.
- the CRPS is now reported as decomposition into dispersion, overprediction and underprediction.

Expand All @@ -36,8 +36,8 @@ of our [original](https://doi.org/10.48550/arXiv.2205.07090) `scoringutils` pape
quantile_level = "quantile_level",
forecast_unit = c("model", "location", "target_end_date", "forecast_date", "target_type")
)
scores <- score(forecast_quantile)
```
scores <- score(forecast_quantile)
```
- Overall, we updated the suggested workflows for how users should work with the package. The following gives an overview (see the [updated paper](https://drive.google.com/file/d/1URaMsXmHJ1twpLpMl1sl2HW4lPuUycoj/view?usp=drive_link) for more details).
![package workflows](./man/figures/workflow.png)

Expand Down
65 changes: 65 additions & 0 deletions R/check-inputs-scoring-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,71 @@ check_input_binary <- function(observed, predicted) {
}


#' @title Assert that inputs are correct for nominal forecasts
#' @description Function assesses whether the inputs correspond to the
#' requirements for scoring nominal forecasts.
#' @param observed Input to be checked. Should be a factor of length n with
#' N levels holding the observed values. n is the number of observations and
#' N is the number of possible outcomes the observed values can assume.
#' output)
#' @param predicted Input to be checked. Should be nxN matrix of predictive
#' quantiles, n (number of rows) being the number of data points and N
#' (number of columns) the number of possible outcomes the observed values
#' can assume.
#' If `observed` is just a single number, then predicted can just be a
#' vector of size N.
#' @param predicted Input to be checked. `predicted` should be a vector of
#' length n, holding probabilities. Alternatively, `predicted` can be a matrix
#' of size n x 1. Values represent the probability that
#' the corresponding value in `observed` will be equal to the highest
#' available factor level.
#' @param predicted_label Factor of length N with N levels, where N is the
#' number of possible outcomes the observed values can assume.
#' @importFrom checkmate assert_factor assert_numeric assert_set_equal
#' @inherit document_assert_functions return
#' @keywords internal_input_check
assert_input_nominal <- function(observed, predicted, predicted_label) {
# observed
assert_factor(observed, min.len = 1, min.levels = 2)
levels <- levels(observed)
n <- length(observed)
N <- length(levels)

# predicted label
assert_factor(
predicted_label, len = N,
any.missing = FALSE, empty.levels.ok = FALSE
)
assert_set_equal(levels(observed), levels(predicted_label))

# predicted
assert_numeric(predicted, min.len = 1, lower = 0, upper = 1)
if (n == 1) {
assert(
# allow one of two options
check_vector(predicted, len = N),
check_matrix(predicted, nrows = n, ncols = N)
)
summed_predictions <- .rowSums(predicted, m = 1, n = N, na.rm = TRUE)
} else {
assert_matrix(predicted, nrows = n)
summed_predictions <- round(rowSums(predicted, na.rm = TRUE), 10) # avoid numeric errors
}
if (!all(summed_predictions == 1)) {
#nolint start: keyword_quote_linter object_usage_linter
row_indices <- as.character(which(summed_predictions != 1))
cli_abort(
c(
`!` = "Probabilities belonging to a single forecast must sum to one",
`i` = "Found issues in row{?s} {row_indices} of {.var predicted}"
)
)
#nolint end
}
return(invisible(NULL))
}


#' @title Assert that inputs are correct for point forecast
#' @description
#' Function assesses whether the inputs correspond to the
Expand Down
27 changes: 27 additions & 0 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,30 @@
#' @source \url{https://github.com/european-modelling-hubs/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/}
# nolint end
"example_binary"


#' Nominal example data
#'
#' A data set with predictions for COVID-19 cases and deaths submitted to the
#' European Forecast Hub.
#'
#' The data was created using the script create-example-data.R in the inst/
#' folder (or the top level folder in a compiled package).
#'
#' @format A data frame with the following columns:
#' \describe{
#' \item{location}{the country for which a prediction was made}
#' \item{target_end_date}{the date for which a prediction was made}
#' \item{target_type}{the target to be predicted (cases or deaths)}
#' \item{observed}{Numeric: observed values}
#' \item{location_name}{name of the country for which a prediction was made}
#' \item{forecast_date}{the date on which a prediction was made}
#' \item{predicted_label}{outcome that a probabilty corresponds to}
#' \item{predicted}{predicted value}
#' \item{model}{name of the model that generated the forecasts}
#' \item{horizon}{forecast horizon in weeks}
#' }
# nolint start
#' @source \url{https://github.com/european-modelling-hubs/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/}
# nolint end
"example_nominal"
21 changes: 21 additions & 0 deletions R/default-scoring-rules.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,27 @@ metrics_binary <- function(select = NULL, exclude = NULL) {
}


#' @title Scoring rules for nominal forecasts
#' @description Helper function that returns a named list of default
#' scoring rules suitable for nominal forecasts.
#'
#' The default scoring rules are:
#' - "log_score" = [logs_nominal()]
#' @inherit select_metrics params return
#' @export
#' @keywords metric
#' @examples
#' metrics_nominal()
#' metrics_nominal(select = "log_score")
metrics_nominal <- function(select = NULL, exclude = NULL) {
all <- list(
log_score = logs_nominal
)
selected <- select_metrics(all, select, exclude)
return(selected)
}


#' @title Default metrics and scoring rules for point forecasts
#' @description
#' Helper function that returns a named list of default
Expand Down
10 changes: 9 additions & 1 deletion R/documentation-templates.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#' moment, those are:
#' - point forecasts
#' - binary forecasts ("soft binary classification")
#' - nominal forecasts ("soft classification with multiple unordered classes")
#' - Probabilistic forecasts in a quantile-based format (a forecast is
#' represented as a set of predictive quantiles)
#' - Probabilistic forecasts in a sample-based format (a forecast is represented
Expand Down Expand Up @@ -33,6 +34,13 @@
#' corresponding to the probability that `observed` is equal to the second
#' factor level. See details [here][brier_score()] for more information.
#'
#' *Nominal forecasts* require a column `observed` of type factor with N levels,
#' (where N is the number of possible outcomes), a column `predicted` of type
#' numeric with probabilities (which sum to one across all possible outcomes),
#' and a column `predicted_label` of type factor with N levels, denoting the
#' outcome for which a probability is given. Forecasts must be complete, i.e.
#' there must be a probability assigned to every possible outcome.
#'
#' *Quantile-based forecasts* require a column `observed` of type numeric,
#' a column `predicted` of type numeric, and a column `quantile_level` of type
#' numeric with quantile-levels (between 0 and 1).
Expand All @@ -43,7 +51,7 @@
#'
#' For more information see the vignettes and the example data
#' ([example_quantile], [example_sample_continuous], [example_sample_discrete],
#' [example_point()], and [example_binary]).
#' [example_point()], [example_binary], and [example_nominal]).
#'
#' @details # Forecast unit
#'
Expand Down
85 changes: 81 additions & 4 deletions R/forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,39 @@ as_forecast_sample <- function(data,
return(data)
}


#' @title Create a `forecast` object for nominal forecasts
#' @description
#' Nominal forecasts are a form of categorical forecasts where the possible
#' outcomes that the observed values can assume are not ordered. In that sense,
#' Nominal forecasts represent a generalisation of binary forecasts.
#' @inheritParams as_forecast
#' @param predicted_label (optional) Name of the column in `data` that denotes
#' the outcome to which a predicted probability corresponds to.
#' This column will be renamed to "predicted_label". Only applicable to
#' nominal forecasts.
#' @family functions to create forecast objects
#' @keywords as_forecast
#' @export
as_forecast_nominal <- function(data,
forecast_unit = NULL,
observed = NULL,
predicted = NULL,
model = NULL,
predicted_label = NULL) {
assert_character(predicted_label, len = 1, null.ok = TRUE)
assert_subset(predicted_label, names(data), empty.ok = TRUE)
if (!is.null(predicted_label)) {
setnames(data, old = predicted_label, new = "predicted_label")
}

data <- as_forecast_generic(data, forecast_unit, observed, predicted, model)
data <- new_forecast(data, "forecast_nominal")
assert_forecast(data)
return(data)
}


#' @title Assert that input is a forecast object and passes validations
#'
#' @description
Expand Down Expand Up @@ -421,6 +454,48 @@ assert_forecast.forecast_sample <- function(
}


#' @export
#' @keywords check-forecasts
#' @importFrom checkmate assert_names assert_set_equal test_set_equal
assert_forecast.forecast_nominal <- function(
forecast, forecast_type = NULL, verbose = TRUE, ...
) {
forecast <- assert_forecast_generic(forecast, verbose)
assert(check_columns_present(forecast, "predicted_label"))
assert_names(
colnames(forecast),
disjunct.from = c("sample_id", "quantile_level")
)
assert_forecast_type(forecast, actual = "nominal", desired = forecast_type)

# levels need to be the same
outcomes <- levels(forecast$observed)
assert_set_equal(levels(forecast$predicted_label), outcomes)

# forecasts need to be complete
forecast_unit <- get_forecast_unit(forecast)
complete <- forecast[, .(
correct = test_set_equal(as.character(predicted_label), outcomes)
), by = forecast_unit]

if (!all(complete$correct)) {
first_issue <- complete[(correct), ..forecast_unit][1]
first_issue <- lapply(first_issue, FUN = as.character)
#nolint start: keyword_quote_linter object_usage_linter duplicate_argument_linter
issue_location <- paste(names(first_issue), "==", first_issue)
cli_abort(
c(`!` = "Found incomplete forecasts",
`i` = "For a nominal forecast, all possible outcomes must be assigned
a probability explicitly.",
`i` = "Found first missing probabilities in the forecast identified by
{.emph {issue_location}}")
)
#nolint end
}
return(forecast[])
}


#' @title Re-validate an existing forecast object
#'
#' @description
Expand Down Expand Up @@ -579,28 +654,30 @@ is_forecast <- function(x) {

#' @export
#' @rdname is_forecast
#' @keywords validate-forecast-object
is_forecast_sample <- function(x) {
inherits(x, "forecast_sample") && inherits(x, "forecast")
}

#' @export
#' @rdname is_forecast
#' @keywords validate-forecast-object
is_forecast_binary <- function(x) {
inherits(x, "forecast_binary") && inherits(x, "forecast")
}

#' @export
#' @rdname is_forecast
#' @keywords validate-forecast-object
is_forecast_point <- function(x) {
inherits(x, "forecast_point") && inherits(x, "forecast")
}

#' @export
#' @rdname is_forecast
#' @keywords validate-forecast-object
is_forecast_quantile <- function(x) {
inherits(x, "forecast_quantile") && inherits(x, "forecast")
}

#' @export
#' @rdname is_forecast
is_forecast_nominal <- function(x) {
inherits(x, "forecast_nominal") && inherits(x, "forecast")
}
Loading