Skip to content

Commit 6a97f72

Browse files
Show progress bar while fitting to training data (#1606)
* Show progress bar while fitting to training data * Minor fixes for progress bar * Revert accidental changes to requirements.txt * Document changes * Skip type checks for tqdm * Make progress bar more flexible with kwargs * Fix link checker make command in CONTRIBUTE.md * Update doc link to be sphinx compatible * Switch to pytets-forked from pytest-xdist Co-authored-by: Eddie Bergman <[email protected]>
1 parent 305a3ab commit 6a97f72

File tree

8 files changed

+98
-5
lines changed

8 files changed

+98
-5
lines changed

CONTRIBUTING.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,8 @@ Lastly, if the feature really is a game changer or you're very proud of it, cons
252252
make doc
253253
```
254254
* If you're unfamiliar with sphinx, it's a documentation generator which can read comments and docstrings from within the code and generate html documentation.
255-
* If you've added documentation, we also have a command `links` for making sure
256-
all the links correctly go to some destination.
255+
* If you've added documentation, we also have a command `links` for making
256+
sure all the links correctly go to some destination.
257257
This helps tests for dead links or accidental typos.
258258
```bash
259259
make links

autosklearn/automl.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@
120120
warnings_to,
121121
)
122122
from autosklearn.util.parallel import preload_modules
123+
from autosklearn.util.progress_bar import ProgressBar
123124
from autosklearn.util.smac_wrap import SMACCallback, SmacRunCallback
124125
from autosklearn.util.stopwatch import StopWatch
125126

@@ -239,6 +240,7 @@ def __init__(
239240
get_trials_callback: SMACCallback | None = None,
240241
dataset_compression: bool | Mapping[str, Any] = True,
241242
allow_string_features: bool = True,
243+
disable_progress_bar: bool = False,
242244
):
243245
super().__init__()
244246

@@ -295,6 +297,7 @@ def __init__(
295297
self.logging_config = logging_config
296298
self.precision = precision
297299
self.allow_string_features = allow_string_features
300+
self.disable_progress_bar = disable_progress_bar
298301
self._initial_configurations_via_metalearning = (
299302
initial_configurations_via_metalearning
300303
)
@@ -626,6 +629,12 @@ def fit(
626629
# By default try to use the TCP logging port or get a new port
627630
self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
628631

632+
progress_bar = ProgressBar(
633+
total=self._time_for_task,
634+
disable=self.disable_progress_bar,
635+
desc="Fitting to the training data",
636+
colour="green",
637+
)
629638
# Once we start the logging server, it starts in a new process
630639
# If an error occurs then we want to make sure that we exit cleanly
631640
# and shut it down, else it might hang
@@ -961,6 +970,7 @@ def fit(
961970
self._logger.exception(e)
962971
raise e
963972
finally:
973+
progress_bar.stop()
964974
self._fit_cleanup()
965975

966976
self.fitted = True

autosklearn/estimators.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(
7676
get_trials_callback: SMACCallback | None = None,
7777
dataset_compression: Union[bool, Mapping[str, Any]] = True,
7878
allow_string_features: bool = True,
79+
disable_progress_bar: bool = False,
7980
):
8081
"""
8182
Parameters
@@ -381,6 +382,10 @@ def __init__(
381382
Whether autosklearn should process string features. By default the
382383
textpreprocessing is enabled.
383384
385+
disable_progress_bar: bool = False
386+
Whether to disable the progress bar that is displayed in the console
387+
while fitting to the training data.
388+
384389
Attributes
385390
----------
386391
cv_results_ : dict of numpy (masked) ndarrays
@@ -475,6 +480,7 @@ def __init__(
475480
self.get_trials_callback = get_trials_callback
476481
self.dataset_compression = dataset_compression
477482
self.allow_string_features = allow_string_features
483+
self.disable_progress_bar = disable_progress_bar
478484

479485
self.automl_ = None # type: Optional[AutoML]
480486

@@ -525,6 +531,7 @@ def build_automl(self):
525531
get_trials_callback=self.get_trials_callback,
526532
dataset_compression=self.dataset_compression,
527533
allow_string_features=self.allow_string_features,
534+
disable_progress_bar=self.disable_progress_bar,
528535
)
529536

530537
return automl

autosklearn/experimental/askl2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def __init__(
166166
load_models: bool = True,
167167
dataset_compression: Union[bool, Mapping[str, Any]] = True,
168168
allow_string_features: bool = True,
169+
disable_progress_bar: bool = False,
169170
):
170171

171172
"""
@@ -284,6 +285,10 @@ def __init__(
284285
load_models : bool, optional (True)
285286
Whether to load the models after fitting Auto-sklearn.
286287
288+
disable_progress_bar: bool = False
289+
Whether to disable the progress bar that is displayed in the console
290+
while fitting to the training data.
291+
287292
Attributes
288293
----------
289294
@@ -337,6 +342,7 @@ def __init__(
337342
scoring_functions=scoring_functions,
338343
load_models=load_models,
339344
allow_string_features=allow_string_features,
345+
disable_progress_bar=disable_progress_bar,
340346
)
341347

342348
def train_selectors(self, selected_metric=None):

autosklearn/util/progress_bar.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Any
2+
3+
import datetime
4+
import time
5+
from threading import Thread
6+
7+
from tqdm import trange
8+
9+
10+
class ProgressBar(Thread):
11+
"""A Thread that displays a tqdm progress bar in the console.
12+
13+
It is specialized to display information relevant to fitting to the training data
14+
with auto-sklearn.
15+
16+
Parameters
17+
----------
18+
total : int
19+
The total amount that should be reached by the progress bar once it finishes
20+
update_interval : float
21+
Specifies how frequently the progress bar is updated (in seconds)
22+
disable : bool
23+
Turns on or off the progress bar. If True, this thread won't be started or
24+
initialized.
25+
kwargs : Any
26+
Keyword arguments that are passed into tqdm's constructor. Refer to:
27+
`tqdm <https://tqdm.github.io/docs/tqdm/>`_. Note that postfix can not be
28+
specified in the kwargs since it is already passed into tqdm by this class.
29+
"""
30+
31+
def __init__(
32+
self,
33+
total: int,
34+
update_interval: float = 1.0,
35+
disable: bool = False,
36+
**kwargs: Any,
37+
):
38+
self.disable = disable
39+
if not disable:
40+
super().__init__(name="_progressbar_")
41+
self.total = total
42+
self.update_interval = update_interval
43+
self.terminated: bool = False
44+
self.kwargs = kwargs
45+
# start this thread
46+
self.start()
47+
48+
def run(self) -> None:
49+
"""Display a tqdm progress bar in the console.
50+
51+
Additionally, it shows useful information related to the task. This method
52+
overrides the run method of Thread.
53+
"""
54+
if not self.disable:
55+
for _ in trange(
56+
self.total,
57+
postfix=f"The total time budget for this task is "
58+
f"{datetime.timedelta(seconds=self.total)}",
59+
**self.kwargs,
60+
):
61+
if not self.terminated:
62+
time.sleep(self.update_interval)
63+
64+
def stop(self) -> None:
65+
"""Terminates the thread."""
66+
if not self.disable:
67+
self.terminated = True
68+
super().join()

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ module = [
155155
"setuptools.*",
156156
"pkg_resources.*",
157157
"yaml.*",
158-
"psutil.*"
158+
"psutil.*",
159+
"tqdm.*",
159160
]
160161
ignore_missing_imports = true
161162

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ pyyaml
1414
pandas>=1.0
1515
liac-arff
1616
threadpoolctl
17+
tqdm
1718

1819
ConfigSpace>=0.4.21,<0.5
1920
pynisher>=0.6.3,<0.7
2021
pyrfr>=0.8.1,<0.9
21-
smac>=1.2,<1.3
22+
smac>=1.2,<1.3

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"test": [
3333
"pytest>=4.6",
3434
"pytest-cov",
35-
"pytest-xdist",
35+
"pytest-forked",
3636
"pytest-timeout",
3737
"pytest-cases>=3.6.11",
3838
"mypy",

0 commit comments

Comments
 (0)