Skip to content

Commit

Permalink
Issue #832 - Make example data pre-validated (#901)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse authored Sep 10, 2024
1 parent 3b59595 commit c011628
Show file tree
Hide file tree
Showing 23 changed files with 81 additions and 82 deletions.
4 changes: 0 additions & 4 deletions R/convenience-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,5 @@ set_forecast_unit <- function(data, forecast_unit) {
assert_subset(forecast_unit, names(data), empty.ok = FALSE)
keep_cols <- c(get_protected_columns(data), forecast_unit)
out <- unique(data[, .SD, .SDcols = keep_cols])
# validate that output remains a valid forecast object if input was one before
if (is_forecast(out)) {
assert_forecast(out)
}
return(out)
}
18 changes: 12 additions & 6 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#' The data was created using the script create-example-data.R in the inst/
#' folder (or the top level folder in a compiled package).
#'
#' @format A data frame with the following columns:
#' @format An object of class `forecast_quantile` (see [as_forecast()]) with the
#' following columns:
#' \describe{
#' \item{location}{the country for which a prediction was made}
#' \item{target_end_date}{the date for which a prediction was made}
Expand Down Expand Up @@ -34,7 +35,8 @@
#' The data was created using the script create-example-data.R in the inst/
#' folder (or the top level folder in a compiled package).
#'
#' @format A data frame with the following columns:
#' @format An object of class `forecast_point` (see [as_forecast()]) with the
#' following columns:
#' \describe{
#' \item{location}{the country for which a prediction was made}
#' \item{target_end_date}{the date for which a prediction was made}
Expand All @@ -60,7 +62,8 @@
#' The data was created using the script create-example-data.R in the inst/
#' folder (or the top level folder in a compiled package).
#'
#' @format A data frame with the following columns:
#' @format An object of class `forecast_sample` (see [as_forecast()]) with the
#' following columns:
#' \describe{
#' \item{location}{the country for which a prediction was made}
#' \item{target_end_date}{the date for which a prediction was made}
Expand All @@ -87,7 +90,8 @@
#' The data was created using the script create-example-data.R in the inst/
#' folder (or the top level folder in a compiled package).
#'
#' @format A data frame with the following columns:
#' @format An object of class `forecast_sample` (see [as_forecast()]) with the
#' following columns:
#' \describe{
#' \item{location}{the country for which a prediction was made}
#' \item{target_end_date}{the date for which a prediction was made}
Expand Down Expand Up @@ -121,7 +125,8 @@
#' The data was created using the script create-example-data.R in the inst/
#' folder (or the top level folder in a compiled package).
#'
#' @format A data frame with the following columns:
#' @format An object of class `forecast_binary` (see [as_forecast()]) with the
#' following columns:
#' \describe{
#' \item{location}{the country for which a prediction was made}
#' \item{location_name}{name of the country for which a prediction was made}
Expand All @@ -147,7 +152,8 @@
#' The data was created using the script create-example-data.R in the inst/
#' folder (or the top level folder in a compiled package).
#'
#' @format A data frame with the following columns:
#' @format An object of class `forecast_nominal` (see [as_forecast()]) with the
#' following columns:
#' \describe{
#' \item{location}{the country for which a prediction was made}
#' \item{target_end_date}{the date for which a prediction was made}
Expand Down
Binary file modified data/example_binary.rda
Binary file not shown.
Binary file modified data/example_nominal.rda
Binary file not shown.
Binary file modified data/example_point.rda
Binary file not shown.
Binary file modified data/example_quantile.rda
Binary file not shown.
Binary file modified data/example_sample_continuous.rda
Binary file not shown.
Binary file modified data/example_sample_discrete.rda
Binary file not shown.
8 changes: 6 additions & 2 deletions inst/create-example-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,14 @@ usethis::use_data(example_truth_only, overwrite = TRUE)
# merge forecast data and truth data and save
example_quantile <- merge_pred_and_obs(hub_data, truth)
data.table::setDT(example_quantile)
example_quantile <- as_forecast_quantile(example_quantile)
usethis::use_data(example_quantile, overwrite = TRUE)


# create data with point forecasts ---------------------------------------------
example_point <- data.table::copy(example_quantile)
example_point <- example_point[quantile %in% c(NA, 0.5)][, quantile_level := NULL]
example_point <- as_forecast_point(example_point)
usethis::use_data(example_point, overwrite = TRUE)


Expand Down Expand Up @@ -184,12 +186,14 @@ by = c(
# remove unnecessary rows where no predictions are available
example_sample_continuous[is.na(predicted), sample_id := NA]
example_sample_continuous <- unique(example_sample_continuous)
example_sample_continuous <- as_forecast_sample(example_sample_continuous)
usethis::use_data(example_sample_continuous, overwrite = TRUE)


# get integer sample data ------------------------------------------------------
example_sample_discrete <- data.table::copy(example_sample_continuous)
example_sample_discrete <- example_sample_discrete[, predicted := round(predicted)]
example_sample_discrete <- as_forecast_sample(example_sample_discrete)
usethis::use_data(example_sample_discrete, overwrite = TRUE)


Expand Down Expand Up @@ -226,7 +230,7 @@ example_binary[, `:=`(
observed = factor(as.numeric(observed))
)]
example_binary <- unique(example_binary)

example_binary <- as_forecast_binary(example_binary)
usethis::use_data(example_binary, overwrite = TRUE)


Expand Down Expand Up @@ -270,5 +274,5 @@ example_nominal[, `:=`(
observed = factor(observed, levels = c("low", "medium", "high")),
predicted_label = factor(predicted_label, levels = c("low", "medium", "high"))
)]

example_nominal <- as_forecast_nominal(example_nominal)
usethis::use_data(example_nominal, overwrite = TRUE)
3 changes: 2 additions & 1 deletion man/example_binary.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/example_nominal.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/example_point.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/example_quantile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/example_sample_continuous.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/example_sample_discrete.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 6 additions & 13 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,11 @@ metrics_no_cov_no_ae <- metrics_quantile(
example_quantile_df <- as.data.frame(na.omit(example_quantile))
checkmate::assert_number(length(class(example_quantile_df)))

# pre-validated forecast objects
forecast_quantile <- as_forecast_quantile(na.omit(example_quantile))
forecast_sample_continuous <- as_forecast_sample(na.omit(example_sample_continuous))
forecast_sample_discrete <- as_forecast_sample(na.omit(example_sample_discrete))
forecast_point <- as_forecast_point(na.omit(example_point))
forecast_binary <- as_forecast_binary(na.omit(example_binary))
forecast_nominal <- as_forecast_nominal(na.omit(example_nominal))

# pre-computed scores
scores_quantile <- score(forecast_quantile)
scores_sample_continuous <- score(forecast_sample_continuous)
scores_sample_discrete <- score(forecast_sample_discrete)
scores_point <- score(forecast_point)
scores_binary <- score(forecast_binary)
scores_nominal <- score(forecast_nominal)
scores_quantile <- score(example_quantile)
scores_sample_continuous <- score(example_sample_continuous)
scores_sample_discrete <- score(example_sample_discrete)
scores_point <- score(example_point)
scores_binary <- score(example_binary)
scores_nominal <- score(example_nominal)
21 changes: 9 additions & 12 deletions tests/testthat/test-convenience-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
test_that("function transform_forecasts works", {
predictions_original <- example_quantile$predicted
predictions <- example_quantile %>%
as_forecast_quantile() %>%
transform_forecasts(
fun = function(x) pmax(0, x),
append = FALSE
Expand Down Expand Up @@ -58,8 +57,7 @@ test_that("function transform_forecasts works", {
})

test_that("transform_forecasts() outputs an object of class forecast_*", {
ex <- as_forecast_binary(na.omit(example_binary))
transformed <- transform_forecasts(ex, fun = identity, append = FALSE)
transformed <- transform_forecasts(example_binary, fun = identity, append = FALSE)
expect_s3_class(transformed, "forecast_binary")
})

Expand Down Expand Up @@ -106,11 +104,17 @@ test_that("function set_forecast_unit() works", {
# these and see whether the result stays the same.
scores1 <- scores_quantile[order(location, target_end_date, target_type, horizon, model), ]

# test that if setting the forecast unit results in an invalid object,
# a warning occurs.
expect_warning(
set_forecast_unit(example_quantile, "model"),
"Assertion on 'data' failed: There are instances with more"
)

ex2 <- set_forecast_unit(
example_quantile,
c("location", "target_end_date", "target_type", "horizon", "model")
) %>%
as_forecast_quantile()
)
scores2 <- score(na.omit(ex2))
scores2 <- scores2[order(location, target_end_date, target_type, horizon, model), ]

Expand Down Expand Up @@ -145,13 +149,6 @@ test_that("set_forecast_unit() revalidates a forecast object", {
expect_no_condition(
set_forecast_unit(obj, c("location", "target_end_date", "target_type", "model", "horizon"))
)
expect_error(
# [.forecast()` will warn even before the error is thrown
suppressWarnings(
set_forecast_unit(obj, c("location", "target_end_date", "target_type", "model"))
),
"There are instances with more than one forecast for the same target."
)
})


Expand Down
14 changes: 7 additions & 7 deletions tests/testthat/test-forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ test_that("as_forecast() function has an error for empty data.frame", {
})

test_that("as_forecast() errors if there is both a sample_id and a quantile_level column", {
example <- data.table::copy(example_quantile)[, sample_id := 1]
example <- as.data.table(example_quantile)[, sample_id := 1]
expect_error(
as_forecast_quantile(example),
"Found columns `quantile_level` and `sample_id`. Only one of these is allowed"
)
})

test_that("as_forecast() warns if there are different numbers of quantiles", {
example <- data.table::copy(example_quantile)[-1000, ]
example <- as.data.table(example_quantile)[-1000, ]
expect_warning(
w <- as_forecast_quantile(na.omit(example)),
"Some forecasts have different numbers of rows"
Expand Down Expand Up @@ -129,7 +129,7 @@ test_that("as_forecast() function throws an error with duplicate forecasts", {
})

test_that("as_forecast_quantile() function warns when no model column is present", {
no_model <- data.table::copy(example_quantile[model == "EuroCOVIDhub-ensemble"])[, model := NULL][]
no_model <- as.data.table(example_quantile[model == "EuroCOVIDhub-ensemble"])[, model := NULL][]
expect_warning(
as_forecast_quantile(no_model),
"There is no column called `model` in the data.")
Expand Down Expand Up @@ -180,7 +180,7 @@ test_that("as_forecast.forecast_nominal() works as expected", {
})

test_that("as_forecast.forecast_nominal() breaks when rows with zero probability are missing", {
ex_faulty <- data.table::copy(example_nominal)
ex_faulty <- as.data.table(example_nominal)
ex_faulty <- ex_faulty[predicted != 0]
expect_warning(
expect_error(
Expand Down Expand Up @@ -228,7 +228,7 @@ test_that("assert_forecast() works as expected", {
})

test_that("assert_forecast.forecast_binary works as expected", {
test <- na.omit(data.table::copy(example_binary))
test <- na.omit(as.data.table(example_binary))
test[, "sample_id" := 1:nrow(test)]

# error if there is a superfluous sample_id column
Expand All @@ -238,7 +238,7 @@ test_that("assert_forecast.forecast_binary works as expected", {
)

# expect error if probabilties are not in [0, 1]
test <- na.omit(data.table::copy(example_binary))
test <- na.omit(as.data.table(example_binary))
test[, "predicted" := predicted + 1]
expect_error(
as_forecast_binary(test),
Expand All @@ -247,7 +247,7 @@ test_that("assert_forecast.forecast_binary works as expected", {
})

test_that("assert_forecast.forecast_point() works as expected", {
test <- na.omit(data.table::copy(example_point))
test <- na.omit(as.data.table(example_point))
test <- as_forecast_point(test)

# expect an error if column is changed to character after initial validation.
Expand Down
14 changes: 7 additions & 7 deletions tests/testthat/test-get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# `get_forecast_type`
# ==============================================================================
test_that("get_forecast_type() works as expected", {
expect_equal(get_forecast_type(forecast_quantile), "quantile")
expect_equal(get_forecast_type(forecast_sample_continuous), "sample")
expect_equal(get_forecast_type(forecast_sample_discrete), "sample")
expect_equal(get_forecast_type(forecast_binary), "binary")
expect_equal(get_forecast_type(forecast_point), "point")
expect_equal(get_forecast_type(forecast_nominal), "nominal")
expect_equal(get_forecast_type(example_quantile), "quantile")
expect_equal(get_forecast_type(example_sample_continuous), "sample")
expect_equal(get_forecast_type(example_sample_discrete), "sample")
expect_equal(get_forecast_type(example_binary), "binary")
expect_equal(get_forecast_type(example_point), "point")
expect_equal(get_forecast_type(example_nominal), "nominal")

expect_error(
get_forecast_type(data.frame(x = 1:10)),
Expand Down Expand Up @@ -210,7 +210,7 @@ test_that("get_duplicate_forecasts() works as expected for point", {
test_that("get_duplicate_forecasts() returns the expected class", {
expect_equal(
class(get_duplicate_forecasts(example_point)),
class(example_point)
c("data.table", "data.frame")
)
})

Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-input-check-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ test_that("Check equal length works if all arguments have length 1", {
})

test_that("ensure_model_column works", {
test <- data.table::copy(example_binary)
test <- as.data.table(example_binary)
expect_warning(
ensure_model_column(test[, model := NULL]),
"There is no column called `model` in the data."
Expand Down Expand Up @@ -48,7 +48,7 @@ test_that("check_number_per_forecast works", {


test_that("check_duplicates works", {
example_bin <- rbind(example_binary[1:2, ], example_binary[1:2, ])
example_bin <- rbind(example_binary[1000:1002, ], example_binary[1000:1002, ])
expect_identical(
capture.output(
check_duplicates(example_bin)
Expand Down
13 changes: 7 additions & 6 deletions tests/testthat/test-print.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
test_that("print() works on forecast_* objects", {
# Check print works on each forecast object
test_dat <- list(forecast_binary, forecast_quantile,
forecast_point, forecast_sample_continuous,
forecast_sample_discrete)
test_dat <- list(example_binary, example_quantile,
example_point, example_sample_continuous,
example_sample_discrete)
test_dat <- lapply(test_dat, na.omit)
for (dat in test_dat){
forecast_type <- get_forecast_type(dat)
forecast_type <- scoringutils:::get_forecast_type(dat)
forecast_unit <- get_forecast_unit(dat)

fn_name <- paste0("as_forecast_", forecast_type)
Expand All @@ -19,8 +20,8 @@ test_that("print() works on forecast_* objects", {
expect_snapshot(print(dat))

# Check print.data.table works.
output_original <- capture.output(print(dat))
output_test <- capture.output(print(data.table(dat)))
output_original <- suppressMessages(capture.output(print(dat)))
output_test <- suppressMessages(capture.output(print(data.table(dat))))
expect_contains(output_original, output_test)
}
})
Loading

0 comments on commit c011628

Please sign in to comment.