Skip to content

Commit

Permalink
Create custom [.forecast() method (#884)
Browse files Browse the repository at this point in the history
* Create custom `[.forecast()` method

* Drop validation in print()

* Drop forecast class where necessary

* Adjust tests for earlier warnings

* Add test

* Write custom head and tail methods

* fix merge conflict

* Convert to data.table before subsetting (#892)

* Add setter methods

* Add basic tests for extended [.data.table() features

* Do not attempt to validate atomic vectors

---------

Co-authored-by: Nikos Bosse <[email protected]>
Co-authored-by: nikosbosse <[email protected]>
  • Loading branch information
3 people authored Aug 19, 2024
1 parent 3d4a0ce commit c20e2aa
Show file tree
Hide file tree
Showing 12 changed files with 222 additions and 59 deletions.
8 changes: 8 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Generated by roxygen2: do not edit by hand

S3method("$<-",forecast)
S3method("[",forecast)
S3method("[<-",forecast)
S3method("[[<-",forecast)
S3method(`[`,scores)
S3method(as_forecast_point,default)
S3method(as_forecast_point,forecast_quantile)
Expand All @@ -11,13 +15,15 @@ S3method(assert_forecast,forecast_nominal)
S3method(assert_forecast,forecast_point)
S3method(assert_forecast,forecast_quantile)
S3method(assert_forecast,forecast_sample)
S3method(head,forecast)
S3method(print,forecast)
S3method(score,default)
S3method(score,forecast_binary)
S3method(score,forecast_nominal)
S3method(score,forecast_point)
S3method(score,forecast_quantile)
S3method(score,forecast_sample)
S3method(tail,forecast)
export(add_relative_skill)
export(ae_median_quantile)
export(ae_median_sample)
Expand Down Expand Up @@ -197,3 +203,5 @@ importFrom(stats,sd)
importFrom(stats,weighted.mean)
importFrom(stats,wilcox.test)
importFrom(utils,combn)
importFrom(utils,head)
importFrom(utils,tail)
4 changes: 3 additions & 1 deletion R/check-input-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ ensure_model_column <- function(data) {
#' @inherit document_check_functions params return
#' @keywords internal_input_check
check_number_per_forecast <- function(data, forecast_unit) {
data <- ensure_data.table(data)
# This function doesn't return a forecast object so it's fine to unclass it
# to avoid validation error while subsetting
data <- as.data.table(data)
data <- na.omit(data)
# check whether there are the same number of quantiles, samples --------------
data[, scoringutils_InternalNumCheck := length(predicted), by = forecast_unit]
Expand Down
117 changes: 116 additions & 1 deletion R/forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ as_forecast_point.forecast_quantile <- function(data, ...) {
assert_forecast(data, verbose = FALSE)
assert_subset(0.5, unique(data$quantile_level))

# At end of this function, the object will have be turned from a
# forecast_quantile to a forecast_point and we don't want to validate it as a
# forecast_point during the conversion process. The correct class is restored
# at the end.
data <- as.data.table(data)

forecast <- data[quantile_level == 0.5]
forecast[, "quantile_level" := NULL]

Expand Down Expand Up @@ -474,7 +480,7 @@ assert_forecast.forecast_nominal <- function(

# forecasts need to be complete
forecast_unit <- get_forecast_unit(forecast)
complete <- forecast[, .(
complete <- as.data.table(forecast)[, .(
correct = test_set_equal(as.character(predicted_label), outcomes)
), by = forecast_unit]

Expand Down Expand Up @@ -681,3 +687,112 @@ is_forecast_quantile <- function(x) {
is_forecast_nominal <- function(x) {
inherits(x, "forecast_nominal") && inherits(x, "forecast")
}

#' @export
`[.forecast` <- function(x, ...) {

out <- NextMethod()

# ...length() > 1: we don't need to revalidate x[]
# is.data.table: when [.data.table returns an atomic vector, it's clear it
# cannot be a valid forecast object, and it is likely intended by the user
if (...length() > 1 && data.table::is.data.table(out)) {
# check whether subset object passes validation
validation <- try(
assert_forecast(forecast = out, verbose = FALSE),
silent = TRUE
)
if (inherits(validation, "try-error")) {
cli_warn(
c(
"!" = "Error in validating forecast object: {validation}"
)
)
}
}

return(out)

}

#' @export
`$<-.forecast` <- function(x, ..., value) {

out <- NextMethod()

# check whether subset object passes validation
validation <- try(
assert_forecast(forecast = out, verbose = FALSE),
silent = TRUE
)
if (inherits(validation, "try-error")) {
cli_warn(
c(
"!" = "Error in validating forecast object: {validation}"
)
)
}

return(out)

}

#' @export
`[[<-.forecast` <- function(x, ..., value) {

out <- NextMethod()

# check whether subset object passes validation
validation <- try(
assert_forecast(forecast = out, verbose = FALSE),
silent = TRUE
)
if (inherits(validation, "try-error")) {
cli_warn(
c(
"!" = "Error in validating forecast object: {validation}"
)
)
}

return(out)

}

#' @export
`[<-.forecast` <- function(x, ..., value) {

out <- NextMethod()

# check whether subset object passes validation
validation <- try(
assert_forecast(forecast = out, verbose = FALSE),
silent = TRUE
)
if (inherits(validation, "try-error")) {
cli_warn(
c(
"!" = "Error in validating forecast object: {validation}"
)
)
}

return(out)

}

#' @export
#' @importFrom utils head
head.forecast <- function(x, ...) {
# We use this custom method just to unclass before forwarding to avoid
# validation when we expect (and don't care) that objects are invalidated
head(as.data.table(x), ...)
}

#' @export
#' @importFrom utils tail
tail.forecast <- function(x, ...) {
# We use this custom method just to unclass before forwarding to avoid
# validation when we expect (and don't care) that objects are invalidated
utils::tail(as.data.table(x), ...)
}
3 changes: 2 additions & 1 deletion R/get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ get_forecast_counts <- function(forecast,
forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
forecast_unit <- get_forecast_unit(forecast)
assert_subset(by, names(forecast))
forecast <- as.data.table(forecast)

# collapse several rows to 1, e.g. treat a set of 10 quantiles as one,
# because they all belong to one single forecast that should be counted once
Expand All @@ -508,7 +509,7 @@ get_forecast_counts <- function(forecast,
forecast <- forecast[forecast[, .I[1], by = collapse_by]$V1]

# count number of rows = number of forecasts
out <- as.data.table(forecast)[, .(count = .N), by = by]
out <- forecast[, .(count = .N), by = by]

# make sure that all combinations in "by" are included in the output (with
# count = 0). To achieve that, take unique values in `forecast` and expand grid
Expand Down
3 changes: 2 additions & 1 deletion R/pit.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ get_pit <- function(forecast,

forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
forecast_type <- get_forecast_type(forecast)
forecast <- as.data.table(forecast)

if (forecast_type == "quantile") {
forecast[, quantile_coverage := (observed <= predicted)]
Expand All @@ -162,7 +163,7 @@ get_pit <- function(forecast,
),
by = c(get_forecast_unit(quantile_coverage))
]
return(as.data.table(quantile_coverage)[])
return(quantile_coverage[])
}

# if prediction type is not quantile, calculate PIT values based on samples
Expand Down
16 changes: 2 additions & 14 deletions R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,6 @@
#' print(dat)
print.forecast <- function(x, ...) {

# check whether object passes validation
validation <- try(
do.call(assert_forecast, list(forecast = x, verbose = FALSE)),
silent = TRUE
)
if (inherits(validation, "try-error")) {
cli_warn(
c(
"!" = "Error in validating forecast object: {validation}."
)
)
}

# get forecast type, forecast unit and score columns
forecast_type <- try(
do.call(get_forecast_type, list(data = x)),
Expand Down Expand Up @@ -73,7 +60,8 @@ print.forecast <- function(x, ...) {
}

cat("\n")
NextMethod(x, ...)

NextMethod()

return(invisible(x))
}
5 changes: 5 additions & 0 deletions R/score.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ score.default <- function(forecast, metrics, ...) {
score.forecast_binary <- function(forecast, metrics = metrics_binary(), ...) {
forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
metrics <- validate_metrics(metrics)
forecast <- as.data.table(forecast)

scores <- apply_metrics(
forecast, metrics,
Expand All @@ -132,6 +133,7 @@ score.forecast_nominal <- function(forecast, metrics = metrics_nominal(), ...) {
forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
forecast_unit <- get_forecast_unit(forecast)
metrics <- validate_metrics(metrics)
forecast <- as.data.table(forecast)

# transpose the forecasts that belong to the same forecast unit
# make sure the labels and predictions are ordered in the same way
Expand Down Expand Up @@ -162,6 +164,7 @@ score.forecast_nominal <- function(forecast, metrics = metrics_nominal(), ...) {
score.forecast_point <- function(forecast, metrics = metrics_point(), ...) {
forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
metrics <- validate_metrics(metrics)
forecast <- as.data.table(forecast)

scores <- apply_metrics(
forecast, metrics,
Expand All @@ -180,6 +183,7 @@ score.forecast_sample <- function(forecast, metrics = metrics_sample(), ...) {
forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
forecast_unit <- get_forecast_unit(forecast)
metrics <- validate_metrics(metrics)
forecast <- as.data.table(forecast)

# transpose the forecasts that belong to the same forecast unit
f_transposed <- forecast[, .(predicted = list(predicted),
Expand Down Expand Up @@ -217,6 +221,7 @@ score.forecast_quantile <- function(forecast, metrics = metrics_quantile(), ...)
forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
forecast_unit <- get_forecast_unit(forecast)
metrics <- validate_metrics(metrics)
forecast <- as.data.table(forecast)

# transpose the forecasts that belong to the same forecast unit
# make sure the quantiles and predictions are ordered in the same way
Expand Down
4 changes: 3 additions & 1 deletion R/utils_data_handling.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ quantile_to_interval_dataframe <- function(forecast,
format = "long",
keep_quantile_col = FALSE,
...) {
forecast <- ensure_data.table(forecast)
# After this transformation, the object will no longer be a valid forecast
# object so we unclass it
forecast <- as.data.table(forecast)

forecast[, boundary := ifelse(quantile_level <= 0.5, "lower", "upper")]
forecast[, interval_range := get_range_from_quantile(quantile_level)]
Expand Down
5 changes: 4 additions & 1 deletion tests/testthat/test-convenience-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ test_that("set_forecast_unit() revalidates a forecast object", {
set_forecast_unit(obj, c("location", "target_end_date", "target_type", "model", "horizon"))
)
expect_error(
set_forecast_unit(obj, c("location", "target_end_date", "target_type", "model")),
# [.forecast()` will warn even before the error is thrown
suppressWarnings(
set_forecast_unit(obj, c("location", "target_end_date", "target_type", "model"))
),
"There are instances with more than one forecast for the same target."
)
})
Expand Down
62 changes: 61 additions & 1 deletion tests/testthat/test-forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,10 @@ test_that("assert_forecast.forecast_point() works as expected", {
test <- as_forecast_point(test)

# expect an error if column is changed to character after initial validation.
test <- test[, "predicted" := as.character(predicted)]
expect_warning(
test <- test[, "predicted" := as.character(predicted)],
"Input looks like a point forecast, but found the following issue"
)
expect_error(
assert_forecast(test),
"Input looks like a point forecast, but found the following issue"
Expand Down Expand Up @@ -303,3 +306,60 @@ test_that("new_forecast() works as expected with a data.frame", {
c("forecast_quantile", "data.table", "data.frame")
)
})

# ==============================================================================
# [.forecast()
# ==============================================================================

test_that("[.forecast() immediately invalidates on change when necessary", {
test <- as_forecast_quantile(na.omit(example_quantile))

# For cols; various ways to drop.
# We use local() to avoid actual deletion in this frame and having to recreate
# the input multiple times
expect_warning(
local(test[, colnames(test) != "observed", with = FALSE]),
"Error in validating"
)

expect_warning(
local(test[, "observed"] <- NULL),
"Error in validating"
)

expect_warning(
local(test$observed <- NULL),
"Error in validating"
)

expect_warning(
local(test[["observed"]] <- NULL),
"Error in validating"
)

# For rows
expect_warning(
local(test[2, ] <- test[1, ])
)
})

test_that("[.forecast() doesn't warn on cases where the user likely didn't intend getting a forecast object", {
test <- as_forecast_quantile(na.omit(example_quantile))

expect_no_condition(test[, location])
})

test_that("[.forecast() is compatible with data.table syntax", {

test <- as_forecast_quantile(na.omit(example_quantile))

expect_no_condition(
test[location == "DE"]
)

expect_no_condition(
test[target_type == "Cases",
.(location, target_end_date, observed, location_name, forecast_date, quantile_level, predicted, model)]
)

})
Loading

0 comments on commit c20e2aa

Please sign in to comment.