Skip to content

Commit a66361b

Browse files
committed
fix-1532-_ERROR_-asyncio.exceptions.CancelledError (#1540)
* Create PR * Abstract out dask client types * Fix _ issue * Extend scope of dask_client in automl.py * Add docstring to dask module * Indent result addition * Add basic tests for Dask wrappers
1 parent a2e63c8 commit a66361b

File tree

5 files changed

+347
-160
lines changed

5 files changed

+347
-160
lines changed

autosklearn/automl.py

Lines changed: 125 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import os
2222
import platform
2323
import sys
24-
import tempfile
2524
import time
2625
import types
2726
import uuid
@@ -37,7 +36,7 @@
3736
import sklearn.utils
3837
from ConfigSpace.configuration_space import Configuration, ConfigurationSpace
3938
from ConfigSpace.read_and_write import json as cs_json
40-
from dask.distributed import Client, LocalCluster
39+
from dask.distributed import Client
4140
from scipy.sparse import spmatrix
4241
from sklearn.base import BaseEstimator
4342
from sklearn.dummy import DummyClassifier, DummyRegressor
@@ -105,6 +104,7 @@
105104
from autosklearn.pipeline.components.regression import RegressorChoice
106105
from autosklearn.smbo import AutoMLSMBO
107106
from autosklearn.util import RE_PATTERN, pipeline
107+
from autosklearn.util.dask import Dask, LocalDask, UserDask
108108
from autosklearn.util.data import (
109109
DatasetCompressionSpec,
110110
default_dataset_compression_arg,
@@ -120,7 +120,6 @@
120120
warnings_to,
121121
)
122122
from autosklearn.util.parallel import preload_modules
123-
from autosklearn.util.single_thread_client import SingleThreadedClient
124123
from autosklearn.util.smac_wrap import SMACCallback, SmacRunCallback
125124
from autosklearn.util.stopwatch import StopWatch
126125

@@ -299,21 +298,22 @@ def __init__(
299298
self._initial_configurations_via_metalearning = (
300299
initial_configurations_via_metalearning
301300
)
301+
self._n_jobs = n_jobs
302302

303303
self._scoring_functions = scoring_functions or []
304304
self._resampling_strategy_arguments = resampling_strategy_arguments or {}
305+
self._multiprocessing_context = "forkserver"
305306

306307
# Single core, local runs should use fork to prevent the __main__ requirements
307308
# in examples. Nevertheless, multi-process runs have spawn as requirement to
308309
# reduce the possibility of a deadlock
309-
if n_jobs == 1 and dask_client is None:
310-
self._multiprocessing_context = "fork"
311-
self._dask_client = SingleThreadedClient()
312-
self._n_jobs = 1
310+
self._dask: Dask
311+
if dask_client is not None:
312+
self._dask = UserDask(client=dask_client)
313313
else:
314-
self._multiprocessing_context = "forkserver"
315-
self._dask_client = dask_client
316-
self._n_jobs = n_jobs
314+
self._dask = LocalDask(n_jobs=n_jobs)
315+
if n_jobs == 1:
316+
self._multiprocessing_context = "fork"
317317

318318
# Create the backend
319319
self._backend: Backend = create(
@@ -350,38 +350,6 @@ def __init__(
350350
self.num_run = 0
351351
self.fitted = False
352352

353-
def _create_dask_client(self) -> None:
354-
self._is_dask_client_internally_created = True
355-
self._dask_client = Client(
356-
LocalCluster(
357-
n_workers=self._n_jobs,
358-
processes=False,
359-
threads_per_worker=1,
360-
# We use the temporal directory to save the
361-
# dask workers, because deleting workers takes
362-
# more time than deleting backend directories
363-
# This prevent an error saying that the worker
364-
# file was deleted, so the client could not close
365-
# the worker properly
366-
local_directory=tempfile.gettempdir(),
367-
# Memory is handled by the pynisher, not by the dask worker/nanny
368-
memory_limit=0,
369-
),
370-
# Heartbeat every 10s
371-
heartbeat_interval=10000,
372-
)
373-
374-
def _close_dask_client(self, force: bool = False) -> None:
375-
if getattr(self, "_dask_client", None) is not None and (
376-
force or getattr(self, "_is_dask_client_internally_created", False)
377-
):
378-
self._dask_client.shutdown()
379-
self._dask_client.close()
380-
del self._dask_client
381-
self._dask_client = None
382-
self._is_dask_client_internally_created = False
383-
del self._is_dask_client_internally_created
384-
385353
def _get_logger(self, name: str) -> PicklableClientLogger:
386354
logger_name = "AutoML(%d):%s" % (self._seed, name)
387355

@@ -747,17 +715,6 @@ def fit(
747715
"autosklearn.metrics.Scorer."
748716
)
749717

750-
# If no dask client was provided, we create one, so that we can
751-
# start a ensemble process in parallel to smbo optimize
752-
if self._dask_client is None and (
753-
self._ensemble_class is not None
754-
or self._n_jobs is not None
755-
and self._n_jobs > 1
756-
):
757-
self._create_dask_client()
758-
else:
759-
self._is_dask_client_internally_created = False
760-
761718
self._dataset_name = dataset_name
762719
self._stopwatch.start(self._dataset_name)
763720

@@ -902,70 +859,85 @@ def fit(
902859
)
903860

904861
n_meta_configs = self._initial_configurations_via_metalearning
905-
_proc_smac = AutoMLSMBO(
906-
config_space=self.configuration_space,
907-
dataset_name=self._dataset_name,
908-
backend=self._backend,
909-
total_walltime_limit=time_left,
910-
func_eval_time_limit=per_run_time_limit,
911-
memory_limit=self._memory_limit,
912-
data_memory_limit=self._data_memory_limit,
913-
stopwatch=self._stopwatch,
914-
n_jobs=self._n_jobs,
915-
dask_client=self._dask_client,
916-
start_num_run=self.num_run,
917-
num_metalearning_cfgs=n_meta_configs,
918-
config_file=configspace_path,
919-
seed=self._seed,
920-
metadata_directory=self._metadata_directory,
921-
metrics=self._metrics,
922-
resampling_strategy=self._resampling_strategy,
923-
resampling_strategy_args=self._resampling_strategy_arguments,
924-
include=self._include,
925-
exclude=self._exclude,
926-
disable_file_output=self._disable_evaluator_output,
927-
get_smac_object_callback=self._get_smac_object_callback,
928-
smac_scenario_args=self._smac_scenario_args,
929-
scoring_functions=self._scoring_functions,
930-
port=self._logger_port,
931-
pynisher_context=self._multiprocessing_context,
932-
ensemble_callback=proc_ensemble,
933-
trials_callback=self._get_trials_callback,
934-
)
862+
with self._dask as dask_client:
863+
resamp_args = self._resampling_strategy_arguments
864+
_proc_smac = AutoMLSMBO(
865+
config_space=self.configuration_space,
866+
dataset_name=self._dataset_name,
867+
backend=self._backend,
868+
total_walltime_limit=time_left,
869+
func_eval_time_limit=per_run_time_limit,
870+
memory_limit=self._memory_limit,
871+
data_memory_limit=self._data_memory_limit,
872+
stopwatch=self._stopwatch,
873+
n_jobs=self._n_jobs,
874+
dask_client=dask_client,
875+
start_num_run=self.num_run,
876+
num_metalearning_cfgs=n_meta_configs,
877+
config_file=configspace_path,
878+
seed=self._seed,
879+
metadata_directory=self._metadata_directory,
880+
metrics=self._metrics,
881+
resampling_strategy=self._resampling_strategy,
882+
resampling_strategy_args=resamp_args,
883+
include=self._include,
884+
exclude=self._exclude,
885+
disable_file_output=self._disable_evaluator_output,
886+
get_smac_object_callback=self._get_smac_object_callback,
887+
smac_scenario_args=self._smac_scenario_args,
888+
scoring_functions=self._scoring_functions,
889+
port=self._logger_port,
890+
pynisher_context=self._multiprocessing_context,
891+
ensemble_callback=proc_ensemble,
892+
trials_callback=self._get_trials_callback,
893+
)
935894

936-
(
937-
self.runhistory_,
938-
self.trajectory_,
939-
self._budget_type,
940-
) = _proc_smac.run_smbo()
941-
trajectory_filename = os.path.join(
942-
self._backend.get_smac_output_directory_for_run(self._seed),
943-
"trajectory.json",
944-
)
945-
saveable_trajectory = [
946-
list(entry[:2]) + [entry[2].get_dictionary()] + list(entry[3:])
947-
for entry in self.trajectory_
948-
]
949-
with open(trajectory_filename, "w") as fh:
950-
json.dump(saveable_trajectory, fh)
951-
952-
self._logger.info("Starting shutdown...")
953-
# Wait until the ensemble process is finished to avoid shutting down
954-
# while the ensemble builder tries to access the data
955-
if proc_ensemble is not None:
956-
self.ensemble_performance_history = list(proc_ensemble.history)
957-
958-
if len(proc_ensemble.futures) > 0:
959-
# Now we wait for the future to return as it cannot be cancelled
960-
# while it is running: https://stackoverflow.com/a/49203129
961-
self._logger.info(
962-
"Ensemble script still running, waiting for it to finish."
963-
)
964-
result = proc_ensemble.futures.pop().result()
965-
if result:
966-
ensemble_history, _ = result
967-
self.ensemble_performance_history.extend(ensemble_history)
968-
self._logger.info("Ensemble script finished, continue shutdown.")
895+
(
896+
self.runhistory_,
897+
self.trajectory_,
898+
self._budget_type,
899+
) = _proc_smac.run_smbo()
900+
901+
trajectory_filename = os.path.join(
902+
self._backend.get_smac_output_directory_for_run(self._seed),
903+
"trajectory.json",
904+
)
905+
saveable_trajectory = [
906+
list(entry[:2])
907+
+ [entry[2].get_dictionary()]
908+
+ list(entry[3:])
909+
for entry in self.trajectory_
910+
]
911+
with open(trajectory_filename, "w") as fh:
912+
json.dump(saveable_trajectory, fh)
913+
914+
self._logger.info("Starting shutdown...")
915+
# Wait until the ensemble process is finished to avoid shutting
916+
# down while the ensemble builder tries to access the data
917+
if proc_ensemble is not None:
918+
self.ensemble_performance_history = list(
919+
proc_ensemble.history
920+
)
921+
922+
if len(proc_ensemble.futures) > 0:
923+
# Now we wait for the future to return as it cannot be
924+
# cancelled while it is running
925+
# * https://stackoverflow.com/a/49203129
926+
self._logger.info(
927+
"Ensemble script still running,"
928+
" waiting for it to finish."
929+
)
930+
result = proc_ensemble.futures.pop().result()
931+
932+
if result:
933+
ensemble_history, _ = result
934+
self.ensemble_performance_history.extend(
935+
ensemble_history
936+
)
937+
938+
self._logger.info(
939+
"Ensemble script finished, continue shutdown."
940+
)
969941

970942
# save the ensemble performance history file
971943
if len(self.ensemble_performance_history) > 0:
@@ -1054,7 +1026,7 @@ def _log_fit_setup(self) -> None:
10541026
self._logger.debug(
10551027
" multiprocessing_context: %s", str(self._multiprocessing_context)
10561028
)
1057-
self._logger.debug(" dask_client: %s", str(self._dask_client))
1029+
self._logger.debug(" dask_client: %s", str(self._dask))
10581030
self._logger.debug(" precision: %s", str(self.precision))
10591031
self._logger.debug(
10601032
" disable_evaluator_output: %s", str(self._disable_evaluator_output)
@@ -1090,7 +1062,6 @@ def __sklearn_is_fitted__(self) -> bool:
10901062

10911063
def _fit_cleanup(self) -> None:
10921064
self._logger.info("Closing the dask infrastructure")
1093-
self._close_dask_client()
10941065
self._logger.info("Finished closing the dask infrastructure")
10951066

10961067
# Clean up the logger
@@ -1555,48 +1526,48 @@ def fit_ensemble(
15551526
# Make sure that input is valid
15561527
y = self.InputValidator.target_validator.transform(y)
15571528

1558-
# Create a client if needed
1559-
if self._dask_client is None:
1560-
self._create_dask_client()
1561-
else:
1562-
self._is_dask_client_internally_created = False
1563-
15641529
metrics = metrics if metrics is not None else self._metrics
15651530
if not isinstance(metrics, Sequence):
15661531
metrics = [metrics]
15671532

15681533
# Use the current thread to start the ensemble builder process
15691534
# The function ensemble_builder_process will internally create a ensemble
15701535
# builder in the provide dask client
1571-
manager = EnsembleBuilderManager(
1572-
start_time=time.time(),
1573-
time_left_for_ensembles=self._time_for_task,
1574-
backend=copy.deepcopy(self._backend),
1575-
dataset_name=dataset_name if dataset_name else self._dataset_name,
1576-
task=task if task else self._task,
1577-
metrics=metrics if metrics is not None else self._metrics,
1578-
ensemble_class=(
1579-
ensemble_class if ensemble_class is not None else self._ensemble_class
1580-
),
1581-
ensemble_kwargs=(
1582-
ensemble_kwargs
1583-
if ensemble_kwargs is not None
1584-
else self._ensemble_kwargs
1585-
),
1586-
ensemble_nbest=ensemble_nbest if ensemble_nbest else self._ensemble_nbest,
1587-
max_models_on_disc=self._max_models_on_disc,
1588-
seed=self._seed,
1589-
precision=precision if precision else self.precision,
1590-
max_iterations=1,
1591-
read_at_most=None,
1592-
memory_limit=self._memory_limit,
1593-
random_state=self._seed,
1594-
logger_port=self._logger_port,
1595-
pynisher_context=self._multiprocessing_context,
1596-
)
1597-
manager.build_ensemble(self._dask_client)
1598-
future = manager.futures.pop()
1599-
result = future.result()
1536+
with self._dask as dask_client:
1537+
manager = EnsembleBuilderManager(
1538+
start_time=time.time(),
1539+
time_left_for_ensembles=self._time_for_task,
1540+
backend=copy.deepcopy(self._backend),
1541+
dataset_name=dataset_name if dataset_name else self._dataset_name,
1542+
task=task if task else self._task,
1543+
metrics=metrics if metrics is not None else self._metrics,
1544+
ensemble_class=(
1545+
ensemble_class
1546+
if ensemble_class is not None
1547+
else self._ensemble_class
1548+
),
1549+
ensemble_kwargs=(
1550+
ensemble_kwargs
1551+
if ensemble_kwargs is not None
1552+
else self._ensemble_kwargs
1553+
),
1554+
ensemble_nbest=ensemble_nbest
1555+
if ensemble_nbest
1556+
else self._ensemble_nbest,
1557+
max_models_on_disc=self._max_models_on_disc,
1558+
seed=self._seed,
1559+
precision=precision if precision else self.precision,
1560+
max_iterations=1,
1561+
read_at_most=None,
1562+
memory_limit=self._memory_limit,
1563+
random_state=self._seed,
1564+
logger_port=self._logger_port,
1565+
pynisher_context=self._multiprocessing_context,
1566+
)
1567+
manager.build_ensemble(dask_client)
1568+
future = manager.futures.pop()
1569+
result = future.result()
1570+
16001571
if result is None:
16011572
raise ValueError(
16021573
"Error building the ensemble - please check the log file and command "
@@ -1606,7 +1577,6 @@ def fit_ensemble(
16061577
self._ensemble_class = ensemble_class
16071578

16081579
self._load_models()
1609-
self._close_dask_client()
16101580
return self
16111581

16121582
def _load_models(self):
@@ -2295,7 +2265,7 @@ def _create_search_space(
22952265

22962266
def __getstate__(self) -> dict[str, Any]:
22972267
# Cannot serialize a client!
2298-
self._dask_client = None
2268+
self._dask = None
22992269
self.logging_server = None
23002270
self.stop_logging_server = None
23012271
return self.__dict__
@@ -2304,8 +2274,6 @@ def __del__(self) -> None:
23042274
# Clean up the logger
23052275
self._clean_logger()
23062276

2307-
self._close_dask_client()
2308-
23092277

23102278
class AutoMLClassifier(AutoML):
23112279

0 commit comments

Comments
 (0)