diff --git a/R/FilterEnsemble.R b/R/FilterEnsemble.R index af2688d52..282a06322 100644 --- a/R/FilterEnsemble.R +++ b/R/FilterEnsemble.R @@ -1,5 +1,3 @@ - - #' @title Filter Ensemble #' #' @usage NULL @@ -30,8 +28,14 @@ #' Required non-negative weights, one for each wrapped filter, with at least one strictly positive value. #' Values are used as given when calculating the weighted mean. If named, names must match the wrapped filter ids. #' * `rank_transform` :: `logical(1)`\cr -#' If `TRUE`, ranks of individual filter scores are used instead of the raw scores before -#' averaging. Initialized to `FALSE`. +#' If `TRUE`, ranks of individual filter scores are used instead of the raw scores before averaging. +#' Initialized to `FALSE`. +#' * `filter_score_transform` :: `function`\cr +#' Function to be applied to the vector of individual filter scores after they were potentially transformed by +#' `rank_transform` but before weighting and aggregation. Initialized to `identity`. +#' * `result_score_transform` :: `function`\cr +#' Function to be applied to the vector of weighted scores after they were potentially transformed by `rank_transform` and/or +#' `filter_score_transform` but before aggregation. Initialized to `identity`. #' #' Parameters of wrapped filters are available via `$param_set` and can be referenced using #' the wrapped filter id followed by `"."`, e.g. `"variance.na.rm"`. @@ -54,8 +58,16 @@ #' #' @section Internals: #' All wrapped filters are called with `nfeat` equal to the number of features to ensure that -#' complete score vectors are available for aggregation. Scores are combined per feature by +#' complete score vectors are available for aggregation. +#' +#' Scores are combined per feature by #' computing the weighted (optionally rank-based) mean. +#' +#' Order of transformations: +#' 1. If `rank_transform` is `TRUE`, ranks are computed from the filter scores. +#' 2. `filter_score_transform` is applied to the scores (or ranks) +#' 3. `result_score_transform` is applied to the product of transformed scores/ranks and weights. +#' #' #' @section References: #' `r format_bib("binder_2020")` @@ -71,6 +83,15 @@ #' flt$param_set$values$weights = c(variance = 0.5, auc = 0.5) #' flt$calculate(task) #' head(as.data.table(flt)) +#' +#' # TODO: Example with agg_reciprocal_ranking +#' +#' +#' # TODO: remove this test example +#' filter = flt("ensemble", list(flt("anova"), flt("auc")), weights = c(0.5, 0.5)) +#' filter$calculate(tsk("spam")) +#' head(as.data.table(filter)) +#' #' @export FilterEnsemble = R6Class("FilterEnsemble", inherit = mlr3filters::Filter, public = list( @@ -96,7 +117,9 @@ FilterEnsemble = R6Class("FilterEnsemble", inherit = mlr3filters::Filter, }, fnames), tags = "required" ), - rank_transform = p_lgl(init = FALSE, tags = "required") + rank_transform = p_lgl(init = FALSE, tags = "required"), + filter_score_transform = p_uty(init = identity, tags = "required", custom_check = check_function), + result_score_transform = p_uty(init = identity, tags = "required", custom_check = check_function) ) super$initialize( @@ -171,8 +194,10 @@ FilterEnsemble = R6Class("FilterEnsemble", inherit = mlr3filters::Filter, scores = pmap(list(private$.wrapped, weights), function(x, w) { x$calculate(task, nfeat) s = x$scores[fn] + # TODO: What does s look like? if (pv$rank_transform) s = rank(s, na.last = "keep", ties.method = "average") - s * w + s = pv$filter_score_transform(s) + pv$result_score_transform(s * w) }) scores_df = as.data.frame(scores) combined = rowSums(scores_df, na.rm = TRUE) diff --git a/R/PipeOpFilter.R b/R/PipeOpFilter.R index b5e8fd88d..a8f15ebed 100644 --- a/R/PipeOpFilter.R +++ b/R/PipeOpFilter.R @@ -136,7 +136,7 @@ PipeOpFilter = R6Class("PipeOpFilter", filtercrit = c("nfeat", "frac", "cutoff", "permuted") filtercrit = Filter(function(name) !is.null(private$.outer_param_set$values[[name]]), filtercrit) if (length(filtercrit) != 1) { - stopf("Exactly one of 'nfeat', 'frac', 'cutoff', or 'permuted' must be given. Instead given: %s", + stopf("Exactly one hyperparameter of 'filter.nfeat', 'filter.frac', 'filter.cutoff', or 'filter.permuted' must be given. Instead given: %s", if (length(filtercrit) == 0) "none" else str_collapse(filtercrit)) } critvalue = private$.outer_param_set$values[[filtercrit]] diff --git a/tests/testthat/test_pipeop_filter.R b/tests/testthat/test_pipeop_filter.R index b4ecee96a..d394234f7 100644 --- a/tests/testthat/test_pipeop_filter.R +++ b/tests/testthat/test_pipeop_filter.R @@ -16,10 +16,10 @@ test_that("PipeOpFilter", { expect_equal(po$id, mlr3filters::FilterVariance$new()$id) - expect_error(po$train(list(task)), "Exactly one of 'nfeat', 'frac', 'cutoff', or 'permuted' must be given.*none") + expect_error(po$train(list(task)), "Exactly one hyperparameter of 'filter.nfeat', 'filter.frac', 'filter.cutoff', or 'filter.permuted' must be given.*none") po$param_set$values = list(filter.nfeat = 1, filter.frac = 1, na.rm = TRUE) - expect_error(po$train(list(task)), "Exactly one of 'nfeat', 'frac', 'cutoff', or 'permuted' must be given.*nfeat, frac") + expect_error(po$train(list(task)), "Exactly one hyperparameter of 'filter.nfeat', 'filter.frac', 'filter.cutoff', or 'filter.permuted' must be given.*nfeat, frac") po$param_set$values = list(filter.nfeat = 1, na.rm = TRUE)