Skip to content

Commit

Permalink
update default args to assert_forecast_type()
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse authored and seabbs committed Mar 27, 2024
1 parent e545da3 commit 232a832
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
11 changes: 5 additions & 6 deletions R/forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ as_forecast.default <- function(data,
}

# assert forecast type is as expected
assert_forecast_type(data, forecast_type)
assert_forecast_type(data, desired = forecast_type)
forecast_type <- get_forecast_type(data)

# produce warning if old format is suspected
Expand Down Expand Up @@ -205,7 +205,7 @@ validate_forecast.default <- function(data, forecast_type = NULL, ...) {
#' @keywords check-forecasts
validate_forecast.forecast_binary <- function(data, forecast_type = NULL, ...) {
data <- validate_general(data)
assert_forecast_type(data, forecast_type)
assert_forecast_type(data, actual = "binary", desired = forecast_type)

columns_correct <- test_columns_not_present(
data, c("sample_id", "quantile_level")
Expand Down Expand Up @@ -239,7 +239,7 @@ validate_forecast.forecast_binary <- function(data, forecast_type = NULL, ...) {
#' @keywords check-forecasts
validate_forecast.forecast_point <- function(data, forecast_type = NULL, ...) {
data <- validate_general(data)
assert_forecast_type(data, forecast_type)
assert_forecast_type(data, actual = "point", desired = forecast_type)
#nolint start: keyword_quote_linter object_usage_linter
input_check <- check_input_point(data$observed, data$predicted)
if (!is.logical(input_check)) {
Expand All @@ -261,7 +261,7 @@ validate_forecast.forecast_point <- function(data, forecast_type = NULL, ...) {
validate_forecast.forecast_quantile <- function(data,
forecast_type = NULL, ...) {
data <- validate_general(data)
assert_forecast_type(data, forecast_type)
assert_forecast_type(data, actual = "quantile", desired = forecast_type)
assert_numeric(data$quantile_level, lower = 0, upper = 1)
return(data[])
}
Expand All @@ -271,9 +271,8 @@ validate_forecast.forecast_quantile <- function(data,
#' @rdname validate_forecast
#' @keywords check-forecasts
validate_forecast.forecast_sample <- function(data, forecast_type = NULL, ...) {

data <- validate_general(data)
assert_forecast_type(data, forecast_type)
assert_forecast_type(data, actual = "sample", desired = forecast_type)
return(data[])
}

Expand Down
16 changes: 9 additions & 7 deletions R/get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,22 @@ test_forecast_type_is_quantile <- function(data) {

#' Assert that forecast type is as expected
#' @param data A forecast object as produced by [as_forecast()].
#' @inheritParams as_forecast
#' @param actual The actual forecast type of the data
#' @param desired The desired forecast type of the data
#' @inherit document_assert_functions return
#' @importFrom cli cli_abort
#' @importFrom checkmate assert_character
assert_forecast_type <- function(data, forecast_type = NULL) {
assert_character(forecast_type, null.ok = TRUE)
desired <- forecast_type
forecast_type <- get_forecast_type(data)
if (!is.null(desired) && desired != forecast_type) {
#' @keywords internal_input_check
assert_forecast_type <- function(data,
actual = get_forecast_type(data),
desired = NULL) {
assert_character(desired, null.ok = TRUE)
if (!is.null(desired) && desired != actual) {
#nolint start: object_usage_linter keyword_quote_linter
cli_abort(
c(
"!" = "Forecast type determined by scoringutils based on input:
{.val {forecast_type}}.",
{.val {actual}}.",
"i" = "Desired forecast type: {.val {desired}}."
)
)
Expand Down
9 changes: 4 additions & 5 deletions man/assert_forecast_type.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 232a832

Please sign in to comment.