Skip to content

Commit

Permalink
Lint files
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse committed Nov 2, 2023
1 parent b05db0d commit 308eae4
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 53 deletions.
23 changes: 2 additions & 21 deletions R/check-input-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ assert_not_null <- function(...) {
#'
#' @keywords internal
assert_equal_length <- function(...,
one_allowed = TRUE,
call_levels_up = 2) {
one_allowed = TRUE,
call_levels_up = 2) {
vars <- list(...)
lengths <- lengths(vars)

Expand Down Expand Up @@ -283,25 +283,6 @@ check_duplicates <- function(data, forecast_unit = NULL) {
}


# Function to check input for methods
# there should not be a name clash between a metric and a column name
# --> maybe this should be checked by the actual method that computes scores
# check whether any column name is a scoringutils metric
# clashing_colnames <- intersect(colnames(data), available_metrics())
# if (length(clashing_colnames) > 0) {
# clashing_colnames <- paste0('"', clashing_colnames, '"')
# warnings <- c(
# warnings,
# paste0(
# "At least one column in the data ",
# "(", toString(clashing_colnames), ") ",
# "corresponds to the name of a metric that will be computed by ",
# "scoringutils. Please check `available_metrics()`"
# )
# )
# }


#' Check column names are present in a data.frame
#' @description
#' The functions loops over the column names and checks whether they are
Expand Down
8 changes: 5 additions & 3 deletions R/get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ get_forecast_type <- function(data) {
} else if (test_forecast_type_is_point(data)) {
forecast_type <- "point"
} else {
stop("Checking `data`: input doesn't satisfy the criteria for any forecast type.",
"Are you missing a column `quantile` or `sample_id`?",
"Please check the vignette for additional info.")
stop(
"Checking `data`: input doesn't satisfy criteria for any forecast type.",
"Are you missing a column `quantile` or `sample_id`?",
"Please check the vignette for additional info."
)
}
conflict <- check_attribute_conflict(data, "forecast_type", forecast_type)
if (!is.logical(conflict)) {
Expand Down
3 changes: 0 additions & 3 deletions R/pairwise-comparisons.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ pairwise_comparison <- function(scores,
metric = "auto",
baseline = NULL,
...) {

# metric_names <- get_metrics(scores)
metric <- match.arg(metric, c("auto", available_metrics()))

if (!is.data.table(scores)) {
scores <- as.data.table(scores)
} else {
Expand Down
27 changes: 16 additions & 11 deletions R/score_quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ score_quantile <- function(data,
data <- remove_na_observed_predicted(data)

# make sure to have both quantile as well as range format --------------------
range_data <- quantile_to_interval( data,
range_data <- quantile_to_interval(
data,
keep_quantile_col = FALSE
)
# adds the range column to the quantile data set
quantile_data <- range_long_to_quantile(range_data,
quantile_data <- range_long_to_quantile(
range_data,
keep_range_col = TRUE
)

Expand Down Expand Up @@ -96,14 +98,15 @@ score_quantile <- function(data,

# compute absolute and squared error for point forecasts
# these are marked by an NA in range, and a numeric value for point
if (any(c("se_point, se_mean, ae_point", "ae_median", "absolute_error") %in% metrics)) {
if ("point" %in% colnames(res)) {
res[
is.na(range) & is.numeric(point),
`:=`(ae_point = abs_error(predicted = point, observed),
se_point = squared_error(predicted = point, observed))
]
}
compute_point <- any(
c("se_point, se_mean, ae_point", "ae_median", "absolute_error") %in% metrics
)
if (compute_point && "point" %in% colnames(res)) {
res[
is.na(range) & is.numeric(point),
`:=`(ae_point = abs_error(predicted = point, observed),
se_point = squared_error(predicted = point, observed))
]
}

# calculate scores on quantile format ----------------------------------------
Expand Down Expand Up @@ -156,7 +159,9 @@ score_quantile <- function(data,
}

# delete internal columns before returning result
res <- delete_columns(res, c("upper", "lower", "boundary", "point", "observed"))
res <- delete_columns(
res, c("upper", "lower", "boundary", "point", "observed")
)

return(res[])
}
7 changes: 5 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,14 @@ remove_scoringutils_class <- function(object) {
return(object)
}
# check if "scoringutils_" is in name of any class
if (any(grepl("scoringutils_", class(object)))) {
if (any(grepl("scoringutils_", class(object), fixed = TRUE))) {
stored_attributes <- get_scoringutils_attributes(object)

# remove all classes that contain "scoringutils_"
class(object) <- class(object)[!grepl("scoringutils_", class(object))]
class(object) <- class(object)[!grepl(
"scoringutils_", class(object),
fixed = TRUE
)]

# remove all scoringutils attributes
object <- strip_attributes(object, names(stored_attributes))
Expand Down
22 changes: 9 additions & 13 deletions R/utils_data_handling.R
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,10 @@ quantile_to_interval <- function(...) {
#' @importFrom data.table copy
#' @export
#' @rdname quantile_to_interval
quantile_to_interval.data.frame <- function(
dt,
format = "long",
keep_quantile_col = FALSE,
...
) {
quantile_to_interval.data.frame <- function(dt,
format = "long",
keep_quantile_col = FALSE,
...) {
if (!is.data.table(dt)) {
dt <- data.table::as.data.table(dt)
} else {
Expand All @@ -231,7 +229,7 @@ quantile_to_interval.data.frame <- function(

if (format == "wide") {
delete_columns(dt, "quantile")
dt <- dcast(dt, ... ~ boundary, value.var = c("predicted"))
dt <- dcast(dt, ... ~ boundary, value.var = "predicted")
}
return(dt[])
}
Expand All @@ -249,12 +247,10 @@ quantile_to_interval.data.frame <- function(
#' `forecast_id` and `range`.
#' @export
#' @rdname quantile_to_interval
quantile_to_interval.numeric <- function(
observed,
predicted,
quantile,
...
) {
quantile_to_interval.numeric <- function(observed,
predicted,
quantile,
...) {
assert_input_quantile(observed, predicted, quantile)

n <- length(observed)
Expand Down

0 comments on commit 308eae4

Please sign in to comment.