Skip to content

Commit

Permalink
Add an argument forecast_type to as_forecast()
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse authored and seabbs committed Feb 26, 2024
1 parent 00952b5 commit 842feac
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
16 changes: 16 additions & 0 deletions R/validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ as_forecast <- function(data,
#' If `NULL` (the default), all columns that are not required columns are
#' assumed to form the unit of a single forecast. If specified, all columns
#' that are not part of the forecast unit (or required columns) will be removed.
#' @param forecast_type (optional) The forecast type you expect the forecasts
#' to have. If the forecast type as determined by `scoringutils` based on the
#' input does not match this, an error will be thrown. If `NULL` (the default),
#' the forecast type will be inferred from the data.
#' @param observed (optional) Name of the column in `data` that contains the
#' observed values. This column will be renamed to "observed".
#' @param predicted (optional) Name of the column in `data` that contains the
Expand All @@ -61,6 +65,7 @@ as_forecast <- function(data,
#' @export
as_forecast.default <- function(data,
forecast_unit = NULL,
forecast_type = NULL,
observed = NULL,
predicted = NULL,
model = NULL,
Expand Down Expand Up @@ -110,8 +115,19 @@ as_forecast.default <- function(data,
}

# find forecast type
desired_forecast_type <- forecast_type
forecast_type <- get_forecast_type(data)

if (!is.null(desired_forecast_type)) {
if (forecast_type != desired_forecast_type) {
stop(
"Forecast type determined by scoringutils based on input: `",
forecast_type,
"`. Desired forecast type: `", desired_forecast_type, "`."
)
}
}

# construct class
data <- new_forecast(data, paste0("forecast_", forecast_type))

Expand Down
7 changes: 7 additions & 0 deletions tests/testthat/test-as_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ test_that("as_forecast() works as expected", {
"target_end_date", "horizon"),
sample_id = "sample")
)

# test if desired forecast type does not correspond to inferred one
test <- na.omit(data.table::copy(example_continuous))
expect_error(
as_forecast(test, forecast_type = "quantile"),
"Forecast type determined by scoringutils based on input"
)
})


Expand Down

0 comments on commit 842feac

Please sign in to comment.