Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: offset column role in Task #1225

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ Authors@R:
comment = c(ORCID = "0000-0002-8115-0400")),
person("Sebastian", "Fischer", , "[email protected]", role = "aut",
comment = c(ORCID = "0000-0002-9609-3197")),
person("Lona", "Koers", , "[email protected]", role = "ctb")
person("Lona", "Koers", , "[email protected]", role = "ctb"),
person("John", "Zobolas", , "[email protected]", role = "ctb",
comment = c(ORCID = "0000-0002-3609-8674"))
)
Description: Efficient, object-oriented programming on the
building blocks of machine learning. Provides 'R6' objects for tasks,
Expand Down
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# mlr3 (development version)

* feat: add new `col_role` offset in `Task` and offset `Learner` property.
A warning is produced if a learner that doesn't support offsets is trained with a task that has an offset column.
* fix: the `$predict_newdata()` method of `Learner` now automatically conducts type conversions (#685)
* BREAKING_CHANGE: Predicting on a `task` with the wrong column information is now an error and not a warning.
* BREAKING_CHANGE: Predicting on a `task` with the wrong column information is now an error and not a warning.
* Column names with UTF-8 characters are now allowed by default.
The option `mlr3.allow_utf8_names` is removed.
* BREAKING CHANGE: `Learner$predict_types` is read-only now.
Expand Down
53 changes: 48 additions & 5 deletions R/Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ Task = R6Class("Task",
}

# columns with these roles must be present in data
mandatory_roles = c("target", "feature", "weight", "group", "stratum", "order")
mandatory_roles = c("target", "feature", "weight", "group", "stratum", "order", "offset")
mandatory_cols = unlist(private$.col_roles[mandatory_roles], use.names = FALSE)
missing_cols = setdiff(mandatory_cols, data$colnames)
if (length(missing_cols)) {
Expand Down Expand Up @@ -896,6 +896,7 @@ Task = R6Class("Task",
#' * `"strata"`: The task is resampled using one or more stratification variables (role `"stratum"`).
#' * `"groups"`: The task comes with grouping/blocking information (role `"group"`).
#' * `"weights"`: The task comes with observation weights (role `"weight"`).
#' * `"offset"`: The task includes one or more offset columns specifying fixed adjustments for model training and possibly for prediction (role `"offset"`).
#' * `"ordered"`: The task has columns which define the row order (role `"order"`).
#'
#' Note that above listed properties are calculated from the `$col_roles` and may not be set explicitly.
Expand All @@ -907,6 +908,7 @@ Task = R6Class("Task",
if (length(col_roles$group)) "groups" else NULL,
if (length(col_roles$stratum)) "strata" else NULL,
if (length(col_roles$weight)) "weights" else NULL,
if (length(col_roles$offset)) "offset" else NULL,
if (length(col_roles$order)) "ordered" else NULL
)
} else {
Expand Down Expand Up @@ -951,6 +953,10 @@ Task = R6Class("Task",
#' Not more than a single column can be associated with this role.
#' * `"stratum"`: Stratification variables. Multiple discrete columns may have this role.
#' * `"weight"`: Observation weights. Not more than one numeric column may have this role.
#' * `"offset"`: Offset values specifying fixed adjustments for model training.
#' These values can be used to provide baseline predictions from an existing model for updating another model.
#' Some learners require an offset for each target class in a multiclass setting.
#' In this case, the offset columns must be named `"offset_{target_class_name}"`.
#'
#' `col_roles` is a named list whose elements are named by column role and each element is a `character()` vector of column names.
#' To alter the roles, just modify the list, e.g. with \R's set functions ([intersect()], [setdiff()], [union()], \ldots).
Expand Down Expand Up @@ -1084,6 +1090,23 @@ Task = R6Class("Task",
setnames(data, c("row_id", "weight"))[]
},

#' @field offset ([data.table::data.table()])\cr
#' Provides the offset column(s) if the task has a column designated with the role `"offset"`.
#'
#' For regression or binary classification tasks, this returns a single-column offset.
#' For multiclass tasks, it may return multiple offset columns, one for each target class.
#'
#' If there are no columns with the `"offset"` role, `NULL` is returned.
offset = function(rhs) {
assert_has_backend(self)
assert_ro_binding(rhs)
offset_cols = private$.col_roles$offset
if (length(offset_cols) == 0L) {
return(NULL)
}

self$backend$data(private$.row_roles$use, offset_cols)
},

#' @field labels (named `character()`)\cr
#' Retrieve `labels` (prettier formated names) from columns.
Expand Down Expand Up @@ -1250,6 +1273,17 @@ task_check_col_roles.Task = function(task, new_roles, ...) {
}
}

# check offset
if (length(new_roles[["offset"]]) && any(fget(task$col_info, new_roles[["offset"]], "type", key = "id") %nin% c("numeric", "integer"))) {
stopf("Offset column(s) %s must be a numeric or integer column", paste0("'", new_roles[["offset"]], "'", collapse = ","))
}

if (any(task$missings(cols = new_roles[["offset"]]) > 0)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use a shorter circuit (something like mlr3misc::some)

missings = task$missings(cols = new_roles[["offset"]])
missings = names(missings[missings > 0])
stopf("Offset column(s) %s contain missing values", paste0("'", missings, "'", collapse = ","))
}

return(new_roles)
}

Expand All @@ -1266,16 +1300,25 @@ task_check_col_roles.TaskClassif = function(task, new_roles, ...) {
stopf("Target column(s) %s must be a factor or ordered factor", paste0("'", new_roles[["target"]], "'", collapse = ","))
}

if (length(new_roles[["offset"]]) > 1L && length(task$class_names) == 2L) {
stop("There may only be up to one column with role 'offset' for binary classification tasks")
}

if (length(new_roles[["offset"]]) > 1L) {
expected_names = paste0("offset_", task$class_names)
expect_subset(new_roles[["offset"]], expected_names, label = "col_roles")
}

NextMethod()
}

#' @rdname task_check_col_roles
#' @export
task_check_col_roles.TaskRegr = function(task, new_roles, ...) {

# check target
if (length(new_roles[["target"]]) > 1L) {
stopf("There may only be up to one column with role 'target'")
for (role in c("target", "offset")) {
if (length(new_roles[[role]]) > 1L) {
stopf("There may only be up to one column with role '%s'", role)
}
}

if (length(new_roles[["target"]]) && any(fget(task$col_info, new_roles[["target"]], "type", key = "id") %nin% c("numeric", "integer"))) {
Expand Down
5 changes: 5 additions & 0 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ assert_task_learner = function(task, learner, cols = NULL) {
}
}

if ("offset" %in% task$properties && "offset" %nin% learner$properties) {
warningf("Task '%s' has offset, but learner '%s' does not support this, so it will be ignored",
task$id, learner$id)
}

tmp = mlr_reflections$task_mandatory_properties[[task$task_type]]
if (length(tmp)) {
tmp = setdiff(intersect(task$properties, tmp), learner$properties)
Expand Down
8 changes: 4 additions & 4 deletions R/mlr_reflections.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ local({
"use"
)

tmp = c("feature", "target", "name", "order", "stratum", "group", "weight")
tmp = c("feature", "target", "name", "order", "stratum", "group", "weight", "offset")
mlr_reflections$task_col_roles = list(
regr = tmp,
classif = tmp,
unsupervised = c("feature", "name", "order")
)

tmp = c("strata", "groups", "weights")
tmp = c("strata", "groups", "weights", "offset")
mlr_reflections$task_properties = list(
classif = c(tmp, "twoclass", "multiclass"),
regr = tmp,
Expand All @@ -114,11 +114,11 @@ local({

mlr_reflections$task_print_col_roles = list(
before = character(),
after = c("Order by" = "order", "Strata" = "stratum", "Groups" = "group", "Weights" = "weight")
after = c("Order by" = "order", "Strata" = "stratum", "Groups" = "group", "Weights" = "weight", "Offset" = "offset")
)

### Learner
tmp = c("featureless", "missings", "weights", "importance", "selected_features", "oob_error", "hotstart_forward", "hotstart_backward", "validation", "internal_tuning", "marshal")
tmp = c("featureless", "missings", "weights", "importance", "selected_features", "oob_error", "hotstart_forward", "hotstart_backward", "validation", "internal_tuning", "marshal", "offset")
mlr_reflections$learner_properties = list(
classif = c(tmp, "twoclass", "multiclass"),
regr = tmp
Expand Down
2 changes: 2 additions & 0 deletions man-roxygen/param_learner_properties.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
#' The following properties are currently standardized and understood by learners in \CRANpkg{mlr3}:
#' * `"missings"`: The learner can handle missing values in the data.
#' * `"weights"`: The learner supports observation weights.
#' * `"offset"`: The learner can incorporate offset values to adjust predictions.
#' * `"importance"`: The learner supports extraction of importance scores, i.e. comes with an `$importance()` extractor function (see section on optional extractors in [Learner]).
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@be-marc here hotstart_forward, hotstart_backward and featureless are not documented, but they should be perhaps?

#' * `"selected_features"`: The learner supports extraction of the set of selected features, i.e. comes with a `$selected_features()` extractor function (see section on optional extractors in [Learner]).
#' * `"oob_error"`: The learner supports extraction of estimated out of bag error, i.e. comes with a `oob_error()` extractor function (see section on optional extractors in [Learner]).
#' * `"validation"`: The learner can use a validation task during training.
#' * `"internal_tuning"`: The learner is able to internally optimize hyperparameters (those are also tagged with `"internal_tuning"`).
#' * `"marshal"`: To save learners with this property, you need to call `$marshal()` first.
#' If a learner is in a marshaled state, you call first need to call `$unmarshal()` to use its model, e.g. for prediction.
#'
1 change: 1 addition & 0 deletions man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/LearnerClassif.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/LearnerRegr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions man/Task.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr3-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion tests/testthat/test_Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,18 @@ test_that("stratify works", {
})

test_that("groups/weights work", {
b = as_data_backend(data.table(x = runif(20), y = runif(20), w = runif(20), g = sample(letters[1:2], 20, replace = TRUE)))
b = as_data_backend(data.table(x = runif(20), y = runif(20), w = runif(20),
o = runif(20), g = sample(letters[1:2], 20, replace = TRUE)))
task = TaskRegr$new("test", b, target = "y")
task$set_row_roles(16:20, character())

expect_false("groups" %chin% task$properties)
expect_false("weights" %chin% task$properties)
expect_false("offset" %chin% task$properties)
expect_null(task$groups)
expect_null(task$weights)

# weight
task$col_roles$weight = "w"
expect_subset("weights", task$properties)
expect_data_table(task$weights, ncols = 2, nrows = 15)
Expand All @@ -265,6 +268,7 @@ test_that("groups/weights work", {
task$col_roles$weight = character()
expect_true("weights" %nin% task$properties)

# group
task$col_roles$group = "g"
expect_subset("groups", task$properties)
expect_data_table(task$groups, ncols = 2, nrows = 15)
Expand Down Expand Up @@ -726,3 +730,4 @@ test_that("warn when internal valid task has 0 obs", {
task = tsk("iris")
expect_warning({task$internal_valid_task = 151}, "has 0 observations")
})

50 changes: 50 additions & 0 deletions tests/testthat/test_TaskClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,53 @@ test_that("target is encoded as factor (#629)", {
dt$target = ordered(dt$target)
TaskClassif$new(id = "XX", backend = dt, target = "target")
})

test_that("offset column role works with binary tasks", {
task = tsk("pima")
expect_null(task$offset)

task$set_col_roles("age", "offset")
expect_subset("offset", task$properties)
expect_data_table(task$offset, nrows = task$nrow, ncols = 1)
expect_subset("age", names(task$offset))

expect_error({
task$col_roles$offset = c("glucose", "diabetes")
}, "There may only be up to one column with role")

expect_error({
task$col_roles$offset = c("glucose")
}, "contain missing values")

expect_warning(lrn("classif.rpart")$train(task), "has offset")
})

test_that("offset column role works with multiclass tasks", {
task = tsk("penguins")
expect_null(task$offset)
task$set_col_roles("year", "offset")
expect_subset("offset", task$properties)
expect_data_table(task$offset, nrows = task$nrow, ncols = 1)
expect_subset("year", names(task$offset))

expect_error({
task$col_roles$offset = "bill_length"
}, "contain missing values")

task = tsk("wine")

expect_error({
task$col_roles$offset = c("alcohol", "ash")
}, "Must be a subset of")

task = tsk("wine")
data = task$data()
set(data, j = "offset_1", value = runif(nrow(data)))
set(data, j = "offset_2", value = runif(nrow(data)))
task = as_task_classif(data, target = "type")
task$set_col_roles(c("offset_1", "offset_2"), "offset")

expect_subset("offset", task$properties)
expect_data_table(task$offset, nrows = task$nrow, ncols = 2)
expect_subset(c("offset_1", "offset_2"), names(task$offset))
})
18 changes: 18 additions & 0 deletions tests/testthat/test_TaskRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,21 @@ test_that("$add_strata", {
task$add_strata(task$target_names, bins = 2)
expect_identical(task$strata$N, c(50L, 10L))
})

test_that("offset column role works", {
task = tsk("mtcars")
expect_null(task$offset)
task$set_col_roles("am", "offset")

expect_subset("offset", task$properties)
expect_data_table(task$offset, nrows = task$nrow, ncols = 1)
expect_subset("am", names(task$offset))

expect_error({
task$col_roles$offset = c("am", "gear")
}, "up to one")

task$col_roles$offset = character()
expect_true("offset" %nin% task$properties)
expect_null(task$offset)
})