Skip to content

Commit f0c2aa0

Browse files
committed
add test for comparator and other improvements based on PR comments
1 parent dc01cd3 commit f0c2aa0

File tree

6 files changed

+79
-32
lines changed

6 files changed

+79
-32
lines changed

autoPyTorch/api/base_task.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import tempfile
1010
import time
1111
import typing
12-
from typing_extensions import runtime
1312
import unittest.mock
1413
import warnings
1514
from abc import abstractmethod
@@ -751,13 +750,14 @@ def run_traditional_ml(
751750
self,
752751
current_task_name: str,
753752
runtime_limit: int,
754-
func_eval_time_limit_secs: int) -> None:
753+
func_eval_time_limit_secs: int
754+
) -> None:
755755
"""
756756
This function can be used to run the suite of traditional machine
757-
learning models during the current task (for e.g, ensemble fit, search)
757+
learning models during the current task (for e.g, ensemble fit, search)
758758
759759
Args:
760-
current_task_name (str): name of the current task,
760+
current_task_name (str): name of the current task,
761761
runtime_limit (int): time limit for fitting traditional models,
762762
func_eval_time_limit_secs (int): Time limit
763763
for a single call to the machine learning model.

autoPyTorch/data/base_target_validator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@ class BaseTargetValidator(BaseEstimator):
4343
"""
4444
def __init__(self,
4545
is_classification: bool = False,
46-
logger: Optional[Union[PicklableClientLogger, logging.Logger
47-
]] = None,
46+
logger: Optional[Union[PicklableClientLogger,
47+
logging.Logger
48+
]
49+
] = None,
4850
) -> None:
4951
self.is_classification = is_classification
5052

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,29 @@ def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]:
7878

7979
class TabularFeatureValidator(BaseFeatureValidator):
8080

81+
@staticmethod
82+
def _comparator(cmp1: str, cmp2: str) -> int:
83+
"""Order so that categorical columns come right and numerical columns come left
84+
85+
Args:
86+
cmp1 (str): First variable to compare
87+
cmp2 (str): Second variable to compare
88+
89+
Raises:
90+
ValueError: if the values of the variables to compare
91+
are not in 'categorical' or 'numerical'
92+
93+
Returns:
94+
int: either [0, -1, 1]
95+
"""
96+
choices = ['categorical', 'numerical']
97+
if cmp1 not in choices or cmp2 not in choices:
98+
raise ValueError('The comparator for the column order only accepts {}, '
99+
'but got {} and {}'.format(choices, cmp1, cmp2))
100+
101+
idx1, idx2 = choices.index(cmp1), choices.index(cmp2)
102+
return idx1 - idx2
103+
81104
def _fit(
82105
self,
83106
X: SUPPORTED_FEAT_TYPES,
@@ -130,19 +153,10 @@ def _fit(
130153
# The column transformer reorders the feature types
131154
# therefore, we need to change the order of columns as well
132155
# This means categorical columns are shifted to the right
133-
def comparator(cmp1: str, cmp2: str) -> int:
134-
""" Order so that categorical columns come right and numerical columns come left """
135-
choices = ['categorical', 'numerical']
136-
if cmp1 not in choices or cmp2 not in choices:
137-
raise ValueError('The comparator for the column order only accepts {}, '
138-
'but got {} and {}'.format(choices, cmp1, cmp2))
139-
140-
idx1, idx2 = choices.index(cmp1), choices.index(cmp2)
141-
return idx1 - idx2
142156

143157
self.feat_type = sorted(
144158
feat_type,
145-
key=functools.cmp_to_key(comparator)
159+
key=functools.cmp_to_key(self._comparator)
146160
)
147161

148162
# differently to categorical_columns and numerical_columns,

autoPyTorch/pipeline/base_pipeline.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def _check_search_space_updates(self, include: Optional[Dict[str, Any]],
400400
raise ValueError("Unknown node name. Expected update node name to be in {} "
401401
"got {}".format(self.named_steps.keys(), update.node_name))
402402
node = self.named_steps[update.node_name]
403+
node_name = node.__class__.__name__
403404
# if node is a choice module
404405
if hasattr(node, 'get_components'):
405406
split_hyperparameter = update.hyperparameter.split(':')
@@ -429,16 +430,16 @@ def _check_search_space_updates(self, include: Optional[Dict[str, Any]],
429430
if choice not in components.keys():
430431
raise ValueError("Unknown component choice for node {}. "
431432
"Expected update hyperparameter "
432-
"to be in {}, but got {}".format(node.__class__.__name__,
433-
components.keys(), choice))
433+
"to be in {}, but got {}".format(node_name,
434+
components.keys(), choice))
434435
# check if the component whose hyperparameter
435436
# needs to be updated is in components of the
436437
# choice module
437438
elif split_hyperparameter[0] not in components.keys():
438439
raise ValueError("Unknown component choice for node {}. "
439440
"Expected update component "
440-
"to be in {}, but got {}".format(node.__class__.__name__,
441-
components.keys(), split_hyperparameter[0]))
441+
"to be in {}, but got {}".format(node_name,
442+
components.keys(), split_hyperparameter[0]))
442443
else:
443444
# check if hyperparameter is in the search space of the component
444445
component = components[split_hyperparameter[0]]
@@ -451,15 +452,15 @@ def _check_search_space_updates(self, include: Optional[Dict[str, Any]],
451452
component.get_hyperparameter_search_space(
452453
dataset_properties=self.dataset_properties).get_hyperparameter_names()]):
453454
continue
455+
component_hyperparameters = component.get_hyperparameter_search_space(
456+
dataset_properties=self.dataset_properties).get_hyperparameter_names()
454457
raise ValueError("Unknown hyperparameter for component {} of node {}."
455458
" Expected update hyperparameter "
456459
"to be in {}, but got {}.".format(component.__name__,
457-
node.__class__.__name__,
458-
component.get_hyperparameter_search_space(
459-
dataset_properties=self.dataset_properties
460-
).get_hyperparameter_names(),
461-
split_hyperparameter[1]
462-
)
460+
node_name,
461+
component_hyperparameters,
462+
split_hyperparameter[1]
463+
)
463464
)
464465
else:
465466
if update.hyperparameter not in node.get_hyperparameter_search_space(
@@ -468,13 +469,13 @@ def _check_search_space_updates(self, include: Optional[Dict[str, Any]],
468469
node.get_hyperparameter_search_space(
469470
dataset_properties=self.dataset_properties).get_hyperparameter_names()]):
470471
continue
472+
node_hyperparameters = node.get_hyperparameter_search_space(
473+
dataset_properties=self.dataset_properties).get_hyperparameter_names()
471474
raise ValueError("Unknown hyperparameter for node {}. "
472475
"Expected update hyperparameter "
473-
"to be in {}, but got {}".format(node.__class__.__name__,
474-
node.
475-
get_hyperparameter_search_space(
476-
dataset_properties=self.dataset_properties).
477-
get_hyperparameter_names(), update.hyperparameter))
476+
"to be in {}, but got {}".format(node_name,
477+
node_hyperparameters,
478+
update.hyperparameter))
478479

479480
def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]]
480481
) -> List[Tuple[str, autoPyTorchChoice]]:

autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
2929
beta = 1.0
3030
lam = self.random_state.beta(beta, beta)
3131
batch_size, n_columns = np.shape(X)
32-
# shuffled_indices: Shuffled version of torch.arange(batch_size)
32+
# shuffled_indices: Shuffled version of torch.arange(batch_size)
3333
shuffled_indices = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size)
3434

3535
r = self.random_state.rand(1)

test/test_data/test_feature_validator.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import functools
23

34
import numpy as np
45

@@ -331,6 +332,11 @@ def feature_validator_remove_nan_catcolumns(df_train: pd.DataFrame, df_test: pd.
331332
def test_feature_validator_remove_nan_catcolumns():
332333
"""
333334
Make sure categorical columns that have only nan values are removed.
335+
The ans arrays contain the final output after calling transform on
336+
datasets, this includes fitting and transforming a column transformer
337+
containing simple imputation for both categorical and numerical
338+
columns, scaling for numerical columns and one hot encoding for
339+
categorical columns.
334340
"""
335341
# First case, there exist null columns (B and C) in the train set
336342
# and a same column (C) are not all null for the test set.
@@ -396,6 +402,7 @@ def test_feature_validator_remove_nan_catcolumns():
396402
ans_test = np.array([[0, 0, 0, 0], [0, 0, 0, 0]], dtype=np.float64)
397403
feature_validator_remove_nan_catcolumns(df_train, df_test, ans_train, ans_test)
398404

405+
399406
def test_features_unsupported_calls_are_raised():
400407
"""
401408
Makes sure we raise a proper message to the user,
@@ -664,3 +671,26 @@ def test_feature_validator_imbalanced_data():
664671
transformed_X_test = validator.transform(X_test)
665672
transformed_X_test = pd.DataFrame(transformed_X_test)
666673
assert not len(validator.all_nan_columns)
674+
675+
676+
def test_comparator():
677+
numerical = 'numerical'
678+
categorical = 'categorical'
679+
680+
validator = TabularFeatureValidator
681+
682+
feat_type = [numerical, categorical] * 10
683+
ans = [categorical] * 10 + [numerical] * 10
684+
feat_type = sorted(
685+
feat_type,
686+
key=functools.cmp_to_key(validator._comparator)
687+
)
688+
assert ans == feat_type
689+
690+
feat_type = [numerical] * 10 + [categorical] * 10
691+
ans = [categorical] * 10 + [numerical] * 10
692+
feat_type = sorted(
693+
feat_type,
694+
key=functools.cmp_to_key(validator._comparator)
695+
)
696+
assert ans == feat_type

0 commit comments

Comments
 (0)