2929from smac .stats .stats import Stats
3030from smac .tae import StatusType
3131
32+ from autoPyTorch .api .results_manager import ResultsManager , SearchResults
3233from autoPyTorch .automl_common .common .utils .backend import Backend , create
3334from autoPyTorch .constants import (
3435 REGRESSION_TASKS ,
@@ -192,12 +193,13 @@ def __init__(
192193 self .search_space : Optional [ConfigurationSpace ] = None
193194 self ._dataset_requirements : Optional [List [FitRequirement ]] = None
194195 self ._metric : Optional [autoPyTorchMetric ] = None
196+ self ._scoring_functions : Optional [List [autoPyTorchMetric ]] = None
195197 self ._logger : Optional [PicklableClientLogger ] = None
196- self .run_history : RunHistory = RunHistory ()
197- self .trajectory : Optional [List ] = None
198198 self .dataset_name : Optional [str ] = None
199199 self .cv_models_ : Dict = {}
200200
201+ self ._results_manager = ResultsManager ()
202+
201203 # By default try to use the TCP logging port or get a new port
202204 self ._logger_port = logging .handlers .DEFAULT_TCP_LOGGING_PORT
203205
@@ -240,6 +242,18 @@ def build_pipeline(self, dataset_properties: Dict[str, Any]) -> BasePipeline:
240242 """
241243 raise NotImplementedError
242244
245+ @property
246+ def run_history (self ) -> RunHistory :
247+ return self ._results_manager .run_history
248+
249+ @property
250+ def ensemble_performance_history (self ) -> List [Dict [str , Any ]]:
251+ return self ._results_manager .ensemble_performance_history
252+
253+ @property
254+ def trajectory (self ) -> Optional [List ]:
255+ return self ._results_manager .trajectory
256+
243257 def set_pipeline_config (self , ** pipeline_config_kwargs : Any ) -> None :
244258 """
245259 Check whether arguments are valid and
@@ -883,6 +897,12 @@ def _search(
883897
884898 self .pipeline_options ['optimize_metric' ] = optimize_metric
885899
900+ if all_supported_metrics :
901+ self ._scoring_functions = get_metrics (dataset_properties = dataset_properties ,
902+ all_supported_metrics = True )
903+ else :
904+ self ._scoring_functions = [self ._metric ]
905+
886906 self .search_space = self .get_search_space (dataset )
887907
888908 # Incorporate budget to pipeline config
@@ -1037,12 +1057,14 @@ def _search(
10371057 pynisher_context = self ._multiprocessing_context ,
10381058 )
10391059 try :
1040- run_history , self .trajectory , budget_type = \
1060+ run_history , self ._results_manager . trajectory , budget_type = \
10411061 _proc_smac .run_smbo (func = tae_func )
10421062 self .run_history .update (run_history , DataOrigin .INTERNAL )
10431063 trajectory_filename = os .path .join (
10441064 self ._backend .get_smac_output_directory_for_run (self .seed ),
10451065 'trajectory.json' )
1066+
1067+ assert self .trajectory is not None # mypy check
10461068 saveable_trajectory = \
10471069 [list (entry [:2 ]) + [entry [2 ].get_dictionary ()] + list (entry [3 :])
10481070 for entry in self .trajectory ]
@@ -1059,7 +1081,7 @@ def _search(
10591081 self ._logger .info ("Starting Shutdown" )
10601082
10611083 if proc_ensemble is not None :
1062- self .ensemble_performance_history = list (proc_ensemble .history )
1084+ self ._results_manager . ensemble_performance_history = list (proc_ensemble .history )
10631085
10641086 if len (proc_ensemble .futures ) > 0 :
10651087 # Also add ensemble runs that did not finish within smac time
@@ -1068,7 +1090,7 @@ def _search(
10681090 result = proc_ensemble .futures .pop ().result ()
10691091 if result :
10701092 ensemble_history , _ , _ , _ = result
1071- self .ensemble_performance_history .extend (ensemble_history )
1093+ self ._results_manager . ensemble_performance_history .extend (ensemble_history )
10721094 self ._logger .info ("Ensemble script finished, continue shutdown." )
10731095
10741096 # save the ensemble performance history file
@@ -1356,28 +1378,12 @@ def get_incumbent_results(
13561378 The incumbent configuration
13571379 Dict[str, Union[int, str, float]]:
13581380 Additional information about the run of the incumbent configuration.
1359-
13601381 """
1361- assert self .run_history is not None , "No Run History found, search has not been called."
1362- if self .run_history .empty ():
1363- raise ValueError ("Run History is empty. Something went wrong, "
1364- "smac was not able to fit any model?" )
1365-
1366- run_history_data = self .run_history .data
1367- if not include_traditional :
1368- # traditional classifiers have trainer_configuration in their additional info
1369- run_history_data = dict (
1370- filter (lambda elem : elem [1 ].status == StatusType .SUCCESS and elem [1 ].
1371- additional_info is not None and elem [1 ].
1372- additional_info ['configuration_origin' ] != 'traditional' ,
1373- run_history_data .items ()))
1374- run_history_data = dict (
1375- filter (lambda elem : 'SUCCESS' in str (elem [1 ].status ), run_history_data .items ()))
1376- sorted_runvalue_by_cost = sorted (run_history_data .items (), key = lambda item : item [1 ].cost )
1377- incumbent_run_key , incumbent_run_value = sorted_runvalue_by_cost [0 ]
1378- incumbent_config = self .run_history .ids_config [incumbent_run_key .config_id ]
1379- incumbent_results = incumbent_run_value .additional_info
1380- return incumbent_config , incumbent_results
1382+
1383+ if self ._metric is None :
1384+ raise RuntimeError ("`search_results` is only available after a search has finished." )
1385+
1386+ return self ._results_manager .get_incumbent_results (metric = self ._metric , include_traditional = include_traditional )
13811387
13821388 def get_models_with_weights (self ) -> List :
13831389 if self .models_ is None or len (self .models_ ) == 0 or \
@@ -1417,3 +1423,43 @@ def _print_debug_info_to_log(self) -> None:
14171423 self ._logger .debug (' multiprocessing_context: %s' , str (self ._multiprocessing_context ))
14181424 for key , value in vars (self ).items ():
14191425 self ._logger .debug (f"\t { key } ->{ value } " )
1426+
1427+ def get_search_results (self ) -> SearchResults :
1428+ """
1429+ Get the interface to obtain the search results easily.
1430+ """
1431+ if self ._scoring_functions is None or self ._metric is None :
1432+ raise RuntimeError ("`search_results` is only available after a search has finished." )
1433+
1434+ return self ._results_manager .get_search_results (
1435+ metric = self ._metric ,
1436+ scoring_functions = self ._scoring_functions
1437+ )
1438+
1439+ def sprint_statistics (self ) -> str :
1440+ """
1441+ Prints statistics about the SMAC search.
1442+
1443+ These statistics include:
1444+
1445+ 1. Optimisation Metric
1446+ 2. Best Optimisation score achieved by individual pipelines
1447+ 3. Total number of target algorithm runs
1448+ 4. Total number of successful target algorithm runs
1449+ 5. Total number of crashed target algorithm runs
1450+ 6. Total number of target algorithm runs that exceeded the time limit
1451+ 7. Total number of successful target algorithm runs that exceeded the memory limit
1452+
1453+ Returns:
1454+ (str):
1455+ Formatted string with statistics
1456+ """
1457+ if self ._scoring_functions is None or self ._metric is None :
1458+ raise RuntimeError ("`search_results` is only available after a search has finished." )
1459+
1460+ assert self .dataset_name is not None # my check
1461+ return self ._results_manager .sprint_statistics (
1462+ dataset_name = self .dataset_name ,
1463+ scoring_functions = self ._scoring_functions ,
1464+ metric = self ._metric
1465+ )
0 commit comments