Skip to content

Commit 2d2f6d1

Browse files
authored
[enhance] Increase the coverage (#336)
1 parent a1512d5 commit 2d2f6d1

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from ConfigSpace.configuration_space import ConfigurationSpace
2+
from ConfigSpace.hyperparameters import CategoricalHyperparameter, UniformIntegerHyperparameter
3+
4+
import pytest
5+
6+
from autoPyTorch.pipeline.components.base_component import ThirdPartyComponents, autoPyTorchComponent
7+
8+
9+
class DummyComponentRequiredFailuire(autoPyTorchComponent):
10+
_required_properties = {'required'}
11+
12+
def __init__(self, random_state=None):
13+
self.fitted = False
14+
self._cs_updates = {}
15+
16+
def fit(self, X, y):
17+
self.fitted = True
18+
return self
19+
20+
def get_properties(dataset_properties=None):
21+
return {"name": 'DummyComponentRequiredFailuire',
22+
"shortname": "Dummy"}
23+
24+
25+
class DummyComponentExtraPropFailuire(autoPyTorchComponent):
26+
27+
def __init__(self, random_state=None):
28+
self.fitted = False
29+
self._cs_updates = {}
30+
31+
def fit(self, X, y):
32+
self.fitted = True
33+
return self
34+
35+
def get_properties(dataset_properties=None):
36+
return {"name": 'DummyComponentExtraPropFailuire',
37+
"shortname": 'Dummy',
38+
"must_not_be_there": True}
39+
40+
41+
class DummyComponent(autoPyTorchComponent):
42+
def __init__(self, a=0, b='orange', random_state=None):
43+
self.a = a
44+
self.b = b
45+
self.fitted = False
46+
self.random_state = random_state
47+
self._cs_updates = {}
48+
49+
def get_hyperparameter_search_space(self, dataset_properties=None):
50+
cs = ConfigurationSpace()
51+
a = UniformIntegerHyperparameter('a', lower=10, upper=100, log=False)
52+
b = CategoricalHyperparameter('b', choices=['red', 'green', 'blue'])
53+
cs.add_hyperparameters([a, b])
54+
return cs
55+
56+
def fit(self, X, y):
57+
self.fitted = True
58+
return self
59+
60+
def get_properties(dataset_properties=None):
61+
return {"name": 'DummyComponent',
62+
"shortname": 'Dummy'}
63+
64+
65+
def test_third_party_component_failure():
66+
_addons = ThirdPartyComponents(autoPyTorchComponent)
67+
68+
with pytest.raises(ValueError, match=r"Property required not specified for .*"):
69+
_addons.add_component(DummyComponentRequiredFailuire)
70+
71+
with pytest.raises(ValueError, match=r"Property must_not_be_there must not be specified for algorithm .*"):
72+
_addons.add_component(DummyComponentExtraPropFailuire)
73+
74+
with pytest.raises(TypeError, match=r"add_component works only with a subclass of .*"):
75+
_addons.add_component(1)
76+
77+
78+
def test_set_hyperparameters_not_found_failure():
79+
dummy_component = DummyComponent()
80+
dummy_config_space = dummy_component.get_hyperparameter_search_space()
81+
success_configuration = dummy_config_space.sample_configuration()
82+
dummy_config_space.add_hyperparameter(CategoricalHyperparameter('c', choices=[1, 2]))
83+
failure_configuration = dummy_config_space.sample_configuration()
84+
with pytest.raises(ValueError, match=r"Cannot set hyperparameter c for autoPyTorch.pipeline "
85+
r"DummyComponent because the hyperparameter does not exist."):
86+
dummy_component.set_hyperparameters(failure_configuration)
87+
with pytest.raises(ValueError, match=r"Cannot set init param r for autoPyTorch.pipeline "
88+
r"DummyComponent because the init param does not exist."):
89+
dummy_component.set_hyperparameters(success_configuration, init_params={'r': 1})

test/test_pipeline/test_tabular_classification.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
import pytest
1414

1515
import torch
16+
from torch.optim.lr_scheduler import _LRScheduler
1617

1718
from autoPyTorch.pipeline.components.setup.early_preprocessor.utils import get_preprocess_transforms
19+
from autoPyTorch.pipeline.components.setup.lr_scheduler.NoScheduler import NoScheduler
1820
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
1921
from autoPyTorch.utils.common import FitRequirement
2022
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates, \
@@ -223,6 +225,7 @@ def test_network_optimizer_lr_handshake(self, fit_dictionary_tabular):
223225
# No error when network is passed
224226
X = pipeline.named_steps['optimizer'].fit(X, None).transform(X)
225227
assert 'optimizer' in X
228+
assert isinstance(pipeline.named_steps['optimizer'].choice.get_optimizer(), torch.optim.Optimizer)
226229

227230
# Then fitting a optimizer should fail if no network:
228231
assert 'lr_scheduler' in pipeline.named_steps.keys()
@@ -234,7 +237,13 @@ def test_network_optimizer_lr_handshake(self, fit_dictionary_tabular):
234237

235238
# No error when network is passed
236239
X = pipeline.named_steps['lr_scheduler'].fit(X, None).transform(X)
237-
assert 'optimizer' in X
240+
assert 'lr_scheduler' in X
241+
if isinstance(pipeline.named_steps['lr_scheduler'].choice, NoScheduler):
242+
pytest.skip("This scheduler does not support `get_scheduler`")
243+
lr_scheduler = pipeline.named_steps['lr_scheduler'].choice.get_scheduler()
244+
if isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
245+
pytest.skip("This scheduler is not a child of _LRScheduler")
246+
assert isinstance(lr_scheduler, _LRScheduler)
238247

239248
def test_get_fit_requirements(self, fit_dictionary_tabular):
240249
dataset_properties = {'numerical_columns': [], 'categorical_columns': [],

0 commit comments

Comments
 (0)