2121import os
2222import platform
2323import sys
24- import tempfile
2524import time
2625import types
2726import uuid
3736import sklearn .utils
3837from ConfigSpace .configuration_space import Configuration , ConfigurationSpace
3938from ConfigSpace .read_and_write import json as cs_json
40- from dask .distributed import Client , LocalCluster
39+ from dask .distributed import Client
4140from scipy .sparse import spmatrix
4241from sklearn .base import BaseEstimator
4342from sklearn .dummy import DummyClassifier , DummyRegressor
105104from autosklearn .pipeline .components .regression import RegressorChoice
106105from autosklearn .smbo import AutoMLSMBO
107106from autosklearn .util import RE_PATTERN , pipeline
107+ from autosklearn .util .dask import Dask , LocalDask , UserDask
108108from autosklearn .util .data import (
109109 DatasetCompressionSpec ,
110110 default_dataset_compression_arg ,
120120 warnings_to ,
121121)
122122from autosklearn .util .parallel import preload_modules
123- from autosklearn .util .single_thread_client import SingleThreadedClient
124123from autosklearn .util .smac_wrap import SMACCallback , SmacRunCallback
125124from autosklearn .util .stopwatch import StopWatch
126125
@@ -299,21 +298,22 @@ def __init__(
299298 self ._initial_configurations_via_metalearning = (
300299 initial_configurations_via_metalearning
301300 )
301+ self ._n_jobs = n_jobs
302302
303303 self ._scoring_functions = scoring_functions or []
304304 self ._resampling_strategy_arguments = resampling_strategy_arguments or {}
305+ self ._multiprocessing_context = "forkserver"
305306
306307 # Single core, local runs should use fork to prevent the __main__ requirements
307308 # in examples. Nevertheless, multi-process runs have spawn as requirement to
308309 # reduce the possibility of a deadlock
309- if n_jobs == 1 and dask_client is None :
310- self ._multiprocessing_context = "fork"
311- self ._dask_client = SingleThreadedClient ()
312- self ._n_jobs = 1
310+ self ._dask : Dask
311+ if dask_client is not None :
312+ self ._dask = UserDask (client = dask_client )
313313 else :
314- self ._multiprocessing_context = "forkserver"
315- self . _dask_client = dask_client
316- self ._n_jobs = n_jobs
314+ self ._dask = LocalDask ( n_jobs = n_jobs )
315+ if n_jobs == 1 :
316+ self ._multiprocessing_context = "fork"
317317
318318 # Create the backend
319319 self ._backend : Backend = create (
@@ -350,38 +350,6 @@ def __init__(
350350 self .num_run = 0
351351 self .fitted = False
352352
353- def _create_dask_client (self ) -> None :
354- self ._is_dask_client_internally_created = True
355- self ._dask_client = Client (
356- LocalCluster (
357- n_workers = self ._n_jobs ,
358- processes = False ,
359- threads_per_worker = 1 ,
360- # We use the temporal directory to save the
361- # dask workers, because deleting workers takes
362- # more time than deleting backend directories
363- # This prevent an error saying that the worker
364- # file was deleted, so the client could not close
365- # the worker properly
366- local_directory = tempfile .gettempdir (),
367- # Memory is handled by the pynisher, not by the dask worker/nanny
368- memory_limit = 0 ,
369- ),
370- # Heartbeat every 10s
371- heartbeat_interval = 10000 ,
372- )
373-
374- def _close_dask_client (self , force : bool = False ) -> None :
375- if getattr (self , "_dask_client" , None ) is not None and (
376- force or getattr (self , "_is_dask_client_internally_created" , False )
377- ):
378- self ._dask_client .shutdown ()
379- self ._dask_client .close ()
380- del self ._dask_client
381- self ._dask_client = None
382- self ._is_dask_client_internally_created = False
383- del self ._is_dask_client_internally_created
384-
385353 def _get_logger (self , name : str ) -> PicklableClientLogger :
386354 logger_name = "AutoML(%d):%s" % (self ._seed , name )
387355
@@ -747,17 +715,6 @@ def fit(
747715 "autosklearn.metrics.Scorer."
748716 )
749717
750- # If no dask client was provided, we create one, so that we can
751- # start a ensemble process in parallel to smbo optimize
752- if self ._dask_client is None and (
753- self ._ensemble_class is not None
754- or self ._n_jobs is not None
755- and self ._n_jobs > 1
756- ):
757- self ._create_dask_client ()
758- else :
759- self ._is_dask_client_internally_created = False
760-
761718 self ._dataset_name = dataset_name
762719 self ._stopwatch .start (self ._dataset_name )
763720
@@ -902,70 +859,85 @@ def fit(
902859 )
903860
904861 n_meta_configs = self ._initial_configurations_via_metalearning
905- _proc_smac = AutoMLSMBO (
906- config_space = self .configuration_space ,
907- dataset_name = self ._dataset_name ,
908- backend = self ._backend ,
909- total_walltime_limit = time_left ,
910- func_eval_time_limit = per_run_time_limit ,
911- memory_limit = self ._memory_limit ,
912- data_memory_limit = self ._data_memory_limit ,
913- stopwatch = self ._stopwatch ,
914- n_jobs = self ._n_jobs ,
915- dask_client = self ._dask_client ,
916- start_num_run = self .num_run ,
917- num_metalearning_cfgs = n_meta_configs ,
918- config_file = configspace_path ,
919- seed = self ._seed ,
920- metadata_directory = self ._metadata_directory ,
921- metrics = self ._metrics ,
922- resampling_strategy = self ._resampling_strategy ,
923- resampling_strategy_args = self ._resampling_strategy_arguments ,
924- include = self ._include ,
925- exclude = self ._exclude ,
926- disable_file_output = self ._disable_evaluator_output ,
927- get_smac_object_callback = self ._get_smac_object_callback ,
928- smac_scenario_args = self ._smac_scenario_args ,
929- scoring_functions = self ._scoring_functions ,
930- port = self ._logger_port ,
931- pynisher_context = self ._multiprocessing_context ,
932- ensemble_callback = proc_ensemble ,
933- trials_callback = self ._get_trials_callback ,
934- )
862+ with self ._dask as dask_client :
863+ resamp_args = self ._resampling_strategy_arguments
864+ _proc_smac = AutoMLSMBO (
865+ config_space = self .configuration_space ,
866+ dataset_name = self ._dataset_name ,
867+ backend = self ._backend ,
868+ total_walltime_limit = time_left ,
869+ func_eval_time_limit = per_run_time_limit ,
870+ memory_limit = self ._memory_limit ,
871+ data_memory_limit = self ._data_memory_limit ,
872+ stopwatch = self ._stopwatch ,
873+ n_jobs = self ._n_jobs ,
874+ dask_client = dask_client ,
875+ start_num_run = self .num_run ,
876+ num_metalearning_cfgs = n_meta_configs ,
877+ config_file = configspace_path ,
878+ seed = self ._seed ,
879+ metadata_directory = self ._metadata_directory ,
880+ metrics = self ._metrics ,
881+ resampling_strategy = self ._resampling_strategy ,
882+ resampling_strategy_args = resamp_args ,
883+ include = self ._include ,
884+ exclude = self ._exclude ,
885+ disable_file_output = self ._disable_evaluator_output ,
886+ get_smac_object_callback = self ._get_smac_object_callback ,
887+ smac_scenario_args = self ._smac_scenario_args ,
888+ scoring_functions = self ._scoring_functions ,
889+ port = self ._logger_port ,
890+ pynisher_context = self ._multiprocessing_context ,
891+ ensemble_callback = proc_ensemble ,
892+ trials_callback = self ._get_trials_callback ,
893+ )
935894
936- (
937- self .runhistory_ ,
938- self .trajectory_ ,
939- self ._budget_type ,
940- ) = _proc_smac .run_smbo ()
941- trajectory_filename = os .path .join (
942- self ._backend .get_smac_output_directory_for_run (self ._seed ),
943- "trajectory.json" ,
944- )
945- saveable_trajectory = [
946- list (entry [:2 ]) + [entry [2 ].get_dictionary ()] + list (entry [3 :])
947- for entry in self .trajectory_
948- ]
949- with open (trajectory_filename , "w" ) as fh :
950- json .dump (saveable_trajectory , fh )
951-
952- self ._logger .info ("Starting shutdown..." )
953- # Wait until the ensemble process is finished to avoid shutting down
954- # while the ensemble builder tries to access the data
955- if proc_ensemble is not None :
956- self .ensemble_performance_history = list (proc_ensemble .history )
957-
958- if len (proc_ensemble .futures ) > 0 :
959- # Now we wait for the future to return as it cannot be cancelled
960- # while it is running: https://stackoverflow.com/a/49203129
961- self ._logger .info (
962- "Ensemble script still running, waiting for it to finish."
963- )
964- result = proc_ensemble .futures .pop ().result ()
965- if result :
966- ensemble_history , _ = result
967- self .ensemble_performance_history .extend (ensemble_history )
968- self ._logger .info ("Ensemble script finished, continue shutdown." )
895+ (
896+ self .runhistory_ ,
897+ self .trajectory_ ,
898+ self ._budget_type ,
899+ ) = _proc_smac .run_smbo ()
900+
901+ trajectory_filename = os .path .join (
902+ self ._backend .get_smac_output_directory_for_run (self ._seed ),
903+ "trajectory.json" ,
904+ )
905+ saveable_trajectory = [
906+ list (entry [:2 ])
907+ + [entry [2 ].get_dictionary ()]
908+ + list (entry [3 :])
909+ for entry in self .trajectory_
910+ ]
911+ with open (trajectory_filename , "w" ) as fh :
912+ json .dump (saveable_trajectory , fh )
913+
914+ self ._logger .info ("Starting shutdown..." )
915+ # Wait until the ensemble process is finished to avoid shutting
916+ # down while the ensemble builder tries to access the data
917+ if proc_ensemble is not None :
918+ self .ensemble_performance_history = list (
919+ proc_ensemble .history
920+ )
921+
922+ if len (proc_ensemble .futures ) > 0 :
923+ # Now we wait for the future to return as it cannot be
924+ # cancelled while it is running
925+ # * https://stackoverflow.com/a/49203129
926+ self ._logger .info (
927+ "Ensemble script still running,"
928+ " waiting for it to finish."
929+ )
930+ result = proc_ensemble .futures .pop ().result ()
931+
932+ if result :
933+ ensemble_history , _ = result
934+ self .ensemble_performance_history .extend (
935+ ensemble_history
936+ )
937+
938+ self ._logger .info (
939+ "Ensemble script finished, continue shutdown."
940+ )
969941
970942 # save the ensemble performance history file
971943 if len (self .ensemble_performance_history ) > 0 :
@@ -1054,7 +1026,7 @@ def _log_fit_setup(self) -> None:
10541026 self ._logger .debug (
10551027 " multiprocessing_context: %s" , str (self ._multiprocessing_context )
10561028 )
1057- self ._logger .debug (" dask_client: %s" , str (self ._dask_client ))
1029+ self ._logger .debug (" dask_client: %s" , str (self ._dask ))
10581030 self ._logger .debug (" precision: %s" , str (self .precision ))
10591031 self ._logger .debug (
10601032 " disable_evaluator_output: %s" , str (self ._disable_evaluator_output )
@@ -1090,7 +1062,6 @@ def __sklearn_is_fitted__(self) -> bool:
10901062
10911063 def _fit_cleanup (self ) -> None :
10921064 self ._logger .info ("Closing the dask infrastructure" )
1093- self ._close_dask_client ()
10941065 self ._logger .info ("Finished closing the dask infrastructure" )
10951066
10961067 # Clean up the logger
@@ -1555,48 +1526,48 @@ def fit_ensemble(
15551526 # Make sure that input is valid
15561527 y = self .InputValidator .target_validator .transform (y )
15571528
1558- # Create a client if needed
1559- if self ._dask_client is None :
1560- self ._create_dask_client ()
1561- else :
1562- self ._is_dask_client_internally_created = False
1563-
15641529 metrics = metrics if metrics is not None else self ._metrics
15651530 if not isinstance (metrics , Sequence ):
15661531 metrics = [metrics ]
15671532
15681533 # Use the current thread to start the ensemble builder process
15691534 # The function ensemble_builder_process will internally create a ensemble
15701535 # builder in the provide dask client
1571- manager = EnsembleBuilderManager (
1572- start_time = time .time (),
1573- time_left_for_ensembles = self ._time_for_task ,
1574- backend = copy .deepcopy (self ._backend ),
1575- dataset_name = dataset_name if dataset_name else self ._dataset_name ,
1576- task = task if task else self ._task ,
1577- metrics = metrics if metrics is not None else self ._metrics ,
1578- ensemble_class = (
1579- ensemble_class if ensemble_class is not None else self ._ensemble_class
1580- ),
1581- ensemble_kwargs = (
1582- ensemble_kwargs
1583- if ensemble_kwargs is not None
1584- else self ._ensemble_kwargs
1585- ),
1586- ensemble_nbest = ensemble_nbest if ensemble_nbest else self ._ensemble_nbest ,
1587- max_models_on_disc = self ._max_models_on_disc ,
1588- seed = self ._seed ,
1589- precision = precision if precision else self .precision ,
1590- max_iterations = 1 ,
1591- read_at_most = None ,
1592- memory_limit = self ._memory_limit ,
1593- random_state = self ._seed ,
1594- logger_port = self ._logger_port ,
1595- pynisher_context = self ._multiprocessing_context ,
1596- )
1597- manager .build_ensemble (self ._dask_client )
1598- future = manager .futures .pop ()
1599- result = future .result ()
1536+ with self ._dask as dask_client :
1537+ manager = EnsembleBuilderManager (
1538+ start_time = time .time (),
1539+ time_left_for_ensembles = self ._time_for_task ,
1540+ backend = copy .deepcopy (self ._backend ),
1541+ dataset_name = dataset_name if dataset_name else self ._dataset_name ,
1542+ task = task if task else self ._task ,
1543+ metrics = metrics if metrics is not None else self ._metrics ,
1544+ ensemble_class = (
1545+ ensemble_class
1546+ if ensemble_class is not None
1547+ else self ._ensemble_class
1548+ ),
1549+ ensemble_kwargs = (
1550+ ensemble_kwargs
1551+ if ensemble_kwargs is not None
1552+ else self ._ensemble_kwargs
1553+ ),
1554+ ensemble_nbest = ensemble_nbest
1555+ if ensemble_nbest
1556+ else self ._ensemble_nbest ,
1557+ max_models_on_disc = self ._max_models_on_disc ,
1558+ seed = self ._seed ,
1559+ precision = precision if precision else self .precision ,
1560+ max_iterations = 1 ,
1561+ read_at_most = None ,
1562+ memory_limit = self ._memory_limit ,
1563+ random_state = self ._seed ,
1564+ logger_port = self ._logger_port ,
1565+ pynisher_context = self ._multiprocessing_context ,
1566+ )
1567+ manager .build_ensemble (dask_client )
1568+ future = manager .futures .pop ()
1569+ result = future .result ()
1570+
16001571 if result is None :
16011572 raise ValueError (
16021573 "Error building the ensemble - please check the log file and command "
@@ -1606,7 +1577,6 @@ def fit_ensemble(
16061577 self ._ensemble_class = ensemble_class
16071578
16081579 self ._load_models ()
1609- self ._close_dask_client ()
16101580 return self
16111581
16121582 def _load_models (self ):
@@ -2295,7 +2265,7 @@ def _create_search_space(
22952265
22962266 def __getstate__ (self ) -> dict [str , Any ]:
22972267 # Cannot serialize a client!
2298- self ._dask_client = None
2268+ self ._dask = None
22992269 self .logging_server = None
23002270 self .stop_logging_server = None
23012271 return self .__dict__
@@ -2304,8 +2274,6 @@ def __del__(self) -> None:
23042274 # Clean up the logger
23052275 self ._clean_logger ()
23062276
2307- self ._close_dask_client ()
2308-
23092277
23102278class AutoMLClassifier (AutoML ):
23112279
0 commit comments