Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
* feat: `benchmark_grid()` will now throw a warning if you mix different predict types in the
design (#1273).
* feat: Converting a `BenchmarkResult` to a `data.table` now includes the `task_id`, `learner_id`, and `resampling_id` columns (#1275).
* fix: Instantiating (repeated) CV on tasks with observations less than the
number of folds now fails.

# mlr3 0.23.0

Expand Down
4 changes: 4 additions & 0 deletions R/Resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ Resampling = R6Class("Resampling",
#' the object in its previous state.
instantiate = function(task) {
task = assert_task(as_task(task))
private$.check(task)
strata = task$strata
groups = task$groups

Expand Down Expand Up @@ -257,6 +258,9 @@ Resampling = R6Class("Resampling",
.id = NULL,
.hash = NULL,
.groups = NULL,
.check = function(task) {
TRUE
},

.get_set = function(getter, i) {
if (!self$is_instantiated) {
Expand Down
15 changes: 14 additions & 1 deletion R/ResamplingCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,25 @@ ResamplingCV = R6Class("ResamplingCV", inherit = Resampling,

private = list(
.sample = function(ids, ...) {
pvs = self$param_set$get_values()
data.table(
row_id = ids,
fold = shuffle(seq_along0(ids) %% as.integer(self$param_set$values$folds) + 1L),
fold = shuffle(seq_along0(ids) %% as.integer(pvs$folds) + 1L),
key = "fold"
)
},
.check = function(task) {
pvs = self$param_set$get_values()
if (!is.null(task$groups)) {
n_groups = length(unique(task$groups$group))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uniqueN() can make use of the data.table key.

if (n_groups < pvs$folds) {
stopf("Cannot instantiate ResamplingCV with %i folds on a grouped task with %i groups.", pvs$folds, n_groups)
}
}
if (task$nrow < pvs$folds) {
stopf("Cannot instantiate ResamplingCV with %i folds on a task with %i rows.", pvs$folds, task$nrow)
}
},

.get_train = function(i) {
self$instance[!list(i), "row_id", on = "fold"][[1L]]
Expand Down
14 changes: 13 additions & 1 deletion R/ResamplingRepeatedCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,25 @@ ResamplingRepeatedCV = R6Class("ResamplingRepeatedCV", inherit = Resampling,

private = list(
.sample = function(ids, ...) {
pv = self$param_set$values
pv = self$param_set$get_values()
n = length(ids)
folds = as.integer(pv$folds)
map_dtr(seq_len(pv$repeats), function(i) {
data.table(row_id = ids, rep = i, fold = shuffle(seq_len0(n) %% folds + 1L))
})
},
.check = function(task) {
pvs = self$param_set$get_values()
if (!is.null(task$groups)) {
n_groups = length(unique(task$groups$group))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uniqueN() can make use of the data.table key.

if (n_groups < pvs$folds) {
stopf("Cannot instantiate ResamplingRepeatedCV with %i folds on a grouped task with %i groups.", pvs$folds, n_groups)
}
}
if (task$nrow < pvs$folds) {
stopf("Cannot instantiate ResamplingRepeatedCV with %i folds on a task with %i rows.", pvs$folds, task$nrow)
}
},

.get_train = function(i) {
i = as.integer(i) - 1L
Expand Down
14 changes: 14 additions & 0 deletions tests/testthat/test_Resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,17 @@ test_that("task_row_hash in Resampling works correctly", {
resampling$instantiate(task)
expect_identical(resampling$task_row_hash, task$row_hash)
})

test_that("folds must be <= task size", {
cv = rsmp("cv", folds = 151)
rep_cv = rsmp("repeated_cv", folds = 151)
task = tsk("iris")
expect_error(cv$instantiate(task), "Cannot instantiate ResamplingCV with 151 folds on a task with 150 rows")
expect_error(rep_cv$instantiate(task), "Cannot instantiate ResamplingRepeatedCV with 151 folds on a task with 150 rows")

task$col_roles$group = "Species"
cv$param_set$set_values(folds = 4L)
rep_cv$param_set$set_values(folds = 4L)
expect_error(cv$instantiate(task), "on a grouped task with 3 groups")
expect_error(rep_cv$instantiate(task), "on a grouped task with 3 groups")
})