Skip to content

Commit

Permalink
Fix code after changes in scoringutils
Browse files Browse the repository at this point in the history
  • Loading branch information
nikosbosse committed Sep 16, 2024
1 parent a6cb8a8 commit 1dcec1d
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 22 deletions.
46 changes: 29 additions & 17 deletions R/score_model_out.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
#'
#' @details If `metrics` is `NULL` (the default), this function chooses
#' appropriate metrics based on the `output_type` contained in the `model_out_tbl`:
#' \itemize{
#' \item For `output_type == "quantile"`, we use the default metrics provided by
#' `scoringutils::metrics_quantile()`: `r names(scoringutils::metrics_quantile())`
#' \item For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating
#'
#' - For `output_type == "quantile"`, we use the default metrics provided by
#' `scoringutils`:
#' `r names(scoringutils::get_metrics(scoringutils::example_quantile))`
#' - For `output_type == "pmf"` and `output_type_id_order` is `NULL` (indicating
#' that the predicted variable is a nominal variable), we use the default metric
#' provided by `scoringutils::metrics_nominal()`,
#' `r names(scoringutils::metrics_nominal())`
#' \item For `output_type == "median"`, we use "ae_point"
#' \item For `output_type == "mean"`, we use "se_point"
#' }
#' provided by `scoringutils`:,
#' `r names(scoringutils::get_metrics(scoringutils::example_nominal))`
#' - For `output_type == "median"`, we use "ae_point"
#' - For `output_type == "mean"`, we use "se_point"
#'
#' Alternatively, a character vector of scoring metrics can be provided. In this
#' case, the following options are supported:
Expand All @@ -46,6 +46,9 @@
#' - `output_type == "pmf"`:
#' - "log_score": log score
#'
#' See [scoringutils::get_metrics()] for more details on the default meterics
#' used by `scoringutils`.
#'
#' @examplesIf requireNamespace("hubExamples", quietly = TRUE)
#' # compute WIS and interval coverage rates at 80% and 90% levels based on
#' # quantile forecasts, summarized by the mean score for each model
Expand Down Expand Up @@ -143,10 +146,10 @@ get_metrics <- function(metrics, output_type, output_type_id_order) {
#' @noRd
get_metrics_default <- function(output_type, output_type_id_order) {
metrics <- switch(output_type,
quantile = scoringutils::metrics_quantile(),
pmf = scoringutils::metrics_nominal(),
mean = scoringutils::metrics_point(select = "se_point"),
median = scoringutils::metrics_point(select = "ae_point"),
quantile = scoringutils::get_metrics(scoringutils::example_quantile),
pmf = scoringutils::get_metrics(scoringutils::example_nominal),
mean = scoringutils::get_metrics(scoringutils::example_point, select = "se_point"),
median = scoringutils::get_metrics(scoringutils::example_point, select = "ae_point"),
NULL # default
)
if (is.null(metrics)) {
Expand Down Expand Up @@ -199,7 +202,10 @@ get_metrics_character <- function(metrics, output_type) {
)
names(interval_metric_fns) <- interval_metrics

other_metric_fns <- scoringutils::metrics_quantile(select = other_metrics)
other_metric_fns <- scoringutils::get_metrics(
scoringutils::example_quantile,
select = other_metrics
)

metric_fns <- c(other_metric_fns, interval_metric_fns)[metrics]
metrics <- metric_fns
Expand All @@ -208,13 +214,19 @@ get_metrics_character <- function(metrics, output_type) {
invalid_metrics <- metrics[!metrics %in% valid_metrics]
error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type)

metrics <- scoringutils::metrics_nominal(select = metrics)
metrics <- scoringutils::get_metrics(
scoringutils::example_nominal,
select = metrics
)
} else if (output_type %in% c("median", "mean")) {
valid_metrics <- c("ae_point", "se_point")
invalid_metrics <- metrics[!metrics %in% valid_metrics]
error_if_invalid_metrics(valid_metrics, invalid_metrics, output_type)

metrics <- scoringutils::metrics_point(select = metrics)
metrics <- scoringutils::get_metrics(
scoringutils::example_point,
select = metrics
)
} else {
# we have already validated `output_type`, so this case should not be
# triggered; this case is just double checking in case we add something new
Expand All @@ -231,7 +243,7 @@ error_if_invalid_metrics <- function(valid_metrics, invalid_metrics, output_type
if (n > 0) {
cli::cli_abort(
c(
"`metrics` had {n} unsupported metric{?s} for
"`metrics` had {n} unsupported metric{?s} for
{.arg output_type} {.val {output_type}}: {.strong {.val {invalid_metrics}}};
supported metrics include {.val {valid_metrics}}.",
comment
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test-score_model_out.R
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ test_that("score_model_out errors when invalid metrics are requested", {
score_model_out(
model_out_tbl = forecast_outputs |> dplyr::filter(.data[["output_type"]] == "mean"),
target_observations = forecast_target_observations,
metrics = scoringutils::metrics_point(),
metrics = scoringutils::get_metrics(scoringutils::example_point),
by = c("model_id", "location")
),
regexp = "`metrics` must be either `NULL` or a character vector of supported metrics."
Expand All @@ -464,3 +464,4 @@ test_that("score_model_out errors when an unsupported output_type is provided",
regexp = "only supports the following types"
)
})

Check warning on line 467 in tests/testthat/test-score_model_out.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-score_model_out.R,line=467,col=1,[trailing_blank_lines_linter] Trailing blank lines are superfluous.
9 changes: 7 additions & 2 deletions tests/testthat/test-transform_point_model_out.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ test_that("hubExamples data set is transformed correctly", {
reference_date = as.Date(reference_date, "%Y-%m-%d"),
target_end_date = as.Date(target_end_date, "%Y-%m-%d")
)
class(exp_forecast) <- c("forecast_point", "forecast", "data.table", "data.frame")
expect_equal(act_forecast, exp_forecast)
expect_s3_class(
act_forecast,
c("forecast_point", "forecast", "data.table", "data.frame")
)
expect_equal(as.data.frame(act_forecast), as.data.frame(exp_forecast))
})

Check warning on line 149 in tests/testthat/test-transform_point_model_out.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-transform_point_model_out.R,line=149,col=1,[trailing_blank_lines_linter] Trailing blank lines are superfluous.

Check warning on line 150 in tests/testthat/test-transform_point_model_out.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-transform_point_model_out.R,line=150,col=1,[trailing_blank_lines_linter] Trailing blank lines are superfluous.
7 changes: 5 additions & 2 deletions tests/testthat/test-transform_quantile_model_out.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ test_that("hubExamples data set is transformed correctly", {
reference_date = as.Date(reference_date, "%Y-%m-%d"),
target_end_date = as.Date(target_end_date, "%Y-%m-%d")
)
class(exp_forecast) <- c("forecast", "forecast_quantile", "data.table", "data.frame")
expect_equal(act_forecast, exp_forecast, ignore_attr = "class")
expect_s3_class(
act_forecast,
c("forecast_quantile", "forecast", "data.table", "data.frame")
)
expect_equal(as.data.frame(act_forecast), as.data.frame(exp_forecast))
})

0 comments on commit 1dcec1d

Please sign in to comment.