Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions R/FilterEnsemble.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


#' @title Filter Ensemble
#'
#' @usage NULL
Expand Down Expand Up @@ -30,8 +28,14 @@
#' Required non-negative weights, one for each wrapped filter, with at least one strictly positive value.
#' Values are used as given when calculating the weighted mean. If named, names must match the wrapped filter ids.
#' * `rank_transform` :: `logical(1)`\cr
#' If `TRUE`, ranks of individual filter scores are used instead of the raw scores before
#' averaging. Initialized to `FALSE`.
#' If `TRUE`, ranks of individual filter scores are used instead of the raw scores before averaging.
#' Initialized to `FALSE`.
#' * `filter_score_transform` :: `function`\cr
#' Function to be applied to the vector of individual filter scores after they were potentially transformed by
#' `rank_transform` but before weighting and aggregation. Initialized to `identity`.
#' * `result_score_transform` :: `function`\cr
#' Function to be applied to the vector of weighted scores after they were potentially transformed by `rank_transform` and/or
#' `filter_score_transform` but before aggregation. Initialized to `identity`.
#'
#' Parameters of wrapped filters are available via `$param_set` and can be referenced using
#' the wrapped filter id followed by `"."`, e.g. `"variance.na.rm"`.
Expand All @@ -54,8 +58,16 @@
#'
#' @section Internals:
#' All wrapped filters are called with `nfeat` equal to the number of features to ensure that
#' complete score vectors are available for aggregation. Scores are combined per feature by
#' complete score vectors are available for aggregation.
#'
#' Scores are combined per feature by
#' computing the weighted (optionally rank-based) mean.
#'
#' Order of transformations:
#' 1. If `rank_transform` is `TRUE`, ranks are computed from the filter scores.
#' 2. `filter_score_transform` is applied to the scores (or ranks)
#' 3. `result_score_transform` is applied to the product of transformed scores/ranks and weights.
#'
#'
#' @section References:
#' `r format_bib("binder_2020")`
Expand All @@ -71,6 +83,15 @@
#' flt$param_set$values$weights = c(variance = 0.5, auc = 0.5)
#' flt$calculate(task)
#' head(as.data.table(flt))
#'
#' # TODO: Example with agg_reciprocal_ranking
#'
#'
#' # TODO: remove this test example
#' filter = flt("ensemble", list(flt("anova"), flt("auc")), weights = c(0.5, 0.5))
#' filter$calculate(tsk("spam"))
#' head(as.data.table(filter))
#'
#' @export
FilterEnsemble = R6Class("FilterEnsemble", inherit = mlr3filters::Filter,
public = list(
Expand All @@ -96,7 +117,9 @@ FilterEnsemble = R6Class("FilterEnsemble", inherit = mlr3filters::Filter,
}, fnames),
tags = "required"
),
rank_transform = p_lgl(init = FALSE, tags = "required")
rank_transform = p_lgl(init = FALSE, tags = "required"),
filter_score_transform = p_uty(init = identity, tags = "required", custom_check = check_function),
result_score_transform = p_uty(init = identity, tags = "required", custom_check = check_function)
)

super$initialize(
Expand Down Expand Up @@ -171,8 +194,10 @@ FilterEnsemble = R6Class("FilterEnsemble", inherit = mlr3filters::Filter,
scores = pmap(list(private$.wrapped, weights), function(x, w) {
x$calculate(task, nfeat)
s = x$scores[fn]
# TODO: What does s look like?
if (pv$rank_transform) s = rank(s, na.last = "keep", ties.method = "average")
s * w
s = pv$filter_score_transform(s)
pv$result_score_transform(s * w)
})
scores_df = as.data.frame(scores)
combined = rowSums(scores_df, na.rm = TRUE)
Expand Down
2 changes: 1 addition & 1 deletion R/PipeOpFilter.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ PipeOpFilter = R6Class("PipeOpFilter",
filtercrit = c("nfeat", "frac", "cutoff", "permuted")
filtercrit = Filter(function(name) !is.null(private$.outer_param_set$values[[name]]), filtercrit)
if (length(filtercrit) != 1) {
stopf("Exactly one of 'nfeat', 'frac', 'cutoff', or 'permuted' must be given. Instead given: %s",
stopf("Exactly one hyperparameter of 'filter.nfeat', 'filter.frac', 'filter.cutoff', or 'filter.permuted' must be given. Instead given: %s",
if (length(filtercrit) == 0) "none" else str_collapse(filtercrit))
}
critvalue = private$.outer_param_set$values[[filtercrit]]
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_pipeop_filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ test_that("PipeOpFilter", {

expect_equal(po$id, mlr3filters::FilterVariance$new()$id)

expect_error(po$train(list(task)), "Exactly one of 'nfeat', 'frac', 'cutoff', or 'permuted' must be given.*none")
expect_error(po$train(list(task)), "Exactly one hyperparameter of 'filter.nfeat', 'filter.frac', 'filter.cutoff', or 'filter.permuted' must be given.*none")

po$param_set$values = list(filter.nfeat = 1, filter.frac = 1, na.rm = TRUE)
expect_error(po$train(list(task)), "Exactly one of 'nfeat', 'frac', 'cutoff', or 'permuted' must be given.*nfeat, frac")
expect_error(po$train(list(task)), "Exactly one hyperparameter of 'filter.nfeat', 'filter.frac', 'filter.cutoff', or 'filter.permuted' must be given.*nfeat, frac")

po$param_set$values = list(filter.nfeat = 1, na.rm = TRUE)

Expand Down