Skip to content

Commit

Permalink
Merge pull request #377 from epiforecasts/rework-quantile-to-interval…
Browse files Browse the repository at this point in the history
…-format

Rework quantile to interval format
  • Loading branch information
nikosbosse authored Nov 15, 2023
2 parents a0fad34 + 88acb9a commit 9ba205a
Show file tree
Hide file tree
Showing 14 changed files with 382 additions and 106 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

S3method(plot,scoringutils_available_forecasts)
S3method(print,scoringutils_check)
S3method(quantile_to_interval,data.frame)
S3method(quantile_to_interval,numeric)
S3method(score,default)
S3method(score,scoringutils_binary)
S3method(score,scoringutils_point)
Expand Down
23 changes: 2 additions & 21 deletions R/check-input-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ assert_not_null <- function(...) {
#'
#' @keywords internal
assert_equal_length <- function(...,
one_allowed = TRUE,
call_levels_up = 2) {
one_allowed = TRUE,
call_levels_up = 2) {
vars <- list(...)
lengths <- lengths(vars)

Expand Down Expand Up @@ -283,25 +283,6 @@ check_duplicates <- function(data, forecast_unit = NULL) {
}


# Function to check input for methods
# there should not be a name clash between a metric and a column name
# --> maybe this should be checked by the actual method that computes scores
# check whether any column name is a scoringutils metric
# clashing_colnames <- intersect(colnames(data), available_metrics())
# if (length(clashing_colnames) > 0) {
# clashing_colnames <- paste0('"', clashing_colnames, '"')
# warnings <- c(
# warnings,
# paste0(
# "At least one column in the data ",
# "(", toString(clashing_colnames), ") ",
# "corresponds to the name of a metric that will be computed by ",
# "scoringutils. Please check `available_metrics()`"
# )
# )
# }


#' Check column names are present in a data.frame
#' @description
#' The functions loops over the column names and checks whether they are
Expand Down
8 changes: 5 additions & 3 deletions R/get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ get_forecast_type <- function(data) {
} else if (test_forecast_type_is_point(data)) {
forecast_type <- "point"
} else {
stop("Checking `data`: input doesn't satisfy the criteria for any forecast type.",
"Are you missing a column `quantile` or `sample_id`?",
"Please check the vignette for additional info.")
stop(
"Checking `data`: input doesn't satisfy criteria for any forecast type.",
"Are you missing a column `quantile` or `sample_id`?",
"Please check the vignette for additional info."
)
}
conflict <- check_attribute_conflict(data, "forecast_type", forecast_type)
if (!is.logical(conflict)) {
Expand Down
3 changes: 0 additions & 3 deletions R/pairwise-comparisons.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ pairwise_comparison <- function(scores,
metric = "auto",
baseline = NULL,
...) {

# metric_names <- get_metrics(scores)
metric <- match.arg(metric, c("auto", available_metrics()))

if (!is.data.table(scores)) {
scores <- as.data.table(scores)
} else {
Expand Down
2 changes: 1 addition & 1 deletion R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ plot_predictions <- function(data,
# range data

if (test_forecast_type_is_quantile(data)) {
forecasts <- quantile_to_range_long(
forecasts <- quantile_to_interval(
forecasts,
keep_quantile_col = FALSE
)
Expand Down
27 changes: 16 additions & 11 deletions R/score_quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ score_quantile <- function(data,
data <- remove_na_observed_predicted(data)

# make sure to have both quantile as well as range format --------------------
range_data <- quantile_to_range_long(data,
range_data <- quantile_to_interval(
data,
keep_quantile_col = FALSE
)
# adds the range column to the quantile data set
quantile_data <- range_long_to_quantile(range_data,
quantile_data <- range_long_to_quantile(
range_data,
keep_range_col = TRUE
)

Expand Down Expand Up @@ -96,14 +98,15 @@ score_quantile <- function(data,

# compute absolute and squared error for point forecasts
# these are marked by an NA in range, and a numeric value for point
if (any(c("se_point, se_mean, ae_point", "ae_median", "absolute_error") %in% metrics)) {
if ("point" %in% colnames(res)) {
res[
is.na(range) & is.numeric(point),
`:=`(ae_point = abs_error(predicted = point, observed),
se_point = squared_error(predicted = point, observed))
]
}
compute_point <- any(
c("se_point, se_mean, ae_point", "ae_median", "absolute_error") %in% metrics
)
if (compute_point && "point" %in% colnames(res)) {
res[
is.na(range) & is.numeric(point),
`:=`(ae_point = abs_error(predicted = point, observed),
se_point = squared_error(predicted = point, observed))
]
}

# calculate scores on quantile format ----------------------------------------
Expand Down Expand Up @@ -156,7 +159,9 @@ score_quantile <- function(data,
}

# delete internal columns before returning result
res <- delete_columns(res, c("upper", "lower", "boundary", "point", "observed"))
res <- delete_columns(
res, c("upper", "lower", "boundary", "point", "observed")
)

return(res[])
}
7 changes: 5 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,14 @@ remove_scoringutils_class <- function(object) {
return(object)
}
# check if "scoringutils_" is in name of any class
if (any(grepl("scoringutils_", class(object)))) {
if (any(grepl("scoringutils_", class(object), fixed = TRUE))) {
stored_attributes <- get_scoringutils_attributes(object)

# remove all classes that contain "scoringutils_"
class(object) <- class(object)[!grepl("scoringutils_", class(object))]
class(object) <- class(object)[!grepl(
"scoringutils_", class(object),
fixed = TRUE
)]

# remove all scoringutils attributes
object <- strip_attributes(object, names(stored_attributes))
Expand Down
119 changes: 91 additions & 28 deletions R/utils_data_handling.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,50 +158,113 @@ range_long_to_quantile <- function(data,
}


#' @title Change Data from a Plain Quantile Format to a Long Range Format
#'
#' Transform From a Quantile Format to an Interval Format
#' @description
#' **Quantile format**
#' In a quantile format, a prediction is characterised by one or multiple
#' predicted values and the corresponding quantile levels. For example, a
#' prediction in a quantile format could be represented by the 0.05, 0.25, 0.5,
#' 0.75 and 0.95 quantiles of the predictive distribution.
#'
#' Transform data from a format that uses quantiles only to one that uses
#' interval ranges to denote quantiles.
#'
#' @param data a data.frame in quantile format
#' **Interval format**
#' In the interval format, two quantiles are assumed to form a prediction
#' interval. Prediction intervals need to be symmetric around the median and
#' are characterised by a lower and an upper bound. The lower bound is defined
#' by the lower quantile and the upper bound is defined by the upper quantile.
#' A 90% prediction interval, for example, covers 90% of the probability mass
#' and is defined by the 5% and 95% quantiles. A forecast could therefore
#' be characterised by one or multiple prediction intervals, e.g. the lower
#' and upper bounds of the 50% and 90% prediction intervals (corresponding to
#' the 0.25 and 0.75 as well as the 0.05 and 0.095 quantiles).
#' @param ... method arguments
quantile_to_interval <- function(...) {
UseMethod("quantile_to_interval")
}


#' @param dt a data.table with columns `quantile` and `predicted`
#' @param format the format of the output. Either "long" or "wide". If "long"
#' (the default), there will be a column `boundary` (with values either "upper"
#' or "lower" and a column `range` that contains the range of the interval.
#' If "wide", there will be a column `range` and two columns
#' `lower` and `upper` that contain the lower and upper bounds of the
#' prediction interval, respectively.
#' @param keep_quantile_col keep the quantile column in the final
#' output after transformation (default is FALSE)
#' @return a data.frame in a long interval range format
#' output after transformation (default is FALSE). This only works if
#' `format = "long"`. If `format = "wide"`, the quantile column will always be
#' dropped.
#' @return
#' *quantile_to_interval.data.frame*:
#' a data.frame in an interval format (either "long" or "wide"), with or
#' without a quantile column. Rows will not be reordered.
#' @importFrom data.table copy
#' @keywords internal

quantile_to_range_long <- function(data,
keep_quantile_col = TRUE) {
data <- data.table::as.data.table(data)
#' @export
#' @rdname quantile_to_interval
quantile_to_interval.data.frame <- function(dt,
format = "long",
keep_quantile_col = FALSE,
...) {
if (!is.data.table(dt)) {
dt <- data.table::as.data.table(dt)
} else {
# use copy to avoid
dt <- copy(dt)
}

data[, boundary := ifelse(quantile <= 0.5, "lower", "upper")]
data[, range := ifelse(
dt[, boundary := ifelse(quantile <= 0.5, "lower", "upper")]
dt[, range := ifelse(
boundary == "lower",
round((1 - 2 * quantile) * 100, 10),
round((2 * quantile - 1) * 100, 10)
)]

# add median quantile
median <- data[quantile == 0.5, ]
median <- dt[quantile == 0.5, ]
median[, boundary := "upper"]

data <- data.table::rbindlist(list(data, median))

dt <- data.table::rbindlist(list(dt, median))
if (!keep_quantile_col) {
data[, "quantile" := NULL]
dt[, quantile := NULL]
}

# if only point forecasts are scored, we only have NA values for range and
# boundary. In that instance we need to set the type of the columns
# explicitly to avoid future collisions.
data[, `:=`(
boundary = as.character(boundary),
range = as.numeric(range)
)]
if (format == "wide") {
delete_columns(dt, "quantile")
dt <- dcast(dt, ... ~ boundary, value.var = "predicted")
}
return(dt[])
}

return(data[])

#' @param observed a numeric vector of observed values of size n
#' @param predicted a numeric vector of predicted values of size n x N. If
#' `observed` is a single number, then `predicted` can be a vector of length N
#' @param quantile a numeric vector of quantile levels of size N
#' @return
#' *quantile_to_interval.numeric*:
#' a data.frame in a wide interval format with columns `forecast_id`,
#' `observed`, `lower`, `upper`, and `range`. The `forecast_id` column is a
#' unique identifier for each forecast. Rows will be reordered according to
#' `forecast_id` and `range`.
#' @export
#' @rdname quantile_to_interval
quantile_to_interval.numeric <- function(observed,
predicted,
quantile,
...) {
assert_input_quantile(observed, predicted, quantile)

n <- length(observed)
N <- length(quantile)

dt <- data.table(
forecast_id = rep(1:n, each = N),
observed = rep(observed, each = N),
predicted = as.vector(t(predicted)),
quantile = quantile
)
out <- quantile_to_interval(dt, format = "wide")
out <- out[order(forecast_id, range)]
return(out)
}


Expand Down Expand Up @@ -240,7 +303,7 @@ sample_to_range_long <- function(data,
type = type
)

data <- quantile_to_range_long(data, keep_quantile_col = keep_quantile_col)
data <- quantile_to_interval(data, keep_quantile_col = keep_quantile_col)

return(data[])
}
2 changes: 2 additions & 0 deletions R/z_globalVariables.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ globalVariables(c(
"dss",
"existing",
"fill_col",
"forecast_id",
"hist",
"identifCol",
"Interval_Score",
"interval_range",
"overprediction",
"underprediction",
"quantile_coverage",
Expand Down
67 changes: 67 additions & 0 deletions man/quantile_to_interval.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 9ba205a

Please sign in to comment.