Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
850c506
remotes
mb706 Aug 9, 2025
326787b
first pass
mb706 Aug 12, 2025
753401a
tasks
mb706 Aug 12, 2025
9314474
progress
mb706 Aug 12, 2025
a574630
tests pass locally
mb706 Aug 13, 2025
ca6cead
error on deprecated
mb706 Aug 13, 2025
4699248
initialization without man / label
mb706 Aug 13, 2025
7919314
initialization without man / label II
mb706 Aug 13, 2025
186e845
Merge branch 'main' into common_baseclass
mb706 Aug 13, 2025
b134320
document
mb706 Aug 13, 2025
7076576
double mlr3misc dependency
mb706 Aug 13, 2025
21e768b
skip backwards compatibility check for now
mb706 Aug 13, 2025
f208959
Merge branch 'main' into common_baseclass
mb706 Aug 14, 2025
6e9ed76
temp fix
mb706 Aug 14, 2025
0ec6341
mlr3data dev version
mb706 Aug 14, 2025
524e28b
Merge branch 'main' into common_baseclass
mb706 Aug 14, 2025
ec2d9e7
diagnostics
mb706 Aug 14, 2025
dc692ed
Merge branch 'common_baseclass' of ssh://github.com/mlr-org/mlr3 into…
mb706 Aug 14, 2025
22796f4
wrong class name
mb706 Aug 14, 2025
d7aacfb
deprecation message rename
mb706 Aug 15, 2025
2a02f61
using autotest
mb706 Aug 16, 2025
6048f0f
Merge branch 'main' into common_baseclass
mb706 Aug 16, 2025
74e5f08
additional configuration
mb706 Aug 16, 2025
bb21daf
Merge branch 'common_baseclass' of ssh://github.com/mlr-org/mlr3 into…
mb706 Aug 16, 2025
7939dff
add 'when' to Learner additional_phash_input
mb706 Aug 16, 2025
7312761
rename autotest
mb706 Aug 16, 2025
b924423
additional_configuration
mb706 Aug 16, 2025
f02b64b
measure autotests
mb706 Aug 16, 2025
8ecdba3
autotests resampling
mb706 Aug 16, 2025
c3c975c
testgenerator autotests
mb706 Aug 16, 2025
54ada21
fix Measure regression
mb706 Aug 16, 2025
f799a70
some adjustments
mb706 Aug 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/dev-cmd-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ jobs:
fail-fast: false
matrix:
config:
- {os: ubuntu-latest, r: 'release', dev-package: 'mlr-org/mlr3misc'}
- {os: ubuntu-latest, r: 'release', dev-package: 'mlr-org/paradox'}

steps:
Expand Down
1 change: 0 additions & 1 deletion .lintr
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ linters: linters_with_defaults(
# the following setup changes/removes certain linters
assignment_linter = NULL, # do not force using <- for assignments
object_name_linter = object_name_linter(c("snake_case", "CamelCase")), # only allow snake case and camel case object names
cyclocomp_linter = NULL, # do not check function complexity
commented_code_linter = NULL, # allow code in comments
line_length_linter = line_length_linter(180L),
indentation_linter(indent = 2L, hanging_indent_style = "never")
Expand Down
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ Suggests:
rpart,
testthat (>= 3.2.0)
Remotes:
mlr-org/mlr3misc
mlr-org/mlr3misc@common_baseclass,
mlr-org/mlr3data@common_baseclass
Encoding: UTF-8
Config/testthat/edition: 3
Config/testthat/parallel: false
Expand Down
136 changes: 19 additions & 117 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,8 @@
#' @template seealso_learner
#' @export
Learner = R6Class("Learner",
inherit = Mlr3Component,
public = list(
#' @template field_id
id = NULL,

#' @template field_label
label = NA_character_,

#' @field state (`NULL` | named `list()`)\cr
#' Current (internal) state of the learner.
#' Contains all information gathered during `train()` and `predict()`.
Expand All @@ -184,14 +179,6 @@ Learner = R6Class("Learner",
#' A complete list of candidate feature types, grouped by task type, is stored in [`mlr_reflections$task_feature_types`][mlr_reflections].
feature_types = NULL,

#' @field properties (`character()`)\cr
#' Stores a set of properties/capabilities the learner has.
#' A complete list of candidate properties, grouped by task type, is stored in [`mlr_reflections$learner_properties`][mlr_reflections].
properties = NULL,

#' @template field_packages
packages = NULL,

#' @template field_predict_sets
predict_sets = "test",

Expand All @@ -216,42 +203,35 @@ Learner = R6Class("Learner",
#' \url{https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling}
timeout = c(train = Inf, predict = Inf),

#' @template field_man
man = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' Note that this object is typically constructed via a derived classes, e.g. [LearnerClassif] or [LearnerRegr].
initialize = function(id, task_type, param_set = ps(), predict_types = character(), feature_types = character(),
properties = character(), packages = character(), label = NA_character_, man = NA_character_) {
initialize = function(id, task_type, param_set = ps(), predict_types = character(0), feature_types = character(0),
properties = character(0), packages = character(0), additional_configuration = character(0), label, man) {

if (!missing(label) || !missing(man)) {
mlr3component_deprecation_msg("label and man are deprecated for Learner construction and will be removed in the future.")
}

super$initialize(dict_entry = id, dict_shortaccess = "lrn",
param_set = param_set, packages = packages, properties = properties,
additional_configuration = c("predict_sets", "parallel_predict", "timeout", "use_weights", "predict_type", "selected_features_impute",
if ("validate" %in% names(self)) "validate", additional_configuration)
)

self$id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
self$feature_types = assert_ordered_set(feature_types, mlr_reflections$task_feature_types, .var.name = "feature_types")
private$.predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
empty.ok = FALSE, .var.name = "predict_types")
private$.predict_type = predict_types[1L]
self$properties = sort(assert_subset(properties, mlr_reflections$learner_properties[[task_type]]))
self$packages = union("mlr3", assert_character(packages, any.missing = FALSE, min.chars = 1L))
self$man = assert_string(man, na.ok = TRUE)
assert_subset(properties, mlr_reflections$learner_properties[[task_type]])

if ("weights" %in% self$properties) {
self$use_weights = "use"
} else {
self$use_weights = "error"
}
private$.param_set = param_set

check_packages_installed(packages, msg = sprintf("Package '%%s' required but not installed for Learner '%s'", id))
},

#' @description
#' Helper for print outputs.
#' @param ... (ignored).
format = function(...) {
sprintf("<%s:%s>", class(self)[1L], self$id)
},

#' @description
Expand Down Expand Up @@ -295,12 +275,6 @@ Learner = R6Class("Learner",
}
},

#' @description
#' Opens the corresponding help page referenced by field `$man`.
help = function() {
open_help(self$man)
},

#' @description
#' Train the learner on a set of observations of the provided `task`.
#' Mutates the learner by reference, i.e. stores the model alongside other information in field `$state`.
Expand Down Expand Up @@ -623,54 +597,6 @@ Learner = R6Class("Learner",
return(invisible(self))
},

#' @description
#' Sets parameter values and fields of the learner.
#' All arguments whose names match the name of a parameter of the [paradox::ParamSet] are set as parameters.
#' All remaining arguments are assumed to be regular fields.
#'
#' @param ... (named `any`)\cr
#' Named arguments to set parameter values and fields.
#' @param .values (named `any`)\cr
#' Named list of parameter values and fields.
#' @examples
#' learner = lrn("classif.rpart")
#' learner$configure(minsplit = 3, parallel_predict = FALSE)
#' learner$configure(.values = list(cp = 0.005))
configure = function(..., .values = list()) {
dots = list(...)
assert_list(dots, names = "unique")
assert_list(.values, names = "unique")
assert_disjunct(names(dots), names(.values))
new_values = insert_named(dots, .values)

# set params in ParamSet
if (length(new_values)) {
param_ids = self$param_set$ids()
ii = names(new_values) %in% param_ids
if (any(ii)) {
self$param_set$values = insert_named(self$param_set$values, new_values[ii])
new_values = new_values[!ii]
}
} else {
param_ids = character()
}

# remaining args go into fields
if (length(new_values)) {
ndots = names(new_values)
for (i in seq_along(new_values)) {
nn = ndots[[i]]
if (!exists(nn, envir = self, inherits = FALSE)) {
stopf("Cannot set argument '%s' for '%s' (not a parameter, not a field).%s",
nn, class(self)[1L], did_you_mean(nn, c(param_ids, setdiff(names(self), ".__enclos_env__")))) # nolint
}
self[[nn]] = new_values[[i]]
}
}

return(invisible(self))
},

#' @description
#' Returns the features selected by the model.
#' The field `selected_features_impute` controls the behavior if the learner does not support feature selection.
Expand Down Expand Up @@ -757,23 +683,6 @@ Learner = R6Class("Learner",
get_log_condition(self$state, "error")
},

#' @field hash (`character(1)`)\cr
#' Hash (unique identifier) for this object.
#' The hash is calculated based on the learner id, the parameter settings, the predict type, the fallback hash, the parallel predict setting, the validate setting, and the predict sets.
hash = function(rhs) {
assert_ro_binding(rhs)
calculate_hash(class(self), self$id, self$param_set$values, private$.predict_type,
self$fallback$hash, self$parallel_predict, get0("validate", self), self$predict_sets, private$.use_weights)
},

#' @field phash (`character(1)`)\cr
#' Hash (unique identifier) for this partial object, excluding some components which are varied systematically during tuning (parameter values).
phash = function(rhs) {
assert_ro_binding(rhs)
calculate_hash(class(self), self$id, private$.predict_type,
self$fallback$hash, self$parallel_predict, get0("validate", self), private$.use_weights)
},

#' @field predict_type (`character(1)`)\cr
#' Stores the currently active predict type, e.g. `"response"`.
#' Must be an element of `$predict_types`.
Expand All @@ -791,14 +700,6 @@ Learner = R6Class("Learner",
private$.predict_type = rhs
},

#' @template field_param_set
param_set = function(rhs) {
if (!missing(rhs) && !identical(rhs, private$.param_set)) {
stopf("param_set is read-only.")
}
private$.param_set
},

#' @field fallback ([Learner])\cr
#' Returns the fallback learner set with `$encapsulate()`.
fallback = function(rhs) {
Expand Down Expand Up @@ -851,7 +752,6 @@ Learner = R6Class("Learner",
.fallback = NULL,
.predict_type = NULL,
.predict_types = NULL,
.param_set = NULL,
.hotstart_stack = NULL,
.selected_features_impute = "error",

Expand All @@ -871,18 +771,20 @@ Learner = R6Class("Learner",
}
},

.additional_phash_input = function() {
list(private$.predict_type, self$fallback$hash, self$parallel_predict, get0("validate", self), self$predict_sets, private$.use_weights, private$.when)
},

deep_clone = function(name, value) {
switch(name,
.param_set = value$clone(deep = TRUE),
.fallback = if (is.null(value)) NULL else value$clone(deep = TRUE),
state = {
if (!is.null(value$train_task)) {
value$train_task = value$train_task$clone(deep = TRUE)
}
value$log = copy(value$log)
value
},
value
super$deep_clone(name, value)
)
}
)
Expand Down
12 changes: 9 additions & 3 deletions R/LearnerClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,16 @@ LearnerClassif = R6Class("LearnerClassif", inherit = Learner,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, param_set = ps(), predict_types = "response", feature_types = character(), properties = character(), packages = character(), label = NA_character_, man = NA_character_) {
initialize = function(id, param_set = ps(), predict_types = "response",
feature_types = character(), properties = character(), packages = character(),
additional_configuration = character(0), label, man
) {
if (!missing(label) || !missing(man)) {
mlr3component_deprecation_msg("label and man are deprecated for Learner construction and will be removed in the future.")
}

super$initialize(id = id, task_type = "classif", param_set = param_set, predict_types = predict_types,
feature_types = feature_types, properties = properties, packages = packages,
label = label, man = man)
feature_types = feature_types, properties = properties, packages = packages, additional_configuration = additional_configuration)

if (getOption("mlr3.prob_as_default", FALSE) && "prob" %in% self$predict_types) {
self$predict_type = "prob"
Expand Down
4 changes: 1 addition & 3 deletions R/LearnerClassifDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
param_set = param_set,
feature_types = c("logical", "integer", "numeric", "character", "factor", "ordered"),
predict_types = c("response", "prob"),
properties = c("twoclass", "multiclass", "missings", "hotstart_forward", "validation", "internal_tuning", "marshal", "weights"),
man = "mlr3::mlr_learners_classif.debug",
label = "Debug Learner for Classification"
properties = c("twoclass", "multiclass", "missings", "hotstart_forward", "validation", "internal_tuning", "marshal", "weights")
)
},
#' @description
Expand Down
4 changes: 1 addition & 3 deletions R/LearnerClassifFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ LearnerClassifFeatureless = R6Class("LearnerClassifFeatureless", inherit = Learn
feature_types = mlr_reflections$task_feature_types,
predict_types = c("response", "prob"),
param_set = ps,
properties = c("featureless", "twoclass", "multiclass", "missings", "importance", "selected_features", "weights"),
label = "Featureless Classification Learner",
man = "mlr3::mlr_learners_classif.featureless",
properties = c("featureless", "twoclass", "multiclass", "missings", "importance", "selected_features", "weights")
)
},

Expand Down
4 changes: 1 addition & 3 deletions R/LearnerClassifRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
feature_types = c("logical", "integer", "numeric", "factor", "ordered"),
predict_types = c("response", "prob"),
param_set = ps,
properties = c("twoclass", "multiclass", "weights", "missings", "importance", "selected_features"),
label = "Classification Tree",
man = "mlr3::mlr_learners_classif.rpart"
properties = c("twoclass", "multiclass", "weights", "missings", "importance", "selected_features")
)
},

Expand Down
21 changes: 14 additions & 7 deletions R/LearnerRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
#' - `"se"`: Predicts the standard error for each value of response for each observation in the test set.
#' - `"distr"`: Probability distribution as `VectorDistribution` object (requires package `distr6`, available via
#' repository \url{https://raphaels1.r-universe.dev}).
#' - `"quantiles"`: Predicts quantile estimates for each observation in the test set.
#' Set `$quantiles` to specify the quantiles to predict and `$quantile_response` to specify the response quantile.
#' See mlr3book [section](https://mlr3book.mlr-org.com/chapters/chapter13/beyond_regression_and_classification.html#sec-quantile-regression) on quantile regression for more details.
#' - `"quantiles"`: Predicts quantile estimates for each observation in the test set.
#' Set `$quantiles` to specify the quantiles to predict and `$quantile_response` to specify the response quantile.
#' See mlr3book [section](https://mlr3book.mlr-org.com/chapters/chapter13/beyond_regression_and_classification.html#sec-quantile-regression) on quantile regression for more details.
#'
#' Predefined learners can be found in the [dictionary][mlr3misc::Dictionary] [mlr_learners].
#' Essential regression learners can be found in this dictionary after loading \CRANpkg{mlr3learners}.
Expand Down Expand Up @@ -44,10 +44,17 @@ LearnerRegr = R6Class("LearnerRegr", inherit = Learner,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, task_type = "regr", param_set = ps(), predict_types = "response", feature_types = character(), properties = character(), packages = character(), label = NA_character_, man = NA_character_) {
initialize = function(dict_entry, id = dict_entry, task_type = "regr",
param_set = ps(), predict_types = "response", feature_types = character(), properties = character(), packages = character(),
additional_configuration = character(0), label, man
) {
if (!missing(label) || !missing(man)) {
mlr3component_deprecation_msg("label and man are deprecated for Learner construction and will be removed in the future.")
}

super$initialize(id = id, task_type = task_type, param_set = param_set, feature_types = feature_types,
predict_types = predict_types, properties = properties, packages = packages,
label = label, man = man)
additional_configuration = c("quantiles", "quantile_response", additional_configuration))
},

#' @description
Expand Down Expand Up @@ -136,7 +143,7 @@ LearnerRegr = R6Class("LearnerRegr", inherit = Learner,
if ("quantiles" %nin% self$predict_types) {
stopf("Learner does not support predicting quantiles")
}
private$.quantiles = assert_numeric(rhs, lower = 0, upper = 1, any.missing = FALSE, min.len = 1L, sorted = TRUE, .var.name = "quantiles")
private$.quantiles = assert_numeric(rhs, lower = 0, upper = 1, any.missing = FALSE, min.len = 1L, sorted = TRUE, .var.name = "quantiles", null.ok = TRUE)

if (length(private$.quantiles) == 1) {
private$.quantile_response = private$.quantiles
Expand All @@ -154,7 +161,7 @@ LearnerRegr = R6Class("LearnerRegr", inherit = Learner,
stopf("Learner does not support predicting quantiles")
}

private$.quantile_response = assert_number(rhs, lower = 0, upper = 1, .var.name = "response")
private$.quantile_response = assert_number(rhs, lower = 0, upper = 1, .var.name = "response", null.ok = TRUE)
private$.quantiles = sort(union(private$.quantiles, private$.quantile_response))
}
),
Expand Down
4 changes: 1 addition & 3 deletions R/LearnerRegrDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
x = p_dbl(0, 1, tags = "train")
),
properties = c("missings", "weights"),
packages = "stats",
man = "mlr3::mlr_learners_regr.debug",
label = "Debug Learner for Regression"
packages = "stats"
)
},

Expand Down
4 changes: 1 addition & 3 deletions R/LearnerRegrFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ LearnerRegrFeatureless = R6Class("LearnerRegrFeatureless", inherit = LearnerRegr
predict_types = c("response", "se", "quantiles"),
param_set = ps,
properties = c("featureless", "missings", "importance", "selected_features", "weights"),
packages = "stats",
label = "Featureless Regression Learner",
man = "mlr3::mlr_learners_regr.featureless"
packages = "stats"
)
},

Expand Down
4 changes: 1 addition & 3 deletions R/LearnerRegrRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr,
predict_types = "response",
packages = "rpart",
param_set = ps,
properties = c("weights", "missings", "importance", "selected_features"),
label = "Regression Tree",
man = "mlr3::mlr_learners_regr.rpart"
properties = c("weights", "missings", "importance", "selected_features")
)
},

Expand Down
Loading
Loading