Skip to content

Commit

Permalink
move tests around
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse committed Oct 6, 2024
1 parent 8f4d680 commit 95ee27c
Show file tree
Hide file tree
Showing 87 changed files with 2,616 additions and 2,445 deletions.
44 changes: 43 additions & 1 deletion R/check-input-helpers.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,46 @@
# to be deleted
#' Ensure that an object is a `data.table`
#' @description
#' This function ensures that an object is a `data table`.
#' If the object is not a data table, it is converted to one. If the object
#' is a data table, a copy of the object is returned.
#' @param data An object to ensure is a data table.
#' @return A data.table/a copy of an existing data.table.
#' @keywords internal
#' @importFrom data.table copy is.data.table as.data.table
ensure_data.table <- function(data) {
if (is.data.table(data)) {
data <- copy(data)
} else {
data <- as.data.table(data)
}
return(data)
}


#' @title Check whether an input is an atomic vector of mode 'numeric'
#'
#' @description Helper function to check whether an input is a numeric vector.
#' @param x input to check
#' @inheritDotParams checkmate::check_numeric
#' @importFrom checkmate check_atomic_vector check_numeric
#' @inherit document_check_functions return
#' @keywords internal_input_check
check_numeric_vector <- function(x, ...) {
# check functions must return TRUE on success
# and a custom error message otherwise
numeric <- check_numeric(x, ...)
vector <- check_atomic_vector(x)
if (!isTRUE(numeric)) {
return(numeric)
} else if (!isTRUE(vector)) {
return(vector)
}
return(TRUE)
}


# ==============================================================================
# functinos below will be deleted in the future

#' @title Helper function to convert assert statements into checks
#'
Expand Down
98 changes: 49 additions & 49 deletions R/class-forecast-nominal.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,34 @@
#' @title Create a `forecast` object for nominal forecasts
#' @description
#' Nominal forecasts are a form of categorical forecasts where the possible
#' outcomes that the observed values can assume are not ordered. In that sense,
#' Nominal forecasts represent a generalisation of binary forecasts.
#' @inheritParams as_forecast
#' @param predicted_label (optional) Name of the column in `data` that denotes
#' the outcome to which a predicted probability corresponds to.
#' This column will be renamed to "predicted_label". Only applicable to
#' nominal forecasts.
#' @family functions to create forecast objects
#' @keywords as_forecast
#' @export
as_forecast_nominal <- function(data,
forecast_unit = NULL,
observed = NULL,
predicted = NULL,
predicted_label = NULL) {
assert_character(predicted_label, len = 1, null.ok = TRUE)
assert_subset(predicted_label, names(data), empty.ok = TRUE)
if (!is.null(predicted_label)) {
setnames(data, old = predicted_label, new = "predicted_label")
}

data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- new_forecast(data, "forecast_nominal")
assert_forecast(data)
return(data)
}


#' @export
#' @keywords check-forecasts
#' @importFrom checkmate assert_names assert_set_equal test_set_equal
Expand Down Expand Up @@ -40,62 +71,13 @@ assert_forecast.forecast_nominal <- function(
}


#' @title Create a `forecast` object for nominal forecasts
#' @description
#' Nominal forecasts are a form of categorical forecasts where the possible
#' outcomes that the observed values can assume are not ordered. In that sense,
#' Nominal forecasts represent a generalisation of binary forecasts.
#' @inheritParams as_forecast
#' @param predicted_label (optional) Name of the column in `data` that denotes
#' the outcome to which a predicted probability corresponds to.
#' This column will be renamed to "predicted_label". Only applicable to
#' nominal forecasts.
#' @family functions to create forecast objects
#' @keywords as_forecast
#' @export
as_forecast_nominal <- function(data,
forecast_unit = NULL,
observed = NULL,
predicted = NULL,
predicted_label = NULL) {
assert_character(predicted_label, len = 1, null.ok = TRUE)
assert_subset(predicted_label, names(data), empty.ok = TRUE)
if (!is.null(predicted_label)) {
setnames(data, old = predicted_label, new = "predicted_label")
}

data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- new_forecast(data, "forecast_nominal")
assert_forecast(data)
return(data)
}


#' @export
#' @rdname is_forecast
is_forecast_nominal <- function(x) {
inherits(x, "forecast_nominal") && inherits(x, "forecast")
}


#' Get default metrics for nominal forecasts
#' @inheritParams get_metrics.forecast_binary
#' @description
#' For nominal forecasts, the default scoring rule is:
#' - "log_score" = [logs_nominal()]
#' @export
#' @family `get_metrics` functions
#' @keywords handle-metrics
#' @examples
#' get_metrics(example_nominal)
get_metrics.forecast_nominal <- function(x, select = NULL, exclude = NULL, ...) {
all <- list(
log_score = logs_nominal
)
select_metrics(all, select, exclude)
}


#' @importFrom stats na.omit
#' @importFrom data.table setattr
#' @rdname score
Expand Down Expand Up @@ -127,6 +109,24 @@ score.forecast_nominal <- function(forecast, metrics = get_metrics(forecast), ..
}


#' Get default metrics for nominal forecasts
#' @inheritParams get_metrics.forecast_binary
#' @description
#' For nominal forecasts, the default scoring rule is:
#' - "log_score" = [logs_nominal()]
#' @export
#' @family `get_metrics` functions
#' @keywords handle-metrics
#' @examples
#' get_metrics(example_nominal)
get_metrics.forecast_nominal <- function(x, select = NULL, exclude = NULL, ...) {
all <- list(
log_score = logs_nominal
)
select_metrics(all, select, exclude)
}


#' Nominal example data
#'
#' A data set with predictions for COVID-19 cases and deaths submitted to the
Expand Down
58 changes: 29 additions & 29 deletions R/class-forecast-point.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,32 @@
#' @title Create a `forecast` object for point forecasts
#' @description
#' Create a `forecast` object for point forecasts. See more information on
#' forecast types and expected input formats by calling `?`[as_forecast()].
#' @inherit as_forecast params
#' @param ... Unused
#' @family functions to create forecast objects
#' @export
#' @keywords as_forecast transform
as_forecast_point <- function(data, ...) {
UseMethod("as_forecast_point")
}


#' @rdname as_forecast_point
#' @export
#' @importFrom cli cli_warn
as_forecast_point.default <- function(data,
forecast_unit = NULL,
observed = NULL,
predicted = NULL,
...) {
data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- new_forecast(data, "forecast_point")
assert_forecast(data)
return(data)
}


#' @export
#' @rdname assert_forecast
#' @importFrom cli cli_abort
Expand Down Expand Up @@ -29,35 +58,6 @@ is_forecast_point <- function(x) {
}


#' @title Create a `forecast` object for point forecasts
#' @description
#' Create a `forecast` object for point forecasts. See more information on
#' forecast types and expected input formats by calling `?`[as_forecast()].
#' @inherit as_forecast params
#' @param ... Unused
#' @family functions to create forecast objects
#' @export
#' @keywords as_forecast transform
as_forecast_point <- function(data, ...) {
UseMethod("as_forecast_point")
}


#' @rdname as_forecast_point
#' @export
#' @importFrom cli cli_warn
as_forecast_point.default <- function(data,
forecast_unit = NULL,
observed = NULL,
predicted = NULL,
...) {
data <- as_forecast_generic(data, forecast_unit, observed, predicted)
data <- new_forecast(data, "forecast_point")
assert_forecast(data)
return(data)
}


#' @importFrom Metrics se ae ape
#' @importFrom stats na.omit
#' @importFrom data.table setattr copy
Expand Down
73 changes: 37 additions & 36 deletions R/class-forecast-quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,50 @@ as_forecast_quantile.default <- function(data,
}


#' @export
#' @rdname assert_forecast
#' @keywords validate-forecast-object
assert_forecast.forecast_quantile <- function(
forecast, forecast_type = NULL, verbose = TRUE, ...

Check warning on line 44 in R/class-forecast-quantile.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/class-forecast-quantile.R,line=44,col=4,[indentation_linter] Indentation should be 2 spaces but is 4 spaces.
) {
forecast <- assert_forecast_generic(forecast, verbose)
assert_forecast_type(forecast, actual = "quantile", desired = forecast_type)
assert_numeric(forecast$quantile_level, lower = 0, upper = 1)
return(invisible(NULL))
}


#' @export
#' @rdname is_forecast
is_forecast_quantile <- function(x) {
inherits(x, "forecast_quantile") && inherits(x, "forecast")
}


#' @rdname as_forecast_point
#' @description
#' When converting a `forecast_quantile` object into a `forecast_point` object,
#' the 0.5 quantile is extracted and returned as the point forecast.
#' @export
#' @keywords as_forecast
as_forecast_point.forecast_quantile <- function(data, ...) {
assert_forecast(data, verbose = FALSE)
assert_subset(0.5, unique(data$quantile_level))

# At end of this function, the object will have be turned from a
# forecast_quantile to a forecast_point and we don't want to validate it as a
# forecast_point during the conversion process. The correct class is restored
# at the end.
data <- as.data.table(data)

forecast <- data[quantile_level == 0.5]
forecast[, "quantile_level" := NULL]

point_forecast <- new_forecast(forecast, "forecast_point")
return(point_forecast)
}


#' @importFrom stats na.omit
#' @importFrom data.table `:=` as.data.table rbindlist %like% setattr copy
#' @rdname score
Expand Down Expand Up @@ -90,42 +127,6 @@ score.forecast_quantile <- function(forecast, metrics = get_metrics(forecast), .
return(scores[])
}

#' @export
#' @rdname assert_forecast
#' @keywords validate-forecast-object
assert_forecast.forecast_quantile <- function(
forecast, forecast_type = NULL, verbose = TRUE, ...
) {
forecast <- assert_forecast_generic(forecast, verbose)
assert_forecast_type(forecast, actual = "quantile", desired = forecast_type)
assert_numeric(forecast$quantile_level, lower = 0, upper = 1)
return(invisible(NULL))
}


#' @rdname as_forecast_point
#' @description
#' When converting a `forecast_quantile` object into a `forecast_point` object,
#' the 0.5 quantile is extracted and returned as the point forecast.
#' @export
#' @keywords as_forecast
as_forecast_point.forecast_quantile <- function(data, ...) {
assert_forecast(data, verbose = FALSE)
assert_subset(0.5, unique(data$quantile_level))

# At end of this function, the object will have be turned from a
# forecast_quantile to a forecast_point and we don't want to validate it as a
# forecast_point during the conversion process. The correct class is restored
# at the end.
data <- as.data.table(data)

forecast <- data[quantile_level == 0.5]
forecast[, "quantile_level" := NULL]

point_forecast <- new_forecast(forecast, "forecast_point")
return(point_forecast)
}


#' Get default metrics for quantile-based forecasts
#'
Expand Down
Loading

0 comments on commit 95ee27c

Please sign in to comment.