Skip to content

Commit

Permalink
Merge branch 'develop' into expose-functions2
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse authored Jan 2, 2024
2 parents de70ac0 + d04a82c commit 5efd52d
Show file tree
Hide file tree
Showing 114 changed files with 1,005 additions and 1,239 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
^Meta$
^_pkgdown\.yml$
^inst/manuscript/manuscript_cache$
^inst/manuscript/.trackdown$
^\.lintr$
^docs$
^\.devcontainer$
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ inst/manuscript/manuscript.blg
inst/manuscript/manuscript.pdf
inst/manuscript/manuscript.tex
inst/manuscript/manuscript_files/
inst/manuscript/.trackdown
docs
..bfg-report/
.DS_Store
Expand Down
2 changes: 1 addition & 1 deletion .lintr
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ linters: linters_with_tags(
exclusions: c(
list.files("tests", recursive = TRUE, full.names = TRUE),
list.files("inst", recursive = TRUE, full.names = TRUE),
"vignettes/metric-details.Rmd"
list.files("vignettes", pattern = ".R$", full.names = TRUE)
)
exclude: "# nolint"
35 changes: 17 additions & 18 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,56 +1,53 @@
# Generated by roxygen2: do not edit by hand

S3method(plot,scoringutils_available_forecasts)
S3method(as_forecast,default)
S3method(print,scoringutils_check)
S3method(quantile_to_interval,data.frame)
S3method(quantile_to_interval,numeric)
S3method(score,default)
S3method(score,scoringutils_binary)
S3method(score,scoringutils_point)
S3method(score,scoringutils_quantile)
S3method(score,scoringutils_sample)
S3method(validate,default)
S3method(validate,scoringutils_binary)
S3method(validate,scoringutils_point)
S3method(validate,scoringutils_quantile)
S3method(validate,scoringutils_sample)
S3method(score,forecast_binary)
S3method(score,forecast_point)
S3method(score,forecast_quantile)
S3method(score,forecast_sample)
S3method(validate_forecast,forecast_binary)
S3method(validate_forecast,forecast_point)
S3method(validate_forecast,forecast_quantile)
S3method(validate_forecast,forecast_sample)
export(abs_error)
export(add_coverage)
export(add_pairwise_comparison)
export(ae_median_quantile)
export(ae_median_sample)
export(avail_forecasts)
export(available_forecasts)
export(as_forecast)
export(available_metrics)
export(bias_quantile)
export(bias_range)
export(bias_sample)
export(brier_score)
export(correlation)
export(crps_sample)
export(dispersion)
export(dss_sample)
export(get_duplicate_forecasts)
export(get_forecast_counts)
export(get_forecast_type)
export(get_forecast_unit)
export(interval_coverage_deviation_quantile)
export(interval_coverage_dev_quantile)
export(interval_coverage_quantile)
export(interval_coverage_sample)
export(interval_score)
export(log_shift)
export(logs_binary)
export(logs_sample)
export(mad_sample)
export(make_NA)
export(make_na)
export(merge_pred_and_obs)
export(new_scoringutils)
export(new_forecast)
export(overprediction)
export(pairwise_comparison)
export(pit)
export(pit_sample)
export(plot_avail_forecasts)
export(plot_correlation)
export(plot_forecast_counts)
export(plot_heatmap)
export(plot_interval_coverage)
export(plot_pairwise_comparison)
Expand All @@ -61,6 +58,7 @@ export(plot_ranges)
export(plot_score_table)
export(plot_wis)
export(quantile_score)
export(quantile_to_interval)
export(run_safely)
export(sample_to_quantile)
export(score)
Expand All @@ -72,7 +70,7 @@ export(summarize_scores)
export(theme_scoringutils)
export(transform_forecasts)
export(underprediction)
export(validate)
export(validate_forecast)
export(validate_general)
export(wis)
importFrom(Metrics,ae)
Expand All @@ -87,6 +85,7 @@ importFrom(checkmate,assert_list)
importFrom(checkmate,assert_logical)
importFrom(checkmate,assert_number)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,assert_string)
importFrom(checkmate,assert_vector)
importFrom(checkmate,check_atomic_vector)
importFrom(checkmate,check_data_frame)
Expand Down
16 changes: 10 additions & 6 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The update introduces breaking changes. If you want to keep using the older vers

## Package updates
- In `score()`, required columns "true_value" and "prediction" were renamed and replaced by required columns "observed" and "predicted". Scoring functions now also use the function arguments "observed" and "predicted" everywhere consistently.
- The overall scoring workflow was updated. `score()` is now a generic function that dispatches the correct method based on the forecast type. forecast types currently supported are "binary", "point", "sample" and "quantile" with corresponding classes "forecast_binary", "forecast_point", "forecast_sample" and "forecast_quantile". An object of class `forecast_*` can be created using the function `as_forecast()`, which also replaces the previous function `check_forecasts()` (see more information below).
- Scoring functions received a consistent interface and input checks:
- metrics for binary forecasts:
- `observed`: factor with exactly 2 levels
Expand All @@ -20,15 +21,18 @@ The update introduces breaking changes. If you want to keep using the older vers
- `observed`: numeric, either a scalar or a vector
- `predicted`: numeric, a vector (if `observed` is a scalar) or a matrix (if `observed` is a vector)
- `quantile`: numeric, a vector with quantile-levels. Can alternatively be a matrix of the same shape as `predicted`.
- `check_forecasts()` was replaced by a new function `validate()`. `validate()` validates the input and in that sense fulfills the purpose of `check_forecasts()`. It has different methods: `validate.default()` assigns the input a class based on their forecast type. Other methods validate the input specifically for the various forecast types.
- `check_forecasts()` was replaced by a different workflow. There now is a function, `as_forecast()`, that determines forecast type of the data, constructs a forecasting object and validates it using the function `validate_forecast()` (a generic that dispatches the correct method based on the forecast type). Objects of class `forecast_binary`, `forecast_point`, `forecast_sample` and `forecast_quantile` have print methods that fulfill the functionality of `check_forecasts()`.
- The functionality for computing pairwise comparisons was now split from `summarise_scores()`. Instead of doing pairwise comparisons as part of summarising scores, a new function, `add_pairwise_comparison()`, was introduced that takes summarised scores as an input and adds pairwise comparisons to it.
- `add_coverage()` was reworked completely. It's new purpose is now to add coverage information to the raw forecast data (essentially fulfilling some of the functionality that was previously covered by `score_quantile()`)
- Support for the interval format was mostly dropped (see PR #525 by @nikosbosse and reviewed by @seabbs)
- The function `bias_range()` was removed (users should now use `bias_quantile()` instead)
- The function `interval_score()` was made an internal function rather than being exported to users. We recommend using `wis()` instead.
- The function `find_duplicates()` was renamed to `get_duplicate_forecasts()`
- Changes to `avail_forecasts()` and `plot_avail_forecasts()`:
- The function `avail_forecasts()` was renamed to `available_forecasts()` for consistency with `available_metrics()`. The old function, `avail_forecasts()` is still available as an alias, but will be removed in the future.
- For clarity, the output column in `avail_forecasts()` was renamed from "Number forecasts" to "count".
- `available_forecasts()` now also displays combinations where there are 0 forecasts, instead of silently dropping corresponding rows.
- `plot_avail_forecasts()` has been deprecated in favour of an S3 method for `plot()`. An alias is still available, but will be removed in the future.
- The function `avail_forecasts()` was renamed to `get_forecast_counts()`. This represents a change in the naming convention where we aim to name functions that provide the user with additional useful information about the data with a prefix "get_". Sees Issue #403 and #521 and PR #511 by @nikosbosse and reviewed by @seabbs for details.
- For clarity, the output column in `get_forecast_counts()` was renamed from "Number forecasts" to "count".
- `get_forecast_counts()` now also displays combinations where there are 0 forecasts, instead of silently dropping corresponding rows.
- `plot_avail_forecasts()` was renamed `plot_forecast_counts()` in line with the change in the function name. The `x` argument no longer has a default value, as the value will depend on the data provided by the user.
- The deprecated `..density..` was replaced with `after_stat(density)` in ggplot calls.
- Files ending in ".Rda" were renamed to ".rds" where appropriate when used together with `saveRDS()` or `readRDS()`.
- `score()` now calls `na.omit()` on the data, instead of only removing rows with missing values in the columns `observed` and `predicted`. This is because `NA` values in other columns can also mess up e.g. grouping of forecasts according to the unit of a single forecast.
Expand Down Expand Up @@ -190,7 +194,7 @@ to a function `summarise_scores()`
- New function `check_forecasts()` to analyse input data before scoring
- New function `correlation()` to compute correlations between different metrics
- New function `add_coverage()` to add coverage for specific central prediction intervals.
- New function `available_forecasts()` allows to visualise the number of available forecasts.
- New function `avail_forecasts()` allows to visualise the number of available forecasts.
- New function `find_duplicates()` to find duplicate forecasts which cause an error.
- All plotting functions were renamed to begin with `plot_`. Arguments were
simplified.
Expand Down
10 changes: 3 additions & 7 deletions R/add_coverage.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,13 @@
#' @export
add_coverage <- function(data) {
stored_attributes <- get_scoringutils_attributes(data)
data <- validate(data)
data <- as_forecast(data)
forecast_unit <- get_forecast_unit(data)
data_cols <- colnames(data) # store so we can reset column order later

# what happens if quantiles are not symmetric around the median?
# should things error? Also write tests for that.
interval_data <- quantile_to_interval(data, format = "wide")
interval_data[, interval_coverage := ifelse(
observed <= upper & observed >= lower,
TRUE,
FALSE)
interval_data[,
interval_coverage := (observed <= upper) & (observed >= lower)
][, c("lower", "upper", "observed") := NULL]

data[, range := get_range_from_quantile(quantile)]
Expand Down
26 changes: 4 additions & 22 deletions R/available_forecasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
#' @examples
#' data.table::setDTthreads(1) # only needed to avoid issues on CRAN
#'
#' available_forecasts(example_quantile,
#' get_forecast_counts(example_quantile,
#' by = c("model", "target_type")
#' )
available_forecasts <- function(data,
get_forecast_counts <- function(data,
by = NULL,
collapse = c("quantile", "sample_id")) {

data <- validate(data)
data <- as_forecast(data)
forecast_unit <- attr(data, "forecast_unit")
data <- na.omit(data)

Expand All @@ -58,7 +58,7 @@ available_forecasts <- function(data,
data <- data[data[, .I[1], by = collapse_by]$V1]

# count number of rows = number of forecasts
out <- data[, .(`count` = .N), by = by]
out <- data[, .(count = .N), by = by]

# make sure that all combinations in "by" are included in the output (with
# count = 0). To achieve that, take the unique values in data and expand grid
Expand All @@ -70,23 +70,5 @@ available_forecasts <- function(data,
out <- merge(out, out_empty, by = by, all.y = TRUE)
out[, count := nafill(count, fill = 0)]

class(out) <- c("scoringutils_available_forecasts", class(out))

return(out[])
}

#' @title Count Number of Available Forecasts `r lifecycle::badge("deprecated")`
#' @details `r lifecycle::badge("deprecated")` Deprecated in 1.2.2. Use
#' [available_forecasts()] instead.
#' @inherit available_forecasts
#' @keywords check-forecasts
#' @export
avail_forecasts <- function(data,
by = NULL,
collapse = c("quantile", "sample")) {
lifecycle::deprecate_warn(
"1.2.2", "avail_forecasts()",
"available_forecasts()"
)
available_forecasts(data, by, collapse)
}
40 changes: 17 additions & 23 deletions R/check-input-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#' @inheritDotParams checkmate::check_numeric
#' @importFrom checkmate check_atomic_vector check_numeric
#' @inherit document_check_functions return
#' @keywords internal
#' @keywords internal_input_check
check_numeric_vector <- function(x, ...) {
# check functions must return TRUE on success
# and a custom error message otherwise
Expand Down Expand Up @@ -36,7 +36,7 @@ check_numeric_vector <- function(x, ...) {
#'
#' @return None. Function errors if quantiles are invalid.
#'
#' @keywords internal
#' @keywords internal_input_check
check_quantiles <- function(quantiles, name = "quantiles", range = c(0, 1)) {
if (any(quantiles < range[1]) || any(quantiles > range[2])) {
stop(name, " must be between ", range[1], " and ", range[2])
Expand All @@ -57,7 +57,7 @@ check_quantiles <- function(quantiles, name = "quantiles", range = c(0, 1)) {
#' @param expr an expression to be evaluated
#' @importFrom checkmate assert assert_numeric check_matrix
#' @inherit document_check_functions return
#' @keywords internal
#' @keywords internal_input_check
check_try <- function(expr) {
result <- try(expr, silent = TRUE)
if (is.null(result)) {
Expand All @@ -79,7 +79,7 @@ check_try <- function(expr) {
#' @return The function returns `NULL`, but throws an error if the variable is
#' missing.
#'
#' @keywords internal
#' @keywords internal_input_check
assert_not_null <- function(...) {
vars <- list(...)
varnames <- names(vars)
Expand Down Expand Up @@ -112,7 +112,7 @@ assert_not_null <- function(...) {
#' within another checking function.
#' @inherit document_assert_functions return
#'
#' @keywords internal
#' @keywords internal_input_check
assert_equal_length <- function(...,
one_allowed = TRUE,
call_levels_up = 2) {
Expand Down Expand Up @@ -161,7 +161,7 @@ assert_equal_length <- function(...,
#' @param attribute The name of the attribute to check
#' @param expected The expected value of the attribute
#' @inherit document_check_functions return
#' @keywords internal
#' @keywords internal_input_check
check_attribute_conflict <- function(object, attribute, expected) {
existing <- attr(object, attribute)
if (is.vector(existing) && is.vector(expected)) {
Expand All @@ -175,7 +175,7 @@ check_attribute_conflict <- function(object, attribute, expected) {
"from what's expected based on the data.\n",
"Existing: ", toString(existing), "\n",
"Expected: ", toString(expected), "\n",
"Running `validate()` again might solve the problem"
"Running `as_forecast()` again might solve the problem"
)
return(msg)
}
Expand All @@ -188,8 +188,9 @@ check_attribute_conflict <- function(object, attribute, expected) {
#' @description
#' Check whether the data.table has a column called `model`.
#' If not, a column called `model` is added with the value `Unspecified model`.
#' @inheritParams score
#' @return The data.table with a column called `model`
#' @keywords internal
#' @keywords internal_input_check
assure_model_column <- function(data) {
if (!("model" %in% colnames(data))) {
message(
Expand All @@ -208,7 +209,7 @@ assure_model_column <- function(data) {
#' returns TRUE and a string with an error message otherwise.
#' @param forecast_unit Character vector denoting the unit of a single forecast.
#' @inherit document_check_functions params return
#' @keywords internal
#' @keywords internal_input_check
check_number_per_forecast <- function(data, forecast_unit) {
# check whether there are the same number of quantiles, samples --------------
data[, scoringutils_InternalNumCheck := length(predicted), by = forecast_unit]
Expand All @@ -235,7 +236,7 @@ check_number_per_forecast <- function(data, forecast_unit) {
#' an error message, otherwise it returns TRUE.
#' @inherit document_check_functions params return
#'
#' @keywords internal
#' @keywords internal_input_check
check_no_NA_present <- function(data, columns) {
for (x in columns){
if (anyNA(data[[x]])) {
Expand All @@ -253,20 +254,13 @@ check_no_NA_present <- function(data, columns) {
}




# print stuff
diagnose <- function(data) {

}

#' Check that there are no duplicate forecasts
#'
#' @description
#' Runs [get_duplicate_forecasts()] and returns a message if an issue is encountered
#' @inheritParams get_duplicate_forecasts
#' @inherit document_check_functions return
#' @keywords internal
#' @keywords internal_input_check
check_duplicates <- function(data, forecast_unit = NULL) {
check_duplicates <- get_duplicate_forecasts(data, forecast_unit = forecast_unit)

Expand All @@ -290,7 +284,7 @@ check_duplicates <- function(data, forecast_unit = NULL) {
#' and returns a message with the first issue encountered.
#' @inherit document_check_functions params return
#' @importFrom checkmate assert_character
#' @keywords check-inputs
#' @keywords internal_input_check
check_columns_present <- function(data, columns) {
if (is.null(columns)) {
return(TRUE)
Expand Down Expand Up @@ -322,7 +316,7 @@ check_columns_present <- function(data, columns) {
#' are present, the function returns TRUE.
#' @inheritParams document_check_functions
#' @return Returns TRUE if all columns are present and FALSE otherwise
#' @keywords internal
#' @keywords internal_input_check
test_columns_present <- function(data, columns) {
check <- check_columns_present(data, columns)
return(is.logical(check))
Expand All @@ -334,7 +328,7 @@ test_columns_present <- function(data, columns) {
#' more columns are present, the function returns FALSE.
#' @inheritParams document_check_functions
#' @return Returns TRUE if none of the columns are present and FALSE otherwise
#' @keywords internal
#' @keywords internal_input_check
test_columns_not_present <- function(data, columns) {
if (any(columns %in% colnames(data))) {
return(FALSE)
Expand All @@ -349,7 +343,7 @@ test_columns_not_present <- function(data, columns) {
#' "quantile" and "sample_id" is present.
#' @inherit document_check_functions params return
#' @importFrom checkmate check_data_frame
#' @keywords check-inputs
#' @keywords internal_input_check
check_data_columns <- function(data) {
is_data <- check_data_frame(data, min.rows = 1)
if (!is.logical(is_data)) {
Expand All @@ -374,7 +368,7 @@ check_data_columns <- function(data) {
#' @param object An object to be checked
#' @param attribute name of an attribute to be checked
#' @inherit document_check_functions return
#' @keywords check-inputs
#' @keywords internal_input_check
check_has_attribute <- function(object, attribute) {
if (is.null(attr(object, attribute))) {
return(
Expand Down
Loading

0 comments on commit 5efd52d

Please sign in to comment.