Skip to content

Commit dd7b020

Browse files
authored
Swap TransformerMixin inheritance by sklearn’s SelectorMixin (#145)
* Adapt BorutaPy to SelectorMixin interface * Test full SelectorMixin API
1 parent 0135a91 commit dd7b020

File tree

2 files changed

+74
-3
lines changed

2 files changed

+74
-3
lines changed

boruta/boruta_py.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
import numpy as np
1313
import scipy as sp
1414
from sklearn.utils import check_random_state, check_X_y
15-
from sklearn.base import TransformerMixin, BaseEstimator
15+
from sklearn.base import BaseEstimator
16+
from sklearn.feature_selection import SelectorMixin
17+
from sklearn.utils.validation import check_is_fitted
1618
import warnings
1719

1820

19-
class BorutaPy(BaseEstimator, TransformerMixin):
21+
class BorutaPy(BaseEstimator, SelectorMixin):
2022
"""
2123
Improved Python implementation of the Boruta R package.
2224
@@ -287,11 +289,19 @@ def _fit(self, X, y):
287289
# check input params
288290
self._check_params(X, y)
289291

292+
feature_names = getattr(X, "columns", None)
293+
if feature_names is not None:
294+
self.feature_names_in_ = np.asarray(feature_names, dtype=object)
295+
else:
296+
self.feature_names_in_ = None
297+
290298
if not isinstance(X, np.ndarray):
291299
X = self._validate_pandas_input(X)
292300
if not isinstance(y, np.ndarray):
293301
y = self._validate_pandas_input(y)
294302

303+
self.n_features_in_ = X.shape[1]
304+
295305
self.random_state = check_random_state(self.random_state)
296306

297307
early_stopping = False
@@ -465,6 +475,10 @@ def _set_n_estimators(self, n_estimators):
465475
)
466476
return self
467477

478+
def _get_support_mask(self):
479+
check_is_fitted(self, 'support_')
480+
return self.support_
481+
468482
def _get_tree_num(self, n_feat):
469483
depth = None
470484
try:

boruta/test/test_boruta.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pandas as pd
33
import pytest
44
from sklearn.ensemble import RandomForestClassifier
5+
from sklearn.exceptions import NotFittedError
56
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
67

78
from boruta import BorutaPy
@@ -68,6 +69,62 @@ def test_dataframe_is_returned(Xy):
6869
assert isinstance(bt.transform(X_df, return_df=True), pd.DataFrame)
6970

7071

72+
def test_selector_mixin_get_support_requires_fit():
73+
bt = BorutaPy(RandomForestClassifier())
74+
with pytest.raises(NotFittedError):
75+
bt.get_support()
76+
77+
78+
def test_selector_mixin_get_support_matches_mask(Xy):
79+
X, y = Xy
80+
bt = BorutaPy(RandomForestClassifier())
81+
bt.fit(X, y)
82+
83+
assert np.array_equal(bt.get_support(), bt.support_)
84+
assert np.array_equal(bt.get_support(indices=True),
85+
np.where(bt.support_)[0])
86+
87+
88+
def test_selector_mixin_inverse_transform_restores_selected_features(Xy):
89+
X, y = Xy
90+
bt = BorutaPy(RandomForestClassifier())
91+
bt.fit(X, y)
92+
93+
X_selected = bt.transform(X)
94+
X_reconstructed = bt.inverse_transform(X_selected)
95+
96+
assert X_reconstructed.shape == X.shape
97+
assert np.allclose(X_reconstructed[:, bt.support_], X[:, bt.support_])
98+
99+
if (~bt.support_).any():
100+
assert np.allclose(X_reconstructed[:, ~bt.support_], 0)
101+
102+
103+
def test_selector_mixin_get_feature_names_out_requires_fit():
104+
bt = BorutaPy(RandomForestClassifier())
105+
with pytest.raises(NotFittedError):
106+
bt.get_feature_names_out()
107+
108+
109+
def test_selector_mixin_get_feature_names_out_returns_selected_names(Xy):
110+
X, y = Xy
111+
bt = BorutaPy(RandomForestClassifier())
112+
bt.fit(X, y)
113+
114+
expected_default = np.array([f"x{i}" for i in np.where(bt.support_)[0]])
115+
assert np.array_equal(bt.get_feature_names_out(), expected_default)
116+
117+
custom_names = np.array([f"feature_{i}" for i in range(X.shape[1])])
118+
selected_names = bt.get_feature_names_out(custom_names)
119+
assert np.array_equal(selected_names, custom_names[bt.support_])
120+
121+
columns = [f"col_{i}" for i in range(X.shape[1])]
122+
X_df = pd.DataFrame(X, columns=columns)
123+
bt_df = BorutaPy(RandomForestClassifier())
124+
bt_df.fit(X_df, y)
125+
assert np.array_equal(bt_df.get_feature_names_out(), np.array(columns)[bt_df.support_])
126+
127+
71128
@pytest.mark.parametrize("tree", [ExtraTreeClassifier(), DecisionTreeClassifier()])
72129
def test_boruta_with_decision_trees(tree, Xy):
73130
msg = (
@@ -80,4 +137,4 @@ def test_boruta_with_decision_trees(tree, Xy):
80137
with pytest.raises(ValueError) as record:
81138
bt.fit(X, y)
82139

83-
assert str(record.value) == msg
140+
assert str(record.value) == msg

0 commit comments

Comments
 (0)