Skip to content

Commit

Permalink
Issue #900: Remove observed and predicted columns from output for poi…
Browse files Browse the repository at this point in the history
…nt and binary forecasts (#904)

* remove observed and predicted columns from output for point and binary forecasts

* update news file
  • Loading branch information
nikosbosse authored Sep 12, 2024
1 parent f1d926d commit 7fa754a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ of our [original](https://doi.org/10.48550/arXiv.2205.07090) `scoringutils` pape
- `score()` now calls `na.omit()` on the data, instead of only removing rows with missing values in the columns `observed` and `predicted`. This is because `NA` values in other columns can also mess up e.g. grouping of forecasts according to the unit of a single forecast.
- `score()` and many other functions now require a validated `forecast` object. `forecast` objects can be created using the functions `as_forecast_point()`, `as_forecast_binary()`, `as_forecast_quantile()`, and `as_forecast_sample()` (which replace the previous `check_forecast()`). A forecast object is a data.table with class `forecast` and an additional class corresponding to the forecast type (e.g. `forecast_quantile`).
`score()` now returns objects of class `scores` with a stored attribute `metrics` that holds the names of the scoring rules that were used. Users can call `get_metrics()` to access the names of those scoring rules.
- `score()` now returns one score per forecast, instead of one score per sample or quantile.
- `score()` now returns one score per forecast, instead of one score per sample or quantile. For binary and point forecasts, the columns "observed" and "predicted" are now removed for consistency with the other forecast types.
- Users can now also use their own scoring rules (making use of the `metrics` argument, which takes in a named list of functions). Default scoring rules can be accessed using the functions `metrics_point()`, `metrics_sample()`, `metrics_quantile()`, `metrics_binary()`, and `metrics_nominal()`, which return a named list of scoring rules suitable for the respective forecast type. Column names of scores in the output of `score()` correspond to the names of the scoring rules (i.e. the names of the functions in the list of metrics).
- Instead of supplying arguments to `score()` to manipulate individual scoring rules users should now manipulate the metric list being supplied using `purrr::partial()` and `select_metric()`. See `?score()` for more information.
- the CRPS is now reported as decomposition into dispersion, overprediction and underprediction.
Expand Down
2 changes: 2 additions & 0 deletions R/score.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ score.forecast_binary <- function(forecast, metrics = metrics_binary(), ...) {
forecast, metrics,
forecast$observed, forecast$predicted
)
scores[, `:=`(predicted = NULL, observed = NULL)]

scores <- as_scores(scores, metrics = names(metrics))
return(scores[])
Expand Down Expand Up @@ -170,6 +171,7 @@ score.forecast_point <- function(forecast, metrics = metrics_point(), ...) {
forecast, metrics,
forecast$observed, forecast$predicted
)
scores[, `:=`(predicted = NULL, observed = NULL)]

scores <- as_scores(scores, metrics = names(metrics))
return(scores[])
Expand Down
13 changes: 12 additions & 1 deletion tests/testthat/test-score.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ test_that("function throws an error if data is not a forecast object", {

# test binary case -------------------------------------------------------------
test_that("function produces output for a binary case", {

expect_equal(
names(scores_binary),
c(get_forecast_unit(example_binary), names(metrics_binary()))
)

eval <- summarise_scores(scores_binary, by = c("model", "target_type"))

expect_equal(
Expand Down Expand Up @@ -128,6 +134,11 @@ test_that(

# test point case --------------------------------------------------------------
test_that("function produces output for a point case", {
expect_equal(
names(scores_binary),
c(get_forecast_unit(example_binary), names(metrics_binary()))
)

eval <- summarise_scores(scores_point, by = c("model", "target_type"))

expect_equal(
Expand Down Expand Up @@ -366,7 +377,7 @@ test_that("`[` preserves attributes", {
test <- data.table::copy(scores_binary)
class(test) <- c("scores", "data.frame")
expect_true("metrics" %in% names(attributes(test)))
expect_true("metrics" %in% names(attributes(test[1:10])))
expect_true("metrics" %in% names(attributes(test[1:9])))
})


Expand Down

0 comments on commit 7fa754a

Please sign in to comment.