From 308eae4b75fcfc3bbdc8bded46019d4bccfe8e2f Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Thu, 2 Nov 2023 13:38:12 +0100 Subject: [PATCH] Lint files --- R/check-input-helpers.R | 23 ++--------------------- R/get_-functions.R | 8 +++++--- R/pairwise-comparisons.R | 3 --- R/score_quantile.R | 27 ++++++++++++++++----------- R/utils.R | 7 +++++-- R/utils_data_handling.R | 22 +++++++++------------- 6 files changed, 37 insertions(+), 53 deletions(-) diff --git a/R/check-input-helpers.R b/R/check-input-helpers.R index 37c8c8a22..cfaa24b2c 100644 --- a/R/check-input-helpers.R +++ b/R/check-input-helpers.R @@ -114,8 +114,8 @@ assert_not_null <- function(...) { #' #' @keywords internal assert_equal_length <- function(..., - one_allowed = TRUE, - call_levels_up = 2) { + one_allowed = TRUE, + call_levels_up = 2) { vars <- list(...) lengths <- lengths(vars) @@ -283,25 +283,6 @@ check_duplicates <- function(data, forecast_unit = NULL) { } -# Function to check input for methods -# there should not be a name clash between a metric and a column name -# --> maybe this should be checked by the actual method that computes scores -# check whether any column name is a scoringutils metric -# clashing_colnames <- intersect(colnames(data), available_metrics()) -# if (length(clashing_colnames) > 0) { -# clashing_colnames <- paste0('"', clashing_colnames, '"') -# warnings <- c( -# warnings, -# paste0( -# "At least one column in the data ", -# "(", toString(clashing_colnames), ") ", -# "corresponds to the name of a metric that will be computed by ", -# "scoringutils. Please check `available_metrics()`" -# ) -# ) -# } - - #' Check column names are present in a data.frame #' @description #' The functions loops over the column names and checks whether they are diff --git a/R/get_-functions.R b/R/get_-functions.R index f9b88a971..22aaa47a9 100644 --- a/R/get_-functions.R +++ b/R/get_-functions.R @@ -23,9 +23,11 @@ get_forecast_type <- function(data) { } else if (test_forecast_type_is_point(data)) { forecast_type <- "point" } else { - stop("Checking `data`: input doesn't satisfy the criteria for any forecast type.", - "Are you missing a column `quantile` or `sample_id`?", - "Please check the vignette for additional info.") + stop( + "Checking `data`: input doesn't satisfy criteria for any forecast type.", + "Are you missing a column `quantile` or `sample_id`?", + "Please check the vignette for additional info." + ) } conflict <- check_attribute_conflict(data, "forecast_type", forecast_type) if (!is.logical(conflict)) { diff --git a/R/pairwise-comparisons.R b/R/pairwise-comparisons.R index d842ede72..f23316254 100644 --- a/R/pairwise-comparisons.R +++ b/R/pairwise-comparisons.R @@ -65,10 +65,7 @@ pairwise_comparison <- function(scores, metric = "auto", baseline = NULL, ...) { - - # metric_names <- get_metrics(scores) metric <- match.arg(metric, c("auto", available_metrics())) - if (!is.data.table(scores)) { scores <- as.data.table(scores) } else { diff --git a/R/score_quantile.R b/R/score_quantile.R index c7b6d893d..f2da97f04 100644 --- a/R/score_quantile.R +++ b/R/score_quantile.R @@ -32,11 +32,13 @@ score_quantile <- function(data, data <- remove_na_observed_predicted(data) # make sure to have both quantile as well as range format -------------------- - range_data <- quantile_to_interval( data, + range_data <- quantile_to_interval( + data, keep_quantile_col = FALSE ) # adds the range column to the quantile data set - quantile_data <- range_long_to_quantile(range_data, + quantile_data <- range_long_to_quantile( + range_data, keep_range_col = TRUE ) @@ -96,14 +98,15 @@ score_quantile <- function(data, # compute absolute and squared error for point forecasts # these are marked by an NA in range, and a numeric value for point - if (any(c("se_point, se_mean, ae_point", "ae_median", "absolute_error") %in% metrics)) { - if ("point" %in% colnames(res)) { - res[ - is.na(range) & is.numeric(point), - `:=`(ae_point = abs_error(predicted = point, observed), - se_point = squared_error(predicted = point, observed)) - ] - } + compute_point <- any( + c("se_point, se_mean, ae_point", "ae_median", "absolute_error") %in% metrics + ) + if (compute_point && "point" %in% colnames(res)) { + res[ + is.na(range) & is.numeric(point), + `:=`(ae_point = abs_error(predicted = point, observed), + se_point = squared_error(predicted = point, observed)) + ] } # calculate scores on quantile format ---------------------------------------- @@ -156,7 +159,9 @@ score_quantile <- function(data, } # delete internal columns before returning result - res <- delete_columns(res, c("upper", "lower", "boundary", "point", "observed")) + res <- delete_columns( + res, c("upper", "lower", "boundary", "point", "observed") + ) return(res[]) } diff --git a/R/utils.R b/R/utils.R index cdb882093..c0c0a34e1 100644 --- a/R/utils.R +++ b/R/utils.R @@ -184,11 +184,14 @@ remove_scoringutils_class <- function(object) { return(object) } # check if "scoringutils_" is in name of any class - if (any(grepl("scoringutils_", class(object)))) { + if (any(grepl("scoringutils_", class(object), fixed = TRUE))) { stored_attributes <- get_scoringutils_attributes(object) # remove all classes that contain "scoringutils_" - class(object) <- class(object)[!grepl("scoringutils_", class(object))] + class(object) <- class(object)[!grepl( + "scoringutils_", class(object), + fixed = TRUE + )] # remove all scoringutils attributes object <- strip_attributes(object, names(stored_attributes)) diff --git a/R/utils_data_handling.R b/R/utils_data_handling.R index bd9152324..d0677ca34 100644 --- a/R/utils_data_handling.R +++ b/R/utils_data_handling.R @@ -200,12 +200,10 @@ quantile_to_interval <- function(...) { #' @importFrom data.table copy #' @export #' @rdname quantile_to_interval -quantile_to_interval.data.frame <- function( - dt, - format = "long", - keep_quantile_col = FALSE, - ... -) { +quantile_to_interval.data.frame <- function(dt, + format = "long", + keep_quantile_col = FALSE, + ...) { if (!is.data.table(dt)) { dt <- data.table::as.data.table(dt) } else { @@ -231,7 +229,7 @@ quantile_to_interval.data.frame <- function( if (format == "wide") { delete_columns(dt, "quantile") - dt <- dcast(dt, ... ~ boundary, value.var = c("predicted")) + dt <- dcast(dt, ... ~ boundary, value.var = "predicted") } return(dt[]) } @@ -249,12 +247,10 @@ quantile_to_interval.data.frame <- function( #' `forecast_id` and `range`. #' @export #' @rdname quantile_to_interval -quantile_to_interval.numeric <- function( - observed, - predicted, - quantile, - ... -) { +quantile_to_interval.numeric <- function(observed, + predicted, + quantile, + ...) { assert_input_quantile(observed, predicted, quantile) n <- length(observed)