|
12 | 12 | import unittest.mock |
13 | 13 | import warnings |
14 | 14 | from abc import abstractmethod |
15 | | -from typing import Any, Callable, Dict, List, Optional, Union, cast |
| 15 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast |
16 | 16 |
|
17 | 17 | from ConfigSpace.configuration_space import Configuration, ConfigurationSpace |
18 | 18 |
|
@@ -223,9 +223,7 @@ def build_pipeline(self, dataset_properties: Dict[str, Any]) -> BasePipeline: |
223 | 223 | """ |
224 | 224 | raise NotImplementedError |
225 | 225 |
|
226 | | - def set_pipeline_config( |
227 | | - self, |
228 | | - **pipeline_config_kwargs: Any) -> None: |
| 226 | + def set_pipeline_config(self, **pipeline_config_kwargs: Any) -> None: |
229 | 227 | """ |
230 | 228 | Check whether arguments are valid and |
231 | 229 | then sets them to the current pipeline |
@@ -259,12 +257,6 @@ def get_pipeline_options(self) -> dict: |
259 | 257 | """ |
260 | 258 | return self.pipeline_options |
261 | 259 |
|
262 | | - # def set_search_space(self, search_space: ConfigurationSpace) -> None: |
263 | | - # """ |
264 | | - # Update the search space. |
265 | | - # """ |
266 | | - # raise NotImplementedError |
267 | | - # |
268 | 260 | def get_search_space(self, dataset: BaseDataset = None) -> ConfigurationSpace: |
269 | 261 | """ |
270 | 262 | Returns the current search space as ConfigurationSpace object. |
@@ -406,9 +398,9 @@ def _close_dask_client(self) -> None: |
406 | 398 | None |
407 | 399 | """ |
408 | 400 | if ( |
409 | | - hasattr(self, '_is_dask_client_internally_created') |
410 | | - and self._is_dask_client_internally_created |
411 | | - and self._dask_client |
| 401 | + hasattr(self, '_is_dask_client_internally_created') |
| 402 | + and self._is_dask_client_internally_created |
| 403 | + and self._dask_client |
412 | 404 | ): |
413 | 405 | self._dask_client.shutdown() |
414 | 406 | self._dask_client.close() |
@@ -661,10 +653,11 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs: |
661 | 653 | f"Fitting {cls} took {runtime}s, performance:{cost}/{additional_info}") |
662 | 654 | configuration = additional_info['pipeline_configuration'] |
663 | 655 | origin = additional_info['configuration_origin'] |
| 656 | + additional_info.pop('pipeline_configuration') |
664 | 657 | run_history.add(config=configuration, cost=cost, |
665 | 658 | time=runtime, status=status, seed=self.seed, |
666 | 659 | starttime=starttime, endtime=starttime + runtime, |
667 | | - origin=origin) |
| 660 | + origin=origin, additional_info=additional_info) |
668 | 661 | else: |
669 | 662 | if additional_info.get('exitcode') == -6: |
670 | 663 | self._logger.error( |
@@ -710,6 +703,7 @@ def _search( |
710 | 703 | memory_limit: Optional[int] = 4096, |
711 | 704 | smac_scenario_args: Optional[Dict[str, Any]] = None, |
712 | 705 | get_smac_object_callback: Optional[Callable] = None, |
| 706 | + tae_func: Optional[Callable] = None, |
713 | 707 | all_supported_metrics: bool = True, |
714 | 708 | precision: int = 32, |
715 | 709 | disable_file_output: List = [], |
@@ -777,6 +771,10 @@ def _search( |
777 | 771 | instances, num_params, runhistory, seed and ta. This is |
778 | 772 | an advanced feature. Use only if you are familiar with |
779 | 773 | [SMAC](https://automl.github.io/SMAC3/master/index.html). |
| 774 | + tae_func (Optional[Callable]): |
| 775 | + TargetAlgorithm to be optimised. If None, `eval_function` |
| 776 | + available in autoPyTorch/evaluation/train_evaluator is used. |
| 777 | + Must be child class of AbstractEvaluator. |
780 | 778 | all_supported_metrics (bool), (default=True): if True, all |
781 | 779 | metrics supporting current task will be calculated |
782 | 780 | for each pipeline and results will be available via cv_results |
@@ -988,7 +986,7 @@ def _search( |
988 | 986 | ) |
989 | 987 | try: |
990 | 988 | run_history, self.trajectory, budget_type = \ |
991 | | - _proc_smac.run_smbo() |
| 989 | + _proc_smac.run_smbo(func=tae_func) |
992 | 990 | self.run_history.update(run_history, DataOrigin.INTERNAL) |
993 | 991 | trajectory_filename = os.path.join( |
994 | 992 | self._backend.get_smac_output_directory_for_run(self.seed), |
@@ -1042,10 +1040,10 @@ def _search( |
1042 | 1040 | return self |
1043 | 1041 |
|
1044 | 1042 | def refit( |
1045 | | - self, |
1046 | | - dataset: BaseDataset, |
1047 | | - budget_config: Dict[str, Union[int, str]] = {}, |
1048 | | - split_id: int = 0 |
| 1043 | + self, |
| 1044 | + dataset: BaseDataset, |
| 1045 | + budget_config: Dict[str, Union[int, str]] = {}, |
| 1046 | + split_id: int = 0 |
1049 | 1047 | ) -> "BaseTask": |
1050 | 1048 | """ |
1051 | 1049 | Refit all models found with fit to new data. |
@@ -1181,10 +1179,10 @@ def fit(self, |
1181 | 1179 | return pipeline |
1182 | 1180 |
|
1183 | 1181 | def predict( |
1184 | | - self, |
1185 | | - X_test: np.ndarray, |
1186 | | - batch_size: Optional[int] = None, |
1187 | | - n_jobs: int = 1 |
| 1182 | + self, |
| 1183 | + X_test: np.ndarray, |
| 1184 | + batch_size: Optional[int] = None, |
| 1185 | + n_jobs: int = 1 |
1188 | 1186 | ) -> np.ndarray: |
1189 | 1187 | """Generate the estimator predictions. |
1190 | 1188 | Generate the predictions based on the given examples from the test set. |
@@ -1234,9 +1232,9 @@ def predict( |
1234 | 1232 | return predictions |
1235 | 1233 |
|
1236 | 1234 | def score( |
1237 | | - self, |
1238 | | - y_pred: np.ndarray, |
1239 | | - y_test: Union[np.ndarray, pd.DataFrame] |
| 1235 | + self, |
| 1236 | + y_pred: np.ndarray, |
| 1237 | + y_test: Union[np.ndarray, pd.DataFrame] |
1240 | 1238 | ) -> Dict[str, float]: |
1241 | 1239 | """Calculate the score on the test set. |
1242 | 1240 | Calculate the evaluation measure on the test set. |
@@ -1277,17 +1275,37 @@ def __del__(self) -> None: |
1277 | 1275 | if hasattr(self, '_backend'): |
1278 | 1276 | self._backend.context.delete_directories(force=False) |
1279 | 1277 |
|
1280 | | - @typing.no_type_check |
1281 | 1278 | def get_incumbent_results( |
1282 | | - self |
1283 | | - ): |
1284 | | - pass |
| 1279 | + self, |
| 1280 | + include_traditional: bool = False |
| 1281 | + ) -> Tuple[Configuration, Dict[str, Union[int, str, float]]]: |
| 1282 | + """ |
| 1283 | + Get Incumbent config and the corresponding results |
| 1284 | + Args: |
| 1285 | + include_traditional: Whether to include results from tradtional pipelines |
1285 | 1286 |
|
1286 | | - @typing.no_type_check |
1287 | | - def get_incumbent_config( |
1288 | | - self |
1289 | | - ): |
1290 | | - pass |
| 1287 | + Returns: |
| 1288 | +
|
| 1289 | + """ |
| 1290 | + assert self.run_history is not None, "No Run History found, search has not been called." |
| 1291 | + if self.run_history.empty(): |
| 1292 | + raise ValueError("Run History is empty. Something went wrong, " |
| 1293 | + "smac was not able to fit any model?") |
| 1294 | + |
| 1295 | + run_history_data = self.run_history.data |
| 1296 | + if not include_traditional: |
| 1297 | + # traditional classifiers have trainer_configuration in their additional info |
| 1298 | + run_history_data = dict( |
| 1299 | + filter(lambda elem: elem[1].additional_info is not None and elem[1]. |
| 1300 | + additional_info['configuration_origin'] != 'traditional', |
| 1301 | + run_history_data.items())) |
| 1302 | + run_history_data = dict( |
| 1303 | + filter(lambda elem: 'SUCCESS' in str(elem[1].status), run_history_data.items())) |
| 1304 | + sorted_runvalue_by_cost = sorted(run_history_data.items(), key=lambda item: item[1].cost) |
| 1305 | + incumbent_run_key, incumbent_run_value = sorted_runvalue_by_cost[0] |
| 1306 | + incumbent_config = self.run_history.ids_config[incumbent_run_key.config_id] |
| 1307 | + incumbent_results = incumbent_run_value.additional_info |
| 1308 | + return incumbent_config, incumbent_results |
1291 | 1309 |
|
1292 | 1310 | def get_models_with_weights(self) -> List: |
1293 | 1311 | if self.models_ is None or len(self.models_) == 0 or \ |
|
0 commit comments