22import pandas as pd
33import pytest
44from sklearn .ensemble import RandomForestClassifier
5+ from sklearn .exceptions import NotFittedError
56from sklearn .tree import DecisionTreeClassifier , ExtraTreeClassifier
67
78from boruta import BorutaPy
@@ -68,6 +69,62 @@ def test_dataframe_is_returned(Xy):
6869 assert isinstance (bt .transform (X_df , return_df = True ), pd .DataFrame )
6970
7071
72+ def test_selector_mixin_get_support_requires_fit ():
73+ bt = BorutaPy (RandomForestClassifier ())
74+ with pytest .raises (NotFittedError ):
75+ bt .get_support ()
76+
77+
78+ def test_selector_mixin_get_support_matches_mask (Xy ):
79+ X , y = Xy
80+ bt = BorutaPy (RandomForestClassifier ())
81+ bt .fit (X , y )
82+
83+ assert np .array_equal (bt .get_support (), bt .support_ )
84+ assert np .array_equal (bt .get_support (indices = True ),
85+ np .where (bt .support_ )[0 ])
86+
87+
88+ def test_selector_mixin_inverse_transform_restores_selected_features (Xy ):
89+ X , y = Xy
90+ bt = BorutaPy (RandomForestClassifier ())
91+ bt .fit (X , y )
92+
93+ X_selected = bt .transform (X )
94+ X_reconstructed = bt .inverse_transform (X_selected )
95+
96+ assert X_reconstructed .shape == X .shape
97+ assert np .allclose (X_reconstructed [:, bt .support_ ], X [:, bt .support_ ])
98+
99+ if (~ bt .support_ ).any ():
100+ assert np .allclose (X_reconstructed [:, ~ bt .support_ ], 0 )
101+
102+
103+ def test_selector_mixin_get_feature_names_out_requires_fit ():
104+ bt = BorutaPy (RandomForestClassifier ())
105+ with pytest .raises (NotFittedError ):
106+ bt .get_feature_names_out ()
107+
108+
109+ def test_selector_mixin_get_feature_names_out_returns_selected_names (Xy ):
110+ X , y = Xy
111+ bt = BorutaPy (RandomForestClassifier ())
112+ bt .fit (X , y )
113+
114+ expected_default = np .array ([f"x{ i } " for i in np .where (bt .support_ )[0 ]])
115+ assert np .array_equal (bt .get_feature_names_out (), expected_default )
116+
117+ custom_names = np .array ([f"feature_{ i } " for i in range (X .shape [1 ])])
118+ selected_names = bt .get_feature_names_out (custom_names )
119+ assert np .array_equal (selected_names , custom_names [bt .support_ ])
120+
121+ columns = [f"col_{ i } " for i in range (X .shape [1 ])]
122+ X_df = pd .DataFrame (X , columns = columns )
123+ bt_df = BorutaPy (RandomForestClassifier ())
124+ bt_df .fit (X_df , y )
125+ assert np .array_equal (bt_df .get_feature_names_out (), np .array (columns )[bt_df .support_ ])
126+
127+
71128@pytest .mark .parametrize ("tree" , [ExtraTreeClassifier (), DecisionTreeClassifier ()])
72129def test_boruta_with_decision_trees (tree , Xy ):
73130 msg = (
@@ -80,4 +137,4 @@ def test_boruta_with_decision_trees(tree, Xy):
80137 with pytest .raises (ValueError ) as record :
81138 bt .fit (X , y )
82139
83- assert str (record .value ) == msg
140+ assert str (record .value ) == msg
0 commit comments