|
120 | 120 | warnings_to,
|
121 | 121 | )
|
122 | 122 | from autosklearn.util.parallel import preload_modules
|
| 123 | +from autosklearn.util.progress_bar import ProgressBar |
123 | 124 | from autosklearn.util.smac_wrap import SMACCallback, SmacRunCallback
|
124 | 125 | from autosklearn.util.stopwatch import StopWatch
|
125 | 126 |
|
@@ -239,6 +240,7 @@ def __init__(
|
239 | 240 | get_trials_callback: SMACCallback | None = None,
|
240 | 241 | dataset_compression: bool | Mapping[str, Any] = True,
|
241 | 242 | allow_string_features: bool = True,
|
| 243 | + disable_progress_bar: bool = False, |
242 | 244 | ):
|
243 | 245 | super().__init__()
|
244 | 246 |
|
@@ -295,6 +297,7 @@ def __init__(
|
295 | 297 | self.logging_config = logging_config
|
296 | 298 | self.precision = precision
|
297 | 299 | self.allow_string_features = allow_string_features
|
| 300 | + self.disable_progress_bar = disable_progress_bar |
298 | 301 | self._initial_configurations_via_metalearning = (
|
299 | 302 | initial_configurations_via_metalearning
|
300 | 303 | )
|
@@ -626,6 +629,12 @@ def fit(
|
626 | 629 | # By default try to use the TCP logging port or get a new port
|
627 | 630 | self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
|
628 | 631 |
|
| 632 | + progress_bar = ProgressBar( |
| 633 | + total=self._time_for_task, |
| 634 | + disable=self.disable_progress_bar, |
| 635 | + desc="Fitting to the training data", |
| 636 | + colour="green", |
| 637 | + ) |
629 | 638 | # Once we start the logging server, it starts in a new process
|
630 | 639 | # If an error occurs then we want to make sure that we exit cleanly
|
631 | 640 | # and shut it down, else it might hang
|
@@ -961,6 +970,7 @@ def fit(
|
961 | 970 | self._logger.exception(e)
|
962 | 971 | raise e
|
963 | 972 | finally:
|
| 973 | + progress_bar.stop() |
964 | 974 | self._fit_cleanup()
|
965 | 975 |
|
966 | 976 | self.fitted = True
|
|
0 commit comments