Skip to content

Commit

Permalink
add tests for breslow PipeOp + some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
bblodfon committed Dec 22, 2023
1 parent 9128d65 commit f72b4d3
Showing 1 changed file with 39 additions and 7 deletions.
46 changes: 39 additions & 7 deletions tests/testthat/test_pipeop_distrcompositor.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@ test_that("PipeOpDistrCompositor - basic properties", {
expect_pipeop(PipeOpDistrCompositor$new())
})

task = tsk("rats")$filter(sample(300, 30))
set.seed(42)
task = tsk("rats")$filter(sample(300, 110))
cox_pred = lrn("surv.coxph")$train(task)$predict(task)

test_that("PipeOpDistrCompositor - overwrite = FALSE", {
gr = mlr3pipelines::ppl("distrcompositor", lrn("surv.kaplan", id = "k2"), overwrite = FALSE)
expect_silent(gr$train(task))
expect_equal(
gr$predict(task)[[1]]$data$distr,
lrn("surv.kaplan", id = "k2")$train(task)$predict(task)$data$distr)
lrn("surv.kaplan", id = "k2")$train(task)$predict(task)$data$distr
)

# breslow
gr = mlr3pipelines::ppl("distrcompositor", lrn("surv.coxph"),
estimator = "breslow", overwrite = FALSE)
expect_silent(gr$train(task))
expect_equal(
gr$predict(task)[[1]]$data$distr,
lrn("surv.coxph")$train(task)$predict(task)$data$distr
)
expect_equal(gr$predict(task)[[1]]$data$distr, cox_pred$data$distr)
})

test_that("PipeOpDistrCompositor - overwrite = TRUE", {
Expand All @@ -37,7 +37,7 @@ test_that("PipeOpDistrCompositor - overwrite = TRUE", {
overwrite = TRUE, graph_learner = TRUE)
expect_silent(gr$train(task))
surv_mat1 = gr$predict(task)$data$distr
surv_mat2 = lrn("surv.coxph")$train(task)$predict(task)$data$distr
surv_mat2 = cox_pred$data$distr
expect_false(all(surv_mat1 == surv_mat2)) # distr predictions changed (a bit)
})

Expand All @@ -47,3 +47,35 @@ test_that("no params", {
pod = mlr3pipelines::po("distrcompose", param_vals = list())
expect_silent(pod$predict(list(base = base, pred = pred)))
})

test_that("breslow PipeOp works", {
# learner is needed
expect_error(po("breslowcompose"), "is missing")

# learner needs to be of survival type
expect_error(po("breslowcompose", learner = lrn("classif.featureless")),
"must have task type")

# learner with lp predictions
learner = lrn("surv.coxph")
b1 = po("breslowcompose", learner = learner, breslow.overwrite = TRUE)
b2 = po("breslowcompose", learner = learner)

expect_pipeop(b1)
expect_pipeop(b2)
expect_equal(b1$id, learner$id)
expect_equal(b2$id, learner$id)
expect_true(b1$param_set$values$breslow.overwrite)
expect_false(b2$param_set$values$breslow.overwrite)

expect_silent(b1$train(list(task)))
expect_silent(b2$train(list(task)))
p1 = b1$predict(list(task))[[1L]]
p2 = b2$predict(list(task))[[1L]]

expect_equal(p1$lp, p2$lp)
surv_mat1 = p1$data$distr
surv_mat2 = p2$data$distr
expect_false(all(surv_mat1 == surv_mat2)) # distr predictions changed (a bit)
expect_true(all(surv_mat2 == cox_pred$data$distr)) # distr was not overwritten
})

0 comments on commit f72b4d3

Please sign in to comment.