Skip to content

Commit

Permalink
some partial progress on supporting ordinal forecasts; need to addres…
Browse files Browse the repository at this point in the history
…s numerical issues
  • Loading branch information
elray1 committed Dec 10, 2024
1 parent fbd8700 commit 2e6d532
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 39 deletions.
61 changes: 51 additions & 10 deletions R/transform_pmf_model_out.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ transform_pmf_model_out <- function(model_out_tbl, oracle_output, output_type_id

# validate or set output_type_id_order
if (!is.null(output_type_id_order)) {
cli::cli_abort(
"ordinal variables are not yet supported."
)
output_type_id_order <- validate_output_type_id_order(output_type_id_order, model_out_tbl)
is_ordinal <- TRUE
} else {
is_ordinal <- FALSE
Expand All @@ -53,13 +51,56 @@ transform_pmf_model_out <- function(model_out_tbl, oracle_output, output_type_id
dplyr::ungroup() |>
dplyr::select(-dplyr::all_of("oracle_value"))

forecast_pmf <- scoringutils::as_forecast_nominal(
data,
forecast_unit = c("model", task_id_cols),
observed = "observation",
predicted = "value",
predicted_label = "output_type_id"
)
if (is_ordinal) {
forecast_pmf <- scoringutils::as_forecast_ordinal(
data,
forecast_unit = c("model", task_id_cols),
observed = "observation",
predicted = "value",
predicted_label = "output_type_id"
)
} else {
forecast_pmf <- scoringutils::as_forecast_nominal(
data,
forecast_unit = c("model", task_id_cols),
observed = "observation",
predicted = "value",
predicted_label = "output_type_id"
)
}

return(forecast_pmf)
}


#' Validate `output_type_id_order` for ordinal variables:
#' - Must be a vector
#' - Must be (set-)equal to the set of all unique `output_type_id` values in `model_out_tbl`
validate_output_type_id_order <- function(output_type_id_order, model_out_tbl) {
if (!is.vector(output_type_id_order)) {
cli::cli_abort("`output_type_id_order` must be a vector.")
}

present_levels <- unique(model_out_tbl$output_type_id)
extra_order_levels <- setdiff(output_type_id_order, present_levels)
missing_order_levels <- setdiff(present_levels, output_type_id_order)
if (length(extra_order_levels) > 0 || length(missing_order_levels) > 0) {
cli::cli_abort(
c(
"`output_type_id_order` must align with the set of all unique `output_type_id` values in `model_out_tbl`.",
ifelse(
length(extra_order_levels) == 0, NULL,
"The following levels were present in `output_type_id_order` but not in `model_out_tbl`:
{.val {extra_order_levels}}."
),
ifelse(
length(missing_order_levels) == 0, NULL,
"The following levels were present in `model_out_tbl` but not in `output_type_id_order`:
{.val {missing_order_levels}};"
)
)
)
}

return(output_type_id_order)
}
9 changes: 4 additions & 5 deletions tests/testthat/helper-pmf_test_oracle_output.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ pmf_test_oracle_output <- function() {
18282, 18289, 18289, 18296, 18296), class = "Date"), output_type_id = c("cat",
"cat", "cat", "cat", "cat", "cat", "cat", "cat", "dog", "dog",
"dog", "dog", "dog", "dog", "dog", "dog", "bird", "bird", "bird",
"bird", "bird", "bird", "bird", "bird"), observation = c(1, NA,
NA, NA, 1, NA, NA, 1, NA, NA, 1, 1, NA, 1, NA, NA, NA, 1, NA,
NA, NA, NA, 1, NA), oracle_value = c(1, 0, 0, 0, 1, 0, 0, 1,
0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)), row.names = c(NA,
-24L), class = c("tbl_df", "tbl", "data.frame"))
"bird", "bird", "bird", "bird", "bird"), oracle_value = c(1,
0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0,
1, 0)), row.names = c(NA, -24L), class = c("tbl_df", "tbl", "data.frame"
))
}
# nolint end
16 changes: 0 additions & 16 deletions tests/testthat/helper-pmf_test_target_observations.R

This file was deleted.

70 changes: 70 additions & 0 deletions tests/testthat/test-score_model_out.R
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,76 @@ test_that("score_model_out succeeds with valid inputs: nominal pmf output_type,
})


test_that("score_model_out succeeds with valid inputs: ordinal pmf output_type, default metrics, custom by", {
# Forecast data from HubExamples: <https://hubverse-org.github.io/hubExamples/reference/forecast_data.html>
forecast_outputs <- hubex_forecast_outputs()
forecast_oracle_output <- hubex_forecast_oracle_output()

act_scores <- score_model_out(
model_out_tbl = forecast_outputs |>
dplyr::filter(.data[["output_type"]] == "pmf"), #|>
# dplyr::group_by(model_id, reference_date, target, horizon, location, target_end_date) |>

Check warning on line 375 in tests/testthat/test-score_model_out.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-score_model_out.R,line=375,col=6,[indentation_linter] Indentation should be 4 spaces but is 6 spaces.
# dplyr::mutate(value = value / sum(value)) |>
# dplyr::ungroup(),
oracle_output = forecast_oracle_output,
by = c("model_id", "location"),
output_type_id_order = c("low", "moderate", "high", "very high")
)

exp_log_scores <- forecast_outputs |>
dplyr::filter(.data[["output_type"]] == "pmf") |>
dplyr::left_join(
forecast_oracle_output |>
dplyr::filter(.data[["output_type"]] == "pmf") |>
dplyr::select(-dplyr::all_of(c("output_type"))),
by = c("location", "target_end_date", "target", "output_type_id")
) |>
dplyr::filter(.data[["oracle_value"]] == 1) |>
dplyr::mutate(
log_score = -1 * log(.data[["value"]])
) |>
dplyr::group_by(dplyr::across(dplyr::all_of(
c("model_id", "location")
))) |>
dplyr::summarize(
log_score = mean(.data[["log_score"]]),
.groups = "drop"
)

exp_rps_scores <- forecast_outputs |>
dplyr::filter(.data[["output_type"]] == "pmf") |>
dplyr::left_join(
forecast_oracle_output |>
dplyr::filter(.data[["output_type"]] == "pmf") |>
dplyr::select(-dplyr::all_of(c("output_type"))),
by = c("location", "target_end_date", "target", "output_type_id")
) |>
dplyr::group_by(dplyr::across(dplyr::all_of(
c("model_id", "location", "reference_date", "horizon", "target_end_date", "target")
))) |>
dplyr::mutate(
log_score = -1 * log(.data[["value"]])
) |>
dplyr::group_by(dplyr::across(dplyr::all_of(
c("model_id", "location")
))) |>
dplyr::summarize(
log_score = mean(.data[["log_score"]]),
.groups = "drop"
)

# same column names, number of rows, and score values
expect_equal(colnames(act_scores), colnames(exp_scores))
expect_equal(nrow(act_scores), nrow(exp_scores))
merged_scores <- dplyr::full_join(
act_scores, exp_scores,
by = c("model_id", "location")
)
expect_equal(nrow(act_scores), nrow(merged_scores))
expect_equal(merged_scores$ae_point.x, merged_scores$ae_point.y)
})


test_that("score_model_out errors when model_out_tbl has multiple output_types", {
# Forecast data from HubExamples: <https://hubverse-org.github.io/hubExamples/reference/forecast_data.html>
forecast_outputs <- hubex_forecast_outputs()
Expand Down
34 changes: 28 additions & 6 deletions tests/testthat/test-transform_pmf_model_out.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# By construction, model_out_tbl and oracle_output are valid inputs to transform_pmf_model_out,
# and exp_forecast is the expected return value from transform_pmf_model_out

test_that("transform_pmf_model_out succeeds with valid inputs", {
test_that("transform_pmf_model_out succeeds with valid inputs -- nominal", {
model_out_tbl <- pmf_test_model_out_tbl()
oracle_output <- pmf_test_oracle_output()
exp_forecast <- pmf_test_exp_forecast()
Expand All @@ -20,21 +20,43 @@ test_that("transform_pmf_model_out succeeds with valid inputs", {

expect_equal(act_forecast, exp_forecast)
})
test_that("output_type_id_order is unsupported (fix when supported)", {


test_that("transform_pmf_model_out succeeds with valid inputs -- ordinal", {
model_out_tbl <- pmf_test_model_out_tbl()
oracle_output <- pmf_test_oracle_output()
exp_forecast <- pmf_test_exp_forecast() |> dplyr::mutate(
predicted_label = factor(predicted_label, levels = c("cat", "dog", "bird"), ordered = TRUE),
observed = factor(observed, levels = c("cat", "dog", "bird"), ordered = TRUE)
)
class(exp_forecast) <- c("forecast_ordinal", "forecast", "data.table", "data.frame")


act_forecast <- transform_pmf_model_out(
model_out_tbl = model_out_tbl,
oracle_output = oracle_output,
output_type_id_order = c("cat", "dog", "bird")
)

expect_equal(act_forecast, exp_forecast)
})


test_that("transform_pmf_model_out throws an error with invalid output_type_id_order", {
model_out_tbl <- pmf_test_model_out_tbl()
oracle_output <- pmf_test_oracle_output()
exp_forecast <- pmf_test_exp_forecast()

expect_error(
act_forecast <- transform_pmf_model_out(
model_out_tbl = model_out_tbl,
oracle_output = oracle_output,
output_type_id_order = "excellence"
),
regexp = "not yet supported"
output_type_id_order = c("cat", "bird", "platypus"),
"`output_type_id_order` must align with the set of all unique `output_type_id` values in `model_out_tbl`."
)
)
})


test_that("transform_pmf_model_out doesn't depend on specific column names for task id variables", {
model_out_tbl <- pmf_test_model_out_tbl() |>
dplyr::rename(loc = location, date = target_date)
Expand Down
5 changes: 3 additions & 2 deletions tests/testthat/testdata/make_pmf_test_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ observed_categories <- data.frame(
oracle_output <- model_out_tbl |>
dplyr::distinct(location, target_date, output_type_id) |>
dplyr::left_join(observed_categories, by = c("location", "target_date", "output_type_id")) |>
dplyr::mutate(oracle_value = ifelse(is.na(observation), 0, 1))
dplyr::mutate(oracle_value = ifelse(is.na(observation), 0, 1)) |>
dplyr::select(-observation)

# check that observations sum to 1
oracle_output |>
group_by(location, target_date) |>
summarize(tot = sum(observation)) |>
summarize(tot = sum(oracle_value)) |>
pull(tot)

# create expected forecast output
Expand Down

0 comments on commit 2e6d532

Please sign in to comment.