Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,9 @@ import(palmerpenguins)
import(paradox)
importFrom(R6,R6Class)
importFrom(R6,is.R6)
importFrom(data.table,as.data.table)
importFrom(data.table,data.table)
importFrom(future,nbrOfWorkers)
importFrom(future,plan)
importFrom(graphics,plot)
importFrom(mlr3misc,clbk)
importFrom(mlr3misc,clbks)
importFrom(mlr3misc,mlr_callbacks)
importFrom(parallelly,availableCores)
importFrom(stats,contr.treatment)
importFrom(stats,model.frame)
Expand Down
10 changes: 9 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# mlr3 (development version)

* feat: Add `mirai` support for parallelization and encapsulation.
## New Features:

* `Task` got method `$materialize_view()` which can save memory after subsetting a task.
* Better input validation for:
* `Learner` fields.
* Various improvements to the documentation and logging output, including
examples for methods.
* Measure "oob_error" now works even without storing models during resampling.
* Added `mirai` support for parallelization and encapsulation.

# mlr3 1.1.0

Expand Down
110 changes: 73 additions & 37 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,46 +176,16 @@ Learner = R6Class("Learner",
#' This is an internal data structure which may change in the future.
state = NULL,

#' @template field_task_type
task_type = NULL,

#' @field feature_types (`character()`)\cr
#' Stores the feature types the learner can handle, e.g. `"logical"`, `"numeric"`, or `"factor"`.
#' A complete list of candidate feature types, grouped by task type, is stored in [`mlr_reflections$task_feature_types`][mlr_reflections].
feature_types = NULL,

#' @field properties (`character()`)\cr
#' Stores a set of properties/capabilities the learner has.
#' A complete list of candidate properties, grouped by task type, is stored in [`mlr_reflections$learner_properties`][mlr_reflections].
properties = NULL,


#' @template field_packages
packages = NULL,

#' @template field_predict_sets
predict_sets = "test",

#' @field parallel_predict (`logical(1)`)\cr
#' If set to `TRUE`, use \CRANpkg{future} to calculate predictions in parallel (default: `FALSE`).
#' The row ids of the `task` will be split into [future::nbrOfWorkers()] chunks,
#' and predictions are evaluated according to the active [future::plan()].
#' This currently only works for methods `Learner$predict()` and `Learner$predict_newdata()`,
#' and has no effect during [resample()] or [benchmark()] where you have other means
#' to parallelize.
#'
#' Note that the recorded time required for prediction reports the time required to predict
#' is not properly defined and depends on the parallelization backend.
parallel_predict = FALSE,

#' @field timeout (named `numeric(2)`)\cr
#' Timeout for the learner's train and predict steps, in seconds.
#' This works differently for different encapsulation methods, see
#' [mlr3misc::encapsulate()].
#' Default is `c(train = Inf, predict = Inf)`.
#' Also see the section on error handling the mlr3book:
#' \url{https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling}
timeout = c(train = Inf, predict = Inf),

#' @template field_man
man = NULL,

Expand All @@ -228,12 +198,12 @@ Learner = R6Class("Learner",

self$id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
private$.task_type = assert_choice(task_type, mlr_reflections$task_types$type)
self$feature_types = assert_ordered_set(feature_types, mlr_reflections$task_feature_types, .var.name = "feature_types")
private$.predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
empty.ok = FALSE, .var.name = "predict_types")
private$.predict_type = predict_types[1L]
self$properties = sort(assert_subset(properties, mlr_reflections$learner_properties[[task_type]]))
private$.properties = sort(assert_subset(properties, mlr_reflections$learner_properties[[task_type]]))
self$packages = union("mlr3", assert_character(packages, any.missing = FALSE, min.chars = 1L))
self$man = assert_string(man, na.ok = TRUE)

Expand Down Expand Up @@ -490,10 +460,10 @@ Learner = R6Class("Learner",
}

prevci = task$col_info
task$backend = newdata
task$col_info = col_info(task$backend)
task$col_info[, c("label", "fix_factor_levels")] = prevci[list(task$col_info$id), on = "id", c("label", "fix_factor_levels")]
task$col_info$fix_factor_levels[is.na(task$col_info$fix_factor_levels)] = FALSE
task$.__enclos_env__$private$.backend = newdata
task$.__enclos_env__$private$.col_info = col_info(task$backend)
task$.__enclos_env__$private$.col_info[, c("label", "fix_factor_levels")] = prevci[list(task$col_info$id), on = "id", c("label", "fix_factor_levels")]
task$.__enclos_env__$private$.col_info$fix_factor_levels[is.na(task$.__enclos_env__$private$.col_info$fix_factor_levels)] = FALSE
task$row_roles$use = task$backend$rownames
task_col_roles = task$col_roles
update_col_roles = FALSE
Expand Down Expand Up @@ -680,6 +650,67 @@ Learner = R6Class("Learner",
),

active = list(
#' @template field_task_type
task_type = function(rhs) {
if (!missing(rhs)) {
warn_deprecated("task_type will soon be read-only.")
private$.properties = rhs
}
private$.task_type
},

#' @field properties (`character()`)\cr
#' Stores a set of properties/capabilities the learner has.
#' A complete list of candidate properties, grouped by task type, is stored in [`mlr_reflections$learner_properties`][mlr_reflections].
properties = function(rhs) {
if (!missing(rhs)) {
warn_deprecated("properties will soon be read-only.")
private$.properties = rhs
}
private$.properties
},

#' @template field_predict_sets
predict_sets = function(rhs) {
if (missing(rhs)) {
return(private$.predict_sets)
}
assert_subset(rhs, mlr_reflections$predict_sets)
private$.predict_sets = rhs
},

#' @field parallel_predict (`logical(1)`)\cr
#' If set to `TRUE`, use \CRANpkg{future} to calculate predictions in parallel (default: `FALSE`).
#' The row ids of the `task` will be split into [future::nbrOfWorkers()] chunks,
#' and predictions are evaluated according to the active [future::plan()].
#' This currently only works for methods `Learner$predict()` and `Learner$predict_newdata()`,
#' and has no effect during [resample()] or [benchmark()] where you have other means
#' to parallelize.
#'
#' Note that the recorded time required for prediction reports the time required to predict
#' is not properly defined and depends on the parallelization backend.
parallel_predict = function(rhs) {
if (missing(rhs)) {
return(private$.parallel_predict)
}
private$.parallel_predict = assert_flag(rhs)
},

#' @field timeout (named `numeric(2)`)\cr
#' Timeout for the learner's train and predict steps, in seconds.
#' This works differently for different encapsulation methods, see
#' [mlr3misc::encapsulate()].
#' Default is `c(train = Inf, predict = Inf)`.
#' Also see the section on error handling the mlr3book:
#' \url{https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling}
timeout = function(rhs) {
if (missing(rhs)) {
return(private$.timeout)
}
assert_permutation(names(rhs), c("train", "predict"))
private$.timeout = assert_numeric(rhs, lower = 0, any.missing = FALSE, len = 2L)
},

#' @field use_weights (`character(1)`)\cr
#' How weights should be handled.
#' Settings are `"use"` `"ignore"`, and `"error"`.
Expand Down Expand Up @@ -840,6 +871,11 @@ Learner = R6Class("Learner",
),

private = list(
.predict_sets = "test",
.task_type = NULL,
.properties = NULL,
.parallel_predict = FALSE,
.timeout = c(train = Inf, predict = Inf),
.use_weights = NULL,
.encapsulation = c(train = "none", predict = "none"),
.fallback = NULL,
Expand Down
71 changes: 50 additions & 21 deletions R/Resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,24 +106,7 @@ Resampling = R6Class("Resampling",
#' `$train_set()` and `$test_set()`.
instance = NULL,

#' @field task_hash (`character(1)`)\cr
#' The hash of the [Task] which was passed to `r$instantiate()`.
task_hash = NA_character_,

#' @field task_row_hash (`character(1)`)\cr
#' The hash of the row ids of the [Task] which was passed to `r$instantiate()`.
task_row_hash = NA_character_,

#' @field task_nrow (`integer(1)`)\cr
#' The number of observations of the [Task] which was passed to `r$instantiate()`.
#'
task_nrow = NA_integer_,

#' @field duplicated_ids (`logical(1)`)\cr
#' If `TRUE`, duplicated rows can occur within a single training set or within a single test set.
#' E.g., this is `TRUE` for Bootstrap, and `FALSE` for cross-validation.
#' Only used internally.
duplicated_ids = NULL,

#' @template field_man
man = NULL,
Expand All @@ -139,7 +122,7 @@ Resampling = R6Class("Resampling",
private$.id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$param_set = assert_param_set(param_set)
self$duplicated_ids = assert_flag(duplicated_ids)
private$.duplicated_ids = assert_flag(duplicated_ids)
self$man = assert_string(man, na.ok = TRUE)
},

Expand Down Expand Up @@ -188,9 +171,9 @@ Resampling = R6Class("Resampling",
task = assert_task(as_task(task))
private$.hash = NULL
self$instance = private$.get_instance(task)
self$task_hash = task$hash
self$task_row_hash = task$row_hash
self$task_nrow = task$nrow
private$.task_hash = task$hash
private$.task_row_hash = task$row_hash
private$.task_nrow = task$nrow
invisible(self)
},

Expand Down Expand Up @@ -256,6 +239,48 @@ Resampling = R6Class("Resampling",
}

private$.hash
},

#' @field task_hash (`character(1)`)\cr
#' The hash of the [Task] which was passed to `r$instantiate()`.
task_hash = function(rhs) {
if (!missing(rhs)) {
warn_deprecated("task_hash will soon be read-only.")
private$.task_hash = rhs
}
private$.task_hash
},

#' @field task_row_hash (`character(1)`)\cr
#' The hash of the row ids of the [Task] which was passed to `r$instantiate()`.
task_row_hash = function(rhs) {
if (!missing(rhs)) {
warn_deprecated("task_row_hash will soon be read-only.")
private$.task_row_hash = rhs
}
private$.task_row_hash
},

#' @field task_nrow (`integer(1)`)\cr
#' The number of observations of the [Task] which was passed to `r$instantiate()`.
task_nrow = function(rhs) {
if (!missing(rhs)) {
warn_deprecated("task_nrow will soon be read-only.")
private$.task_nrow = rhs
}
private$.task_nrow
},

#' @field duplicated_ids (`logical(1)`)\cr
#' If `TRUE`, duplicated rows can occur within a single training set or within a single test set.
#' E.g., this is `TRUE` for Bootstrap, and `FALSE` for cross-validation.
#' Only used internally.
duplicated_ids = function(rhs) {
if (!missing(rhs)) {
warn_deprecated("duplicated_ids will soon be read-only.")
private$.duplicated_ids = rhs
}
private$.duplicated_ids
}
),

Expand All @@ -264,6 +289,10 @@ Resampling = R6Class("Resampling",
.id = NULL,
.hash = NULL,
.groups = NULL,
.task_hash = NA_character_,
.task_row_hash = NA_character_,
.task_nrow = NA_integer_,
.duplicated_ids = NULL,

.get_instance = function(task) {
strata = task$strata
Expand Down
6 changes: 3 additions & 3 deletions R/ResamplingCustom.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ ResamplingCustom = R6Class("ResamplingCustom", inherit = Resampling,
assert_subset(unlist(train_sets, use.names = FALSE), task$row_ids)
assert_subset(unlist(test_sets, use.names = FALSE), task$row_ids)
self$instance = list(train = train_sets, test = test_sets)
self$task_hash = task$hash
self$task_nrow = task$nrow
self$task_row_hash = task$row_hash
private$.task_hash = task$hash
private$.task_nrow = task$nrow
private$.task_row_hash = task$row_hash
invisible(self)
}
),
Expand Down
6 changes: 3 additions & 3 deletions R/ResamplingCustomCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ ResamplingCustomCV = R6Class("ResamplingCustomCV", inherit = Resampling,
}

self$instance = split(task$row_ids, f, drop = TRUE)
self$task_hash = task$hash
self$task_nrow = task$nrow
self$task_row_hash = task$row_hash
private$.task_hash = task$hash
private$.task_nrow = task$nrow
private$.task_row_hash = task$row_hash
invisible(self)
}
),
Expand Down
Loading