Skip to content

Commit

Permalink
Merge branch 'main' into update-as_forecast()
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse authored Feb 23, 2024
2 parents 062bfb3 + 29986a6 commit 393cc54
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 13 deletions.
38 changes: 30 additions & 8 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -216,23 +216,45 @@ ensure_data.table <- function(data) {
#' dat <- as_forecast(example_quantile)
#' print(dat)
print.forecast_binary <- function(x, ...) {
# Obtain forecast object information for printing
forecast_type <- get_forecast_type(x)
score_cols <- get_score_names(x)

# check whether object passes validation
validation <- try(do.call(validate_forecast, list(data = x)), silent = TRUE)
if (inherits(validation, "try-error")) {
validation_msg <- conditionMessage(attr(validation, "condition"))
warning(
"Error in validating forecast object:\n",
validation_msg
)
}

# get forecast type, forecast unit and score columns
forecast_type <- try(
do.call(get_forecast_type, list(data = x)),
silent = TRUE
)
forecast_unit <- get_forecast_unit(x)
score_cols <- get_score_names(x)

# Print forecast object information
cat("Forecast type:\n")
print(forecast_type)
if (inherits(forecast_type, "try-error")) {
message("Could not determine forecast type due to error in validation.")
} else {
cat("Forecast type:\n")
print(forecast_type)
}

if (length(forecast_unit) == 0) {
message("Could not determine forecast unit")
} else {
cat("\nForecast unit:\n")
print(forecast_unit)
}

if (!is.null(score_cols)) {
cat("\nScore columns:\n")
print(score_cols)
}

cat("\nForecast unit:\n")
print(forecast_unit)

cat("\n")
NextMethod(x, ...)

Expand Down
49 changes: 44 additions & 5 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,17 @@ test_that("get_score_names() works as expected", {
)
})


# ==============================================================================
# print
# ==============================================================================

test_that("print() works on forecast_* objects", {
# Check print works on each forecast object
test_dat <- list(example_binary, example_quantile,
example_point, example_continuous, example_integer)
test_dat <- list(na.omit(example_binary), na.omit(example_quantile),
na.omit(example_point), na.omit(example_continuous), na.omit(example_integer))
for (dat in test_dat){
dat <- suppressMessages(as_forecast(dat))
dat <- as_forecast(dat)
forecast_type <- get_forecast_type(dat)
forecast_unit <- get_forecast_unit(dat)

Expand All @@ -103,13 +108,47 @@ test_that("print() works on forecast_* objects", {

# Check Score columns are printed
dat <- example_quantile %>%
na.omit %>%
set_forecast_unit(c("location", "target_end_date",
"target_type", "horizon", "model")) %>%
as_forecast() %>%
add_coverage() %>%
suppressMessages
add_coverage()

expect_output(print(dat), "Score columns")
score_cols <- get_score_names(dat)
expect_output(print(dat), pattern = paste(score_cols, collapse = " "))
})

test_that("print methods fail gracefully", {
test <- as_forecast(na.omit(example_quantile))
test$observed <- NULL

# message if forecast type can't be computed
expect_warning(
expect_message(
expect_output(
print(test),
pattern = "Forecast unit:"
),
"Could not determine forecast type due to error in validation."
),
"Error in validating forecast object:"
)

# message if forecast unit can't be computed
test <- 1:10
class(test) <- "forecast_point"
expect_warning(
expect_message(
expect_message(
expect_output(
print(test),
pattern = "Forecast unit:"
),
"Could not determine forecast unit."
),
"Could not determine forecast type"
),
"Error in validating forecast object:"
)
})

0 comments on commit 393cc54

Please sign in to comment.