Skip to content

Commit 8962612

Browse files
authored
[refactor] Getting dataset properties from the dataset object (#164)
* Use get_required_dataset_info of the dataset when needing required info for getting dataset requirements * Fix flake * Fix bug in getting dataset requirements * Added doc string to explain dataset properties * Update doc string in utils pipeline
1 parent 4493270 commit 8962612

File tree

7 files changed

+153
-112
lines changed

7 files changed

+153
-112
lines changed

autoPyTorch/api/base_task.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,6 @@ def __init__(
196196
raise ValueError("Expected search space updates to be of instance"
197197
" HyperparameterSearchSpaceUpdates got {}".format(type(self.search_space_updates)))
198198

199-
@abstractmethod
200-
def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, Any]:
201-
"""
202-
given a pipeline type, this function returns the
203-
dataset properties required by the dataset object
204-
"""
205-
raise NotImplementedError
206-
207199
@abstractmethod
208200
def build_pipeline(self, dataset_properties: Dict[str, Any]) -> BasePipeline:
209201
"""
@@ -267,7 +259,10 @@ def get_search_space(self, dataset: BaseDataset = None) -> ConfigurationSpace:
267259
return self.search_space
268260
elif dataset is not None:
269261
dataset_requirements = get_dataset_requirements(
270-
info=self._get_required_dataset_properties(dataset))
262+
info=dataset.get_required_dataset_info(),
263+
include=self.include_components,
264+
exclude=self.exclude_components,
265+
search_space_updates=self.search_space_updates)
271266
return get_configuration_space(info=dataset.get_dataset_properties(dataset_requirements),
272267
include=self.include_components,
273268
exclude=self.exclude_components,
@@ -785,7 +780,10 @@ def _search(
785780
# Initialise information needed for the experiment
786781
experiment_task_name: str = 'runSearch'
787782
dataset_requirements = get_dataset_requirements(
788-
info=self._get_required_dataset_properties(dataset))
783+
info=dataset.get_required_dataset_info(),
784+
include=self.include_components,
785+
exclude=self.exclude_components,
786+
search_space_updates=self.search_space_updates)
789787
self._dataset_requirements = dataset_requirements
790788
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
791789
self._stopwatch.start_task(experiment_task_name)
@@ -1049,7 +1047,10 @@ def refit(
10491047
self._logger = self._get_logger(str(self.dataset_name))
10501048

10511049
dataset_requirements = get_dataset_requirements(
1052-
info=self._get_required_dataset_properties(dataset))
1050+
info=dataset.get_required_dataset_info(),
1051+
include=self.include_components,
1052+
exclude=self.exclude_components,
1053+
search_space_updates=self.search_space_updates)
10531054
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
10541055
self._backend.save_datamanager(dataset)
10551056

@@ -1119,7 +1120,10 @@ def fit(self,
11191120

11201121
# get dataset properties
11211122
dataset_requirements = get_dataset_requirements(
1122-
info=self._get_required_dataset_properties(dataset))
1123+
info=dataset.get_required_dataset_info(),
1124+
include=self.include_components,
1125+
exclude=self.exclude_components,
1126+
search_space_updates=self.search_space_updates)
11231127
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
11241128
self._backend.save_datamanager(dataset)
11251129

autoPyTorch/api/tabular_classification.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
TASK_TYPES_TO_STRING,
1414
)
1515
from autoPyTorch.data.tabular_validator import TabularInputValidator
16-
from autoPyTorch.datasets.base_dataset import BaseDataset
1716
from autoPyTorch.datasets.resampling_strategy import (
1817
CrossValTypes,
1918
HoldoutValTypes,
@@ -97,17 +96,6 @@ def __init__(
9796
task_type=TASK_TYPES_TO_STRING[TABULAR_CLASSIFICATION],
9897
)
9998

100-
def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, Any]:
101-
if not isinstance(dataset, TabularDataset):
102-
raise ValueError("Dataset is incompatible for the given task,: {}".format(
103-
type(dataset)
104-
))
105-
return {'task_type': dataset.task_type,
106-
'output_type': dataset.output_type,
107-
'issparse': dataset.issparse,
108-
'numerical_columns': dataset.numerical_columns,
109-
'categorical_columns': dataset.categorical_columns}
110-
11199
def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularClassificationPipeline:
112100
return TabularClassificationPipeline(dataset_properties=dataset_properties)
113101

autoPyTorch/api/tabular_regression.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
TASK_TYPES_TO_STRING
1414
)
1515
from autoPyTorch.data.tabular_validator import TabularInputValidator
16-
from autoPyTorch.datasets.base_dataset import BaseDataset
1716
from autoPyTorch.datasets.resampling_strategy import (
1817
CrossValTypes,
1918
HoldoutValTypes,
@@ -89,17 +88,6 @@ def __init__(
8988
task_type=TASK_TYPES_TO_STRING[TABULAR_REGRESSION],
9089
)
9190

92-
def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, Any]:
93-
if not isinstance(dataset, TabularDataset):
94-
raise ValueError("Dataset is incompatible for the given task,: {}".format(
95-
type(dataset)
96-
))
97-
return {'task_type': dataset.task_type,
98-
'output_type': dataset.output_type,
99-
'issparse': dataset.issparse,
100-
'numerical_columns': dataset.numerical_columns,
101-
'categorical_columns': dataset.categorical_columns}
102-
10391
def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularRegressionPipeline:
10492
return TabularRegressionPipeline(dataset_properties=dataset_properties)
10593

autoPyTorch/datasets/base_dataset.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -348,11 +348,17 @@ def replace_data(self, X_train: BaseDatasetInputType,
348348

349349
def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) -> Dict[str, Any]:
350350
"""
351-
Gets the dataset properties required in the fit dictionary
351+
Gets the dataset properties required in the fit dictionary.
352+
This depends on the components that are active in the
353+
pipeline and returns the properties they need about the dataset.
354+
Information of the required properties of each component
355+
can be found in their documentation.
352356
Args:
353357
dataset_requirements (List[FitRequirement]): List of
354358
fit requirements that the dataset properties must
355-
contain.
359+
contain. This is created using the `get_dataset_requirements
360+
function in
361+
<https://github.com/automl/Auto-PyTorch/blob/refactor_development/autoPyTorch/utils/pipeline.py#L25>`
356362
357363
Returns:
358364
dataset_properties (Dict[str, Any]):
@@ -362,19 +368,15 @@ def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) ->
362368
for dataset_requirement in dataset_requirements:
363369
dataset_properties[dataset_requirement.name] = getattr(self, dataset_requirement.name)
364370

365-
# Add task type, output type and issparse to dataset properties as
366-
# they are not a dataset requirement in the pipeline
367-
dataset_properties.update({'task_type': self.task_type,
368-
'output_type': self.output_type,
369-
'issparse': self.issparse,
370-
'input_shape': self.input_shape,
371-
'output_shape': self.output_shape
372-
})
371+
# Add the required dataset info to dataset properties as
372+
# they might not be a dataset requirement in the pipeline
373+
dataset_properties.update(self.get_required_dataset_info())
373374
return dataset_properties
374375

375376
def get_required_dataset_info(self) -> Dict[str, Any]:
376377
"""
377-
Returns a dictionary containing required dataset properties to instantiate a pipeline,
378+
Returns a dictionary containing required dataset
379+
properties to instantiate a pipeline.
378380
"""
379381
info = {'output_type': self.output_type,
380382
'issparse': self.issparse}

autoPyTorch/datasets/tabular_dataset.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,24 @@ def __init__(self,
112112

113113
def get_required_dataset_info(self) -> Dict[str, Any]:
114114
"""
115-
Returns a dictionary containing required dataset properties to instantiate a pipeline,
115+
Returns a dictionary containing required dataset
116+
properties to instantiate a pipeline.
117+
For a Tabular Dataset this includes-
118+
1. 'output_type'- Enum indicating the type of the output for this problem.
119+
We currently use the `sklearn type_of_target
120+
<https://scikit-learn.org/stable/modules/generated/sklearn.utils.multiclass.type_of_target.html>`
121+
to infer the output type from the data and we encode it to an
122+
Enum for which you can find more info in `autopytorch/constants.py
123+
<https://github.com/automl/Auto-PyTorch/blob/refactor_development/autoPyTorch/constants.py>`
124+
2. 'issparse'- A flag indicating if the input is in a sparse matrix.
125+
3. 'numerical_columns'- a list which contains the column numbers
126+
for the numerical columns in the input dataset
127+
4. 'categorical_columns'- a list which contains the column numbers
128+
for the categorical columns in the input dataset
129+
5. 'task_type'- Enum indicating the type of task. For tabular datasets,
130+
currently we support 'tabular_classification' and 'tabular_regression'. and we encode it to an
131+
Enum for which you can find more info in `autopytorch/constants.py
132+
<https://github.com/automl/Auto-PyTorch/blob/refactor_development/autoPyTorch/constants.py>`
116133
"""
117134
info = super().get_required_dataset_info()
118135
info.update({

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
TABULAR_TASKS,
3232
)
3333
from autoPyTorch.datasets.base_dataset import BaseDataset
34-
from autoPyTorch.datasets.tabular_dataset import TabularDataset
3534
from autoPyTorch.evaluation.utils import (
3635
VotingRegressorWrapper,
3736
convert_multioutput_multiclass_to_multilabel
@@ -71,6 +70,7 @@ class MyTraditionalTabularClassificationPipeline(BaseEstimator):
7170
An optional dictionary that is passed to the pipeline's steps. It complies
7271
a similar function as the kwargs
7372
"""
73+
7474
def __init__(self, config: str,
7575
dataset_properties: Dict[str, Any],
7676
random_state: Optional[Union[int, np.random.RandomState]] = None,
@@ -141,6 +141,7 @@ class DummyClassificationPipeline(DummyClassifier):
141141
An optional dictionary that is passed to the pipeline's steps. It complies
142142
a similar function as the kwargs
143143
"""
144+
144145
def __init__(self, config: Configuration,
145146
random_state: Optional[Union[int, np.random.RandomState]] = None,
146147
init_params: Optional[Dict] = None
@@ -208,6 +209,7 @@ class DummyRegressionPipeline(DummyRegressor):
208209
An optional dictionary that is passed to the pipeline's steps. It complies
209210
a similar function as the kwargs
210211
"""
212+
211213
def __init__(self, config: Configuration,
212214
random_state: Optional[Union[int, np.random.RandomState]] = None,
213215
init_params: Optional[Dict] = None) -> None:
@@ -394,12 +396,9 @@ def __init__(self, backend: Backend,
394396
raise ValueError('disable_file_output should be either a bool or a list')
395397

396398
self.pipeline_class: Optional[Union[BaseEstimator, BasePipeline]] = None
397-
info: Dict[str, Any] = {'task_type': self.datamanager.task_type,
398-
'output_type': self.datamanager.output_type,
399-
'issparse': self.issparse}
400399
if self.task_type in REGRESSION_TASKS:
401400
if isinstance(self.configuration, int):
402-
self.pipeline_class = DummyClassificationPipeline
401+
self.pipeline_class = DummyRegressionPipeline
403402
elif isinstance(self.configuration, str):
404403
raise ValueError("Only tabular classifications tasks "
405404
"are currently supported with traditional methods")
@@ -425,11 +424,12 @@ def __init__(self, backend: Backend,
425424
else:
426425
raise ValueError('task {} not available'.format(self.task_type))
427426
self.predict_function = self._predict_proba
428-
if self.task_type in TABULAR_TASKS:
429-
assert isinstance(self.datamanager, TabularDataset)
430-
info.update({'numerical_columns': self.datamanager.numerical_columns,
431-
'categorical_columns': self.datamanager.categorical_columns})
432-
self.dataset_properties = self.datamanager.get_dataset_properties(get_dataset_requirements(info))
427+
self.dataset_properties = self.datamanager.get_dataset_properties(
428+
get_dataset_requirements(info=self.datamanager.get_required_dataset_info(),
429+
include=self.include,
430+
exclude=self.exclude,
431+
search_space_updates=self.search_space_updates
432+
))
433433

434434
self.additional_metrics: Optional[List[autoPyTorchMetric]] = None
435435
if all_supported_metrics:
@@ -630,9 +630,9 @@ def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float],
630630
return None
631631

632632
def calculate_auxiliary_losses(
633-
self,
634-
Y_valid_pred: np.ndarray,
635-
Y_test_pred: np.ndarray,
633+
self,
634+
Y_valid_pred: np.ndarray,
635+
Y_test_pred: np.ndarray,
636636
) -> Tuple[Optional[float], Optional[float]]:
637637
"""
638638
A helper function to calculate the performance estimate of the
@@ -670,10 +670,10 @@ def calculate_auxiliary_losses(
670670
return validation_loss, test_loss
671671

672672
def file_output(
673-
self,
674-
Y_optimization_pred: np.ndarray,
675-
Y_valid_pred: np.ndarray,
676-
Y_test_pred: np.ndarray
673+
self,
674+
Y_optimization_pred: np.ndarray,
675+
Y_valid_pred: np.ndarray,
676+
Y_test_pred: np.ndarray
677677
) -> Tuple[Optional[float], Dict]:
678678
"""
679679
This method decides what file outputs are written to disk.

0 commit comments

Comments
 (0)