1313# limitations under the License.
1414import os
1515import sys
16+ from typing import Optional , Union
1617from unittest import mock
1718from unittest .mock import ANY , call , Mock
1819
3637 ([ProgressBar (refresh_rate = 2 )], 1 ),
3738 ]
3839)
39- def test_progress_bar_on (tmpdir , callbacks , refresh_rate ):
40+ def test_progress_bar_on (tmpdir , callbacks : list , refresh_rate : Optional [ int ] ):
4041 """Test different ways the progress bar can be turned on."""
4142
4243 trainer = Trainer (
@@ -60,7 +61,7 @@ def test_progress_bar_on(tmpdir, callbacks, refresh_rate):
6061 ([ModelCheckpoint (dirpath = '../trainer' )], 0 ),
6162 ]
6263)
63- def test_progress_bar_off (tmpdir , callbacks , refresh_rate ):
64+ def test_progress_bar_off (tmpdir , callbacks : list , refresh_rate : Union [ bool , int ] ):
6465 """Test different ways the progress bar can be turned off."""
6566
6667 trainer = Trainer (
@@ -165,7 +166,7 @@ def test_progress_bar_fast_dev_run(tmpdir):
165166
166167
167168@pytest .mark .parametrize ('refresh_rate' , [0 , 1 , 50 ])
168- def test_progress_bar_progress_refresh (tmpdir , refresh_rate ):
169+ def test_progress_bar_progress_refresh (tmpdir , refresh_rate : int ):
169170 """Test that the three progress bars get correctly updated when using different refresh rates."""
170171
171172 model = BoringModel ()
@@ -219,7 +220,7 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal
219220
220221
221222@pytest .mark .parametrize ('limit_val_batches' , (0 , 5 ))
222- def test_num_sanity_val_steps_progress_bar (tmpdir , limit_val_batches ):
223+ def test_num_sanity_val_steps_progress_bar (tmpdir , limit_val_batches : int ):
223224 """
224225 Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument.
225226 """
@@ -309,7 +310,9 @@ def init_test_tqdm(self):
309310 [5 , 2 , 6 , [6 , 1 ], [2 ]],
310311 ]
311312)
312- def test_main_progress_bar_update_amount (tmpdir , train_batches , val_batches , refresh_rate , train_deltas , val_deltas ):
313+ def test_main_progress_bar_update_amount (
314+ tmpdir , train_batches : int , val_batches : int , refresh_rate : int , train_deltas : list , val_deltas : list
315+ ):
313316 """
314317 Test that the main progress updates with the correct amount together with the val progress. At the end of
315318 the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh rate.
@@ -336,7 +339,7 @@ def test_main_progress_bar_update_amount(tmpdir, train_batches, val_batches, ref
336339 [3 , 1 , [1 , 1 , 1 ]],
337340 [5 , 3 , [3 , 2 ]],
338341])
339- def test_test_progress_bar_update_amount (tmpdir , test_batches , refresh_rate , test_deltas ):
342+ def test_test_progress_bar_update_amount (tmpdir , test_batches : int , refresh_rate : int , test_deltas : list ):
340343 """
341344 Test that test progress updates with the correct amount.
342345 """
@@ -379,10 +382,18 @@ def training_step(self, batch, batch_idx):
379382
380383
381384@pytest .mark .parametrize (
382- "input_num, expected" , [[1 , '1' ], [1.0 , '1.000' ], [0.1 , '0.100' ], [1e-3 , '0.001' ], [1e-5 , '1e-5' ], ['1.0' , '1.000' ],
383- ['10000' , '10000' ], ['abc' , 'abc' ]]
385+ "input_num, expected" , [
386+ [1 , '1' ],
387+ [1.0 , '1.000' ],
388+ [0.1 , '0.100' ],
389+ [1e-3 , '0.001' ],
390+ [1e-5 , '1e-5' ],
391+ ['1.0' , '1.000' ],
392+ ['10000' , '10000' ],
393+ ['abc' , 'abc' ],
394+ ]
384395)
385- def test_tqdm_format_num (input_num , expected ):
396+ def test_tqdm_format_num (input_num : Union [ str , int , float ], expected : str ):
386397 """ Check that the specialized tqdm.format_num appends 0 to floats and strings """
387398 assert tqdm .format_num (input_num ) == expected
388399
0 commit comments