Skip to content

Commit

Permalink
Merge pull request #460 from epiforecasts/update-binary-input
Browse files Browse the repository at this point in the history
Update input formats for binary and point forecasts
  • Loading branch information
nikosbosse authored Nov 20, 2023
2 parents 1a730f4 + 67768c5 commit b778f8a
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 83 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ importFrom(checkmate,assert_list)
importFrom(checkmate,assert_logical)
importFrom(checkmate,assert_number)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,assert_vector)
importFrom(checkmate,check_atomic_vector)
importFrom(checkmate,check_data_frame)
importFrom(checkmate,check_function)
Expand Down
70 changes: 54 additions & 16 deletions R/check-inputs-scoring-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,25 +161,21 @@ check_input_interval <- function(observed, lower, upper, range) {
#' that `predicted` represents the probability that the observed value is equal
#' to the highest factor level.
#' @param predicted Input to be checked. `predicted` should be a vector of
#' length n, holding probabilities. Values represent the probability that
#' length n, holding probabilities. Alternatively, `predicted` can be a matrix
#' of size n x 1. Values represent the probability that
#' the corresponding value in `observed` will be equal to the highest
#' available factor level.
#' @importFrom checkmate assert assert_factor
#' @inherit document_assert_functions return
#' @keywords check-inputs
assert_input_binary <- function(observed, predicted) {
if (length(observed) != length(predicted)) {
stop("`observed` and `predicted` need to be ",
"of same length when scoring binary forecasts")
}
assert_factor(observed, n.levels = 2)
levels <- levels(observed)
assert(
check_numeric_vector(predicted, min.len = 1, lower = 0, upper = 1)
)
assert_factor(observed, n.levels = 2, min.len = 1)
assert_numeric(predicted, lower = 0, upper = 1)
assert_dims_ok_point(observed, predicted)
return(invisible(NULL))
}


#' @title Check that inputs are correct for binary forecast
#' @inherit assert_input_binary params description
#' @inherit document_check_functions return
Expand All @@ -200,12 +196,9 @@ check_input_binary <- function(observed, predicted) {
#' @inherit document_assert_functions return
#' @keywords check-inputs
assert_input_point <- function(observed, predicted) {
assert(check_numeric_vector(observed, min.len = 1))
assert(check_numeric_vector(predicted, min.len = 1))
if (length(observed) != length(predicted)) {
stop("`observed` and `predicted` need to be ",
"of same length when scoring point forecasts")
}
assert(check_numeric(observed))
assert(check_numeric(predicted))
assert(check_dims_ok_point(observed, predicted))
return(invisible(NULL))
}

Expand All @@ -217,3 +210,48 @@ check_input_point <- function(observed, predicted) {
result <- check_try(assert_input_point(observed, predicted))
return(result)
}


#' @title Assert Inputs Have Matching Dimensions
#' @description Function assesses whether input dimensions match. In the
#' following, n is the number of observations / forecasts. Scalar values may
#' be repeated to match the length of the other input.
#' Allowed options are therefore
#' - `observed` is vector of length 1 or length n
#' - `predicted` is
#' - a vector of of length 1 or length n
#' - a matrix with n rows and 1 column
#' @inherit assert_input_binary
#' @inherit document_assert_functions return
#' @importFrom checkmate assert_vector check_matrix check_vector assert
#' @keywords check-inputs
assert_dims_ok_point <- function(observed, predicted) {
assert_vector(observed, min.len = 1)
n_obs <- length(observed)
assert(
check_vector(predicted, min.len = 1, strict = TRUE),
check_matrix(predicted, ncols = 1, nrows = n_obs)
)
dim_p <- dim(predicted)
if (!is.null(dim_p) && (length(dim_p) > 1) && (dim_p[2] > 1)) {
stop("`predicted` must be a vector or a matrix with one column. Found ",
dim(predicted)[2], " columns")
}
n_pred <- length(as.vector(predicted))
# check that both are either of length 1 or of equal length
if ((n_obs != 1) && (n_pred != 1) && (n_obs != n_pred)) {
stop("`observed` and `predicted` must either be of length 1 or ",
"of equal length. Found ", n_obs, " and ", n_pred)
}
return(invisible(NULL))
}


#' @title Check Inputs Have Matching Dimensions
#' @inherit assert_dims_ok_point params description
#' @inherit document_check_functions return
#' @keywords check-inputs
check_dims_ok_point <- function(observed, predicted) {
result <- check_try(assert_dims_ok_point(observed, predicted))
return(result)
}
40 changes: 40 additions & 0 deletions man/assert_dims_ok_point.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/assert_input_binary.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 40 additions & 0 deletions man/check_dims_ok_point.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/check_input_binary.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit b778f8a

Please sign in to comment.