Skip to content

[ADD] Post-Hoc ensemble fitting #260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 166 additions & 37 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ def __init__(
self.trajectory: Optional[List] = None
self.dataset_name: Optional[str] = None
self.cv_models_: Dict = {}
self.precision: Optional[int] = None
self.opt_metric: Optional[str] = None
self.dataset: Optional[BaseDataset] = None

# By default try to use the TCP logging port or get a new port
self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
Expand Down Expand Up @@ -412,6 +415,25 @@ def _close_dask_client(self) -> None:
self._is_dask_client_internally_created = False
del self._is_dask_client_internally_created

def _cleanup(self) -> None:
"""

Closes the different servers created during api search.

Returns:
None
"""
if self._logger is not None:
self._logger.info("Closing the dask infrastructure")
self._close_dask_client()
self._logger.info("Finished closing the dask infrastructure")

# Clean up the logger
self._logger.info("Starting to clean up the logger")
self._clean_logger()
else:
self._close_dask_client()

def _load_models(self) -> bool:

"""
Expand Down Expand Up @@ -783,7 +805,7 @@ def _search(
metrics supporting current task will be calculated
for each pipeline and results will be available via cv_results
precision (int), (default=32): Numeric precision used when loading
ensemble data. Can be either '16', '32' or '64'.
ensemble data. Can be either 16, 32 or 64.
disable_file_output (Union[bool, List]):
load_models (bool), (default=True): Whether to load the
models after fitting AutoPyTorch.
Expand Down Expand Up @@ -910,6 +932,8 @@ def _search(
self._stopwatch.stop_task(traditional_task_name)

# ============> Starting ensemble
self.precision = precision
self.opt_metric = optimize_metric
elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name)
time_left_for_ensembles = max(0, total_walltime_limit - elapsed_time)
proc_ensemble = None
Expand All @@ -926,28 +950,12 @@ def _search(
self._logger.info("Starting ensemble")
ensemble_task_name = 'ensemble'
self._stopwatch.start_task(ensemble_task_name)
proc_ensemble = EnsembleBuilderManager(
start_time=time.time(),
time_left_for_ensembles=time_left_for_ensembles,
backend=copy.deepcopy(self._backend),
dataset_name=str(dataset.dataset_name),
output_type=STRING_TO_OUTPUT_TYPES[dataset.output_type],
task_type=STRING_TO_TASK_TYPES[self.task_type],
metrics=[self._metric],
opt_metric=optimize_metric,
ensemble_size=self.ensemble_size,
ensemble_nbest=self.ensemble_nbest,
max_models_on_disc=self.max_models_on_disc,
seed=self.seed,
max_iterations=None,
read_at_most=sys.maxsize,
ensemble_memory_limit=self._memory_limit,
random_state=self.seed,
precision=precision,
logger_port=self._logger_port,
pynisher_context=self._multiprocessing_context,
)
self._stopwatch.stop_task(ensemble_task_name)
proc_ensemble = self._init_ensemble_builder(time_left_for_ensembles=time_left_for_ensembles,
ensemble_size=self.ensemble_size,
ensemble_nbest=self.ensemble_nbest,
precision=precision,
optimize_metric=self.opt_metric
)

# ==> Run SMAC
smac_task_name: str = 'runSMAC'
Expand Down Expand Up @@ -1028,18 +1036,12 @@ def _search(
pd.DataFrame(self.ensemble_performance_history).to_json(
os.path.join(self._backend.internals_directory, 'ensemble_history.json'))

self._logger.info("Closing the dask infrastructure")
self._close_dask_client()
self._logger.info("Finished closing the dask infrastructure")

if load_models:
self._logger.info("Loading models...")
self._load_models()
self._logger.info("Finished loading models...")

# Clean up the logger
self._logger.info("Starting to clean up the logger")
self._clean_logger()
self._cleanup()

return self

Expand Down Expand Up @@ -1114,7 +1116,7 @@ def refit(
# the ordering of the data.
fit_and_suppress_warnings(self._logger, model, X, y=None)

self._clean_logger()
self._cleanup()

return self

Expand Down Expand Up @@ -1179,9 +1181,139 @@ def fit(self,

fit_and_suppress_warnings(self._logger, pipeline, X, y=None)

self._clean_logger()
self._cleanup()
return pipeline

def fit_ensemble(
self,
ensemble_nbest: int = 50,
ensemble_size: int = 50,
precision: int = 32,
load_models: bool = True
) -> 'BaseTask':
"""
Enables post-hoc fitting of the ensemble after the `search()`
method is finished. This method creates an ensemble using all
the models stored on disk during the smbo run
Args:
ensemble_nbest (Optional[int]):
only consider the ensemble_nbest models to build the ensemble.
If None, uses the value stored in class attribute `ensemble_nbest`.
ensemble_size (int) (default=50):
Number of models added to the ensemble built by
Ensemble selection from libraries of models.
Models are drawn with replacement.
precision (int), (default=32): Numeric precision used when loading
ensemble data. Can be either 16, 32 or 64.

Returns:
self
"""
# Make sure that input is valid
if self.dataset is None or self.opt_metric is None:
raise ValueError("fit_ensemble() can only be called after `search()`. "
"Please call the `search()` method of {} prior to "
"fit_ensemble().".format(self.__class__.__name__))

if self._logger is None:
self._logger = self._get_logger(self.dataset.dataset_name)

# Create a client if needed
if self._dask_client is None:
self._create_dask_client()
else:
self._is_dask_client_internally_created = False

manager = self._init_ensemble_builder(
time_left_for_ensembles=self._time_for_task,
optimize_metric=self.opt_metric,
precision=precision,
ensemble_size=ensemble_size,
ensemble_nbest=ensemble_nbest,
)

manager.build_ensemble(self._dask_client)
future = manager.futures.pop()
result = future.result()
if result is None:
raise ValueError("Errors occurred while building the ensemble - please"
" check the log file and command line output for error messages.")
self.ensemble_performance_history, _, _, _ = result

if load_models:
self._load_models()
self._cleanup()
return self

def _init_ensemble_builder(
self,
time_left_for_ensembles: float,
optimize_metric: str,
ensemble_nbest: int,
ensemble_size: int,
precision: int = 32,
) -> EnsembleBuilderManager:
"""
Initializes an `EnsembleBuilderManager`.

Args:
time_left_for_ensembles (float):
Time (in seconds) allocated to building the ensemble
optimize_metric (str):
Name of the metric to optimize the ensemble.
ensemble_nbest (int):
only consider the ensemble_nbest models to build the ensemble.
ensemble_size (int):
Number of models added to the ensemble built by
Ensemble selection from libraries of models.
Models are drawn with replacement.
precision (int), (default=32): Numeric precision used when loading
ensemble data. Can be either 16, 32 or 64.

Returns:
EnsembleBuilderManager

"""
if self._logger is None:
raise ValueError("logger should be initialized to fit ensemble")
if self.dataset is None:
raise ValueError("ensemble can only be initialised after or during `search()`. "
"Please call the `search()` method of {}.".format(self.__class__.__name__))

self._logger.info("Starting ensemble")
ensemble_task_name = 'ensemble'
self._stopwatch.start_task(ensemble_task_name)

# Use the current thread to start the ensemble builder process
# The function ensemble_builder_process will internally create a ensemble
# builder in the provide dask client
required_dataset_properties = {'task_type': self.task_type,
'output_type': self.dataset.output_type}
proc_ensemble = EnsembleBuilderManager(
start_time=time.time(),
time_left_for_ensembles=time_left_for_ensembles,
backend=copy.deepcopy(self._backend),
dataset_name=str(self.dataset.dataset_name),
output_type=STRING_TO_OUTPUT_TYPES[self.dataset.output_type],
task_type=STRING_TO_TASK_TYPES[self.task_type],
metrics=[self._metric] if self._metric is not None else get_metrics(
dataset_properties=required_dataset_properties, names=[optimize_metric]),
opt_metric=optimize_metric,
ensemble_size=ensemble_size,
ensemble_nbest=ensemble_nbest,
max_models_on_disc=self.max_models_on_disc,
seed=self.seed,
max_iterations=None,
read_at_most=sys.maxsize,
ensemble_memory_limit=self._memory_limit,
random_state=self.seed,
precision=precision,
logger_port=self._logger_port,
pynisher_context=self._multiprocessing_context,
)
self._stopwatch.stop_task(ensemble_task_name)
return proc_ensemble

def predict(
self,
X_test: np.ndarray,
Expand Down Expand Up @@ -1230,7 +1362,7 @@ def predict(

predictions = self.ensemble_.predict(all_predictions)

self._clean_logger()
self._cleanup()

return predictions

Expand Down Expand Up @@ -1267,10 +1399,7 @@ def __getstate__(self) -> Dict[str, Any]:
return self.__dict__

def __del__(self) -> None:
# Clean up the logger
self._clean_logger()

self._close_dask_client()
self._cleanup()

# When a multiprocessing work is done, the
# objects are deleted. We don't want to delete run areas
Expand Down
5 changes: 1 addition & 4 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,7 @@ def __init__(
val_transforms (Optional[torchvision.transforms.Compose]):
Additional Transforms to be applied to the validation/test data
"""
self.dataset_name = dataset_name

if self.dataset_name is None:
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
self.dataset_name: str = dataset_name if dataset_name is not None else str(uuid.uuid1(clock_seq=os.getpid()))

if not hasattr(train_tensors[0], 'shape'):
type_check(train_tensors, val_tensors)
Expand Down
Loading