Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #900: Remove observed and predicted columns from output for point and binary forecasts #904

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ of our [original](https://doi.org/10.48550/arXiv.2205.07090) `scoringutils` pape
- `score()` now calls `na.omit()` on the data, instead of only removing rows with missing values in the columns `observed` and `predicted`. This is because `NA` values in other columns can also mess up e.g. grouping of forecasts according to the unit of a single forecast.
- `score()` and many other functions now require a validated `forecast` object. `forecast` objects can be created using the functions `as_forecast_point()`, `as_forecast_binary()`, `as_forecast_quantile()`, and `as_forecast_sample()` (which replace the previous `check_forecast()`). A forecast object is a data.table with class `forecast` and an additional class corresponding to the forecast type (e.g. `forecast_quantile`).
`score()` now returns objects of class `scores` with a stored attribute `metrics` that holds the names of the scoring rules that were used. Users can call `get_metrics()` to access the names of those scoring rules.
- `score()` now returns one score per forecast, instead of one score per sample or quantile.
- `score()` now returns one score per forecast, instead of one score per sample or quantile. For binary and point forecasts, the columns "observed" and "predicted" are now removed for consistency with the other forecast types.
- Users can now also use their own scoring rules (making use of the `metrics` argument, which takes in a named list of functions). Default scoring rules can be accessed using the functions `metrics_point()`, `metrics_sample()`, `metrics_quantile()`, `metrics_binary()`, and `metrics_nominal()`, which return a named list of scoring rules suitable for the respective forecast type. Column names of scores in the output of `score()` correspond to the names of the scoring rules (i.e. the names of the functions in the list of metrics).
- Instead of supplying arguments to `score()` to manipulate individual scoring rules users should now manipulate the metric list being supplied using `purrr::partial()` and `select_metric()`. See `?score()` for more information.
- the CRPS is now reported as decomposition into dispersion, overprediction and underprediction.
Expand Down
2 changes: 2 additions & 0 deletions R/score.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ score.forecast_binary <- function(forecast, metrics = metrics_binary(), ...) {
forecast, metrics,
forecast$observed, forecast$predicted
)
scores[, `:=`(predicted = NULL, observed = NULL)]

scores <- as_scores(scores, metrics = names(metrics))
return(scores[])
Expand Down Expand Up @@ -170,6 +171,7 @@ score.forecast_point <- function(forecast, metrics = metrics_point(), ...) {
forecast, metrics,
forecast$observed, forecast$predicted
)
scores[, `:=`(predicted = NULL, observed = NULL)]

scores <- as_scores(scores, metrics = names(metrics))
return(scores[])
Expand Down
13 changes: 12 additions & 1 deletion tests/testthat/test-score.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ test_that("function throws an error if data is not a forecast object", {

# test binary case -------------------------------------------------------------
test_that("function produces output for a binary case", {

expect_equal(
names(scores_binary),
c(get_forecast_unit(example_binary), names(metrics_binary()))
)

eval <- summarise_scores(scores_binary, by = c("model", "target_type"))

expect_equal(
Expand Down Expand Up @@ -128,6 +134,11 @@ test_that(

# test point case --------------------------------------------------------------
test_that("function produces output for a point case", {
expect_equal(
names(scores_binary),
c(get_forecast_unit(example_binary), names(metrics_binary()))
)

eval <- summarise_scores(scores_point, by = c("model", "target_type"))

expect_equal(
Expand Down Expand Up @@ -366,7 +377,7 @@ test_that("`[` preserves attributes", {
test <- data.table::copy(scores_binary)
class(test) <- c("scores", "data.frame")
expect_true("metrics" %in% names(attributes(test)))
expect_true("metrics" %in% names(attributes(test[1:10])))
expect_true("metrics" %in% names(attributes(test[1:9])))
})


Expand Down