Skip to content

Commit

Permalink
Replace get_complete_forecasts() by na.omit()
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse committed Nov 20, 2023
1 parent b0ef998 commit bbca6c0
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 61 deletions.
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ export(correlation)
export(crps_sample)
export(dispersion)
export(dss_sample)
export(get_complete_forecasts)
export(get_duplicate_forecasts)
export(get_forecast_unit)
export(interval_coverage_deviation_quantile)
Expand Down
2 changes: 1 addition & 1 deletion R/available_forecasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ available_forecasts <- function(data,

data <- validate(data)
forecast_unit <- attr(data, "forecast_unit")
data <- get_complete_forecasts(data)
data <- na.omit(data)

if (is.null(by)) {
by <- forecast_unit
Expand Down
22 changes: 0 additions & 22 deletions R/get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -269,25 +269,3 @@ get_scoringutils_attributes <- function(object) {
}
return(attr_list)
}


#' Get Complete Forecasts
#' @description Helper function to remove rows from a data.frame where the
#' value in either one of the columns `predicted` or `observed` is `NA`.
#' @inheritParams score
#' @return A data.table with the same columns as the input, but
#' without rows where either `predicted` or `observed` is `NA`.
#' @export
#' @keywords check-forecasts
get_complete_forecasts <- function(data) {
data <- ensure_data.table(data)
assert(check_columns_present(data, c("observed", "predicted")))
data <- data[!is.na(observed) & !is.na(predicted)]
if (nrow(data) == 0) {
stop(
"After removing NA values in `observed` and `predicted`, ",
"there were no observations left"
)
}
return(data[])
}
2 changes: 1 addition & 1 deletion R/pit.R
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ pit <- function(data,
n_replicates = 100) {

data <- validate(data)
data <- get_complete_forecasts(data)
data <- na.omit(data)
forecast_type <- get_forecast_type(data)

if (forecast_type == "quantile") {
Expand Down
8 changes: 4 additions & 4 deletions R/score.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ score.default <- function(data, ...) {
#' @export
score.scoringutils_binary <- function(data, metrics = metrics_binary, ...) {
data <- validate(data)
data <- get_complete_forecasts(data)
data <- na.omit(data)
metrics <- validate_metrics(metrics)

data <- apply_metrics(
Expand All @@ -169,7 +169,7 @@ score.scoringutils_binary <- function(data, metrics = metrics_binary, ...) {
#' @export
score.scoringutils_point <- function(data, metrics = metrics_point, ...) {
data <- validate(data)
data <- get_complete_forecasts(data)
data <- na.omit(data)
metrics <- validate_metrics(metrics)

data <- apply_metrics(
Expand All @@ -186,7 +186,7 @@ score.scoringutils_point <- function(data, metrics = metrics_point, ...) {
#' @export
score.scoringutils_sample <- function(data, metrics = metrics_sample, ...) {
data <- validate(data)
data <- get_complete_forecasts(data)
data <- na.omit(data)
forecast_unit <- attr(data, "forecast_unit")
metrics <- validate_metrics(metrics)

Expand Down Expand Up @@ -223,7 +223,7 @@ score.scoringutils_sample <- function(data, metrics = metrics_sample, ...) {
#' @export
score.scoringutils_quantile <- function(data, metrics = metrics_quantile, ...) {
data <- validate(data)
data <- get_complete_forecasts(data)
data <- na.omit(data)
forecast_unit <- attr(data, "forecast_unit")
metrics <- validate_metrics(metrics)

Expand Down
6 changes: 6 additions & 0 deletions R/validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ validate_general <- function(data) {
setattr(data, "messages", messages)
}

if (nrow(na.omit(data)) == 0) {
stop(
"After removing rows with NA values in the data, nothing is left."
)
}

return(data[])
}

Expand Down
20 changes: 0 additions & 20 deletions man/get_complete_forecasts.Rd

This file was deleted.

20 changes: 8 additions & 12 deletions tests/testthat/test-get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,21 @@ fixed = TRUE


# ==============================================================================
# `get_complete_forecasts()`
# Test removing `NA` values from the data
# ==============================================================================
test_that("get_complete_forecasts() works as expected", {
expect_equal(nrow(get_complete_forecasts(example_quantile)), 20401)
test_that("removing NA rows from data works as expected", {
expect_equal(nrow(na.omit(example_quantile)), 20401)

ex <- data.frame(observed = c(NA, 1:3), predicted = 1:4)
expect_equal(nrow(get_complete_forecasts(ex)), 3)
expect_equal(nrow(na.omit(ex)), 3)

ex$predicted <- c(1:3, NA)
expect_equal(nrow(get_complete_forecasts(ex)), 2)
expect_equal(nrow(na.omit(ex)), 2)

ex <- data.table::copy(example_quantile)[, "predicted" := NA_real_]
expect_error(
get_complete_forecasts(data.frame(x = 1:2, y = 1:2)),
"Assertion on 'data' failed: Columns 'observed', 'predicted' not found in data."
)

expect_error(
get_complete_forecasts(data.frame(observed = c(NA, NA), predicted = 1:2)),
"After removing NA values in `observed` and `predicted`, there were no observations left"
validate(ex),
"After removing rows with NA values in the data, nothing is left."
)
})

Expand Down

0 comments on commit bbca6c0

Please sign in to comment.