Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix code after changes in scoringutils #52

Merged
merged 4 commits into from
Sep 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions R/score_model_out.R
Original file line number Diff line number Diff line change
@@ -15,13 +15,15 @@
#'
#' @details If `metrics` is `NULL` (the default), this function chooses
#' appropriate metrics based on the `output_type` contained in the `model_out_tbl`:
#'
#' \itemize{
#' \item For `output_type == "quantile"`, we use the default metrics provided by
#' `scoringutils::metrics_quantile()`: `r names(scoringutils::metrics_quantile())`
#' \item For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating
#' \item For `output_type == "quantile"`, we use the default metrics provided by
#' `scoringutils`:
#' `r names(scoringutils::get_metrics(scoringutils::example_quantile))`
#' \item For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating
#' that the predicted variable is a nominal variable), we use the default metric
#' provided by `scoringutils::metrics_nominal()`,
#' `r names(scoringutils::metrics_nominal())`
#' provided by `scoringutils`:
#' `r names(scoringutils::get_metrics(scoringutils::example_nominal))`
#' \item For `output_type == "median"`, we use "ae_point"
#' \item For `output_type == "mean"`, we use "se_point"
#' }
@@ -46,6 +48,9 @@
#' - `output_type == "pmf"`:
#' - "log_score": log score
#'
#' See [scoringutils::get_metrics()] for more details on the default metrics
#' used by `scoringutils`.
#'
#' @examplesIf requireNamespace("hubExamples", quietly = TRUE)
#' # compute WIS and interval coverage rates at 80% and 90% levels based on
#' # quantile forecasts, summarized by the mean score for each model
@@ -143,10 +148,10 @@ get_metrics <- function(metrics, output_type, output_type_id_order) {
#' @noRd
get_metrics_default <- function(output_type, output_type_id_order) {
metrics <- switch(output_type,
quantile = scoringutils::metrics_quantile(),
pmf = scoringutils::metrics_nominal(),
mean = scoringutils::metrics_point(select = "se_point"),
median = scoringutils::metrics_point(select = "ae_point"),
quantile = scoringutils::get_metrics(scoringutils::example_quantile),
pmf = scoringutils::get_metrics(scoringutils::example_nominal),
mean = scoringutils::get_metrics(scoringutils::example_point, select = "se_point"),
median = scoringutils::get_metrics(scoringutils::example_point, select = "ae_point"),
NULL # default
)
if (is.null(metrics)) {
@@ -199,7 +204,10 @@ get_metrics_character <- function(metrics, output_type) {
)
names(interval_metric_fns) <- interval_metrics

other_metric_fns <- scoringutils::metrics_quantile(select = other_metrics)
other_metric_fns <- scoringutils::get_metrics(
scoringutils::example_quantile,
select = other_metrics
)

metric_fns <- c(other_metric_fns, interval_metric_fns)[metrics]
metrics <- metric_fns
@@ -208,13 +216,19 @@ get_metrics_character <- function(metrics, output_type) {
invalid_metrics <- metrics[!metrics %in% valid_metrics]
error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type)

metrics <- scoringutils::metrics_nominal(select = metrics)
metrics <- scoringutils::get_metrics(
scoringutils::example_nominal,
select = metrics
)
} else if (output_type %in% c("median", "mean")) {
valid_metrics <- c("ae_point", "se_point")
invalid_metrics <- metrics[!metrics %in% valid_metrics]
error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type)

metrics <- scoringutils::metrics_point(select = metrics)
metrics <- scoringutils::get_metrics(
scoringutils::example_point,
select = metrics
)
} else {
# we have already validated `output_type`, so this case should not be
# triggered; this case is just double checking in case we add something new
@@ -231,7 +245,7 @@ error_if_invalid_metrics <- function(valid_metrics, invalid_metrics, output_type
if (n > 0) {
cli::cli_abort(
c(
"`metrics` had {n} unsupported metric{?s} for
"`metrics` had {n} unsupported metric{?s} for
{.arg output_type} {.val {output_type}}: {.strong {.val {invalid_metrics}}};
supported metrics include {.val {valid_metrics}}.",
comment
9 changes: 7 additions & 2 deletions man/score_model_out.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-score_model_out.R
Original file line number Diff line number Diff line change
@@ -441,7 +441,7 @@ test_that("score_model_out errors when invalid metrics are requested", {
score_model_out(
model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "mean"),
target_observations = forecast_target_observations,
metrics = scoringutils::metrics_point(),
metrics = scoringutils::get_metrics(scoringutils::example_point),
by = c("model_id", "location")
),
regexp = "`metrics` must be either `NULL` or a character vector of supported metrics."
7 changes: 5 additions & 2 deletions tests/testthat/test-transform_point_model_out.R
Original file line number Diff line number Diff line change
@@ -140,6 +140,9 @@ test_that("hubExamples data set is transformed correctly", {
reference_date = as.Date(reference_date, "%Y-%m-%d"),
target_end_date = as.Date(target_end_date, "%Y-%m-%d")
)
class(exp_forecast) <- c("forecast_point", "forecast", "data.table", "data.frame")
expect_equal(act_forecast, exp_forecast)
expect_s3_class(
act_forecast,
c("forecast_point", "forecast", "data.table", "data.frame")
)
expect_equal(as.data.frame(act_forecast), as.data.frame(exp_forecast))
})
7 changes: 5 additions & 2 deletions tests/testthat/test-transform_quantile_model_out.R
Original file line number Diff line number Diff line change
@@ -89,6 +89,9 @@ test_that("hubExamples data set is transformed correctly", {
reference_date = as.Date(reference_date, "%Y-%m-%d"),
target_end_date = as.Date(target_end_date, "%Y-%m-%d")
)
class(exp_forecast) <- c("forecast", "forecast_quantile", "data.table", "data.frame")
expect_equal(act_forecast, exp_forecast, ignore_attr = "class")
expect_s3_class(
act_forecast,
c("forecast_quantile", "forecast", "data.table", "data.frame")
)
expect_equal(as.data.frame(act_forecast), as.data.frame(exp_forecast))
})