diff --git a/R/score.R b/R/score.R index 2516d609..dc7199c7 100644 --- a/R/score.R +++ b/R/score.R @@ -349,11 +349,14 @@ validate_scores <- function(scores) { return(invisible(NULL)) } -##' @method `[` scores -##' @export +#' @method `[` scores +#' @importFrom data.table setattr +#' @export `[.scores` <- function(x, ...) { ret <- NextMethod() - if (is.data.frame(ret)) { + if (is.data.table(ret)) { + setattr(ret, "metrics", attr(x, "metrics")) + } else if (is.data.frame(ret)) { attr(ret, "metrics") <- attr(x, "metrics") } return(ret) diff --git a/tests/testthat/test-score.R b/tests/testthat/test-score.R index 2b41c76a..fb25f004 100644 --- a/tests/testthat/test-score.R +++ b/tests/testthat/test-score.R @@ -61,6 +61,14 @@ test_that("function throws an error if data is not a forecast object", { # expect_warning(suppressMessages(score(forecast = data))) # }) +test_that("Manipulating scores objects with .[ works as expected", { + expect_no_condition(scores_point[1:10]) + + expect_no_condition(scores_point[, .(model, ae_point)]) + + ex <- score(example_quantile) + expect_no_condition(ex[, extra_col := "something"]) +}) # test binary case -------------------------------------------------------------