Skip to content

Commit 55dd3a4

Browse files
authored
Typing for tests 1/n (#6313)
* typing * yapf * typing
1 parent fc6d402 commit 55dd3a4

16 files changed

+113
-119
lines changed

tests/accelerators/test_accelerator_connector.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License
1414

1515
import os
16+
from typing import Optional
1617
from unittest import mock
1718

1819
import pytest
@@ -30,6 +31,7 @@
3031
DDPSpawnPlugin,
3132
DDPSpawnShardedPlugin,
3233
DeepSpeedPlugin,
34+
ParallelPlugin,
3335
PrecisionPlugin,
3436
SingleDevicePlugin,
3537
)
@@ -408,10 +410,8 @@ def test_ipython_incompatible_backend_error(*_):
408410
["accelerator", "plugin"],
409411
[('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')],
410412
)
411-
def test_plugin_accelerator_choice(accelerator, plugin):
412-
"""
413-
Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent.
414-
"""
413+
def test_plugin_accelerator_choice(accelerator: Optional[str], plugin: str):
414+
"""Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent."""
415415
trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2)
416416
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)
417417

@@ -428,7 +428,9 @@ def test_plugin_accelerator_choice(accelerator, plugin):
428428
])
429429
@mock.patch('torch.cuda.is_available', return_value=True)
430430
@mock.patch('torch.cuda.device_count', return_value=2)
431-
def test_accelerator_choice_multi_node_gpu(mock_is_available, mock_device_count, accelerator, plugin, tmpdir):
431+
def test_accelerator_choice_multi_node_gpu(
432+
mock_is_available, mock_device_count, tmpdir, accelerator: str, plugin: ParallelPlugin
433+
):
432434
trainer = Trainer(
433435
accelerator=accelerator,
434436
default_root_dir=tmpdir,

tests/callbacks/test_callback_hook_outputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
@pytest.mark.parametrize("single_cb", [False, True])
21-
def test_train_step_no_return(tmpdir, single_cb):
21+
def test_train_step_no_return(tmpdir, single_cb: bool):
2222
"""
2323
Tests that only training_step can be used
2424
"""

tests/callbacks/test_early_stopping.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import logging
1515
import os
1616
import pickle
17+
from typing import List, Optional
1718
from unittest import mock
1819

1920
import cloudpickle
@@ -119,7 +120,7 @@ def test_early_stopping_no_extraneous_invocations(tmpdir):
119120
([6, 5, 6, 5, 5, 5], 3, 4),
120121
],
121122
)
122-
def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_epoch):
123+
def test_early_stopping_patience(tmpdir, loss_values: list, patience: int, expected_stop_epoch: int):
123124
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
124125

125126
class ModelOverrideValidationReturn(BoringModel):
@@ -142,7 +143,7 @@ def validation_epoch_end(self, outputs):
142143
assert trainer.current_epoch == expected_stop_epoch
143144

144145

145-
@pytest.mark.parametrize('validation_step', ['base', None])
146+
@pytest.mark.parametrize('validation_step_none', [True, False])
146147
@pytest.mark.parametrize(
147148
"loss_values, patience, expected_stop_epoch",
148149
[
@@ -151,7 +152,9 @@ def validation_epoch_end(self, outputs):
151152
([6, 5, 6, 5, 5, 5], 3, 4),
152153
],
153154
)
154-
def test_early_stopping_patience_train(tmpdir, validation_step, loss_values, patience, expected_stop_epoch):
155+
def test_early_stopping_patience_train(
156+
tmpdir, validation_step_none: bool, loss_values: list, patience: int, expected_stop_epoch: int
157+
):
155158
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
156159

157160
class ModelOverrideTrainReturn(BoringModel):
@@ -163,7 +166,7 @@ def training_epoch_end(self, outputs):
163166

164167
model = ModelOverrideTrainReturn()
165168

166-
if validation_step is None:
169+
if validation_step_none:
167170
model.validation_step = None
168171

169172
early_stop_callback = EarlyStopping(monitor="train_loss", patience=patience, verbose=True)
@@ -254,7 +257,7 @@ def validation_epoch_end(self, outputs):
254257

255258

256259
@pytest.mark.parametrize('step_freeze, min_steps, min_epochs', [(5, 1, 1), (5, 1, 3), (3, 15, 1)])
257-
def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, min_steps, min_epochs):
260+
def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: int, min_steps: int, min_epochs: int):
258261
"""Excepted Behaviour:
259262
IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being triggered,
260263
THEN the trainer should continue until reaching `trainer.global_step` == `min_steps`, and stop.
@@ -386,10 +389,10 @@ def on_train_end(self) -> None:
386389
marks=RunIf(skip_windows=True)),
387390
],
388391
)
389-
def test_multiple_early_stopping_callbacks(callbacks, expected_stop_epoch, accelerator, num_processes, tmpdir):
390-
"""
391-
Ensure when using multiple early stopping callbacks we stop if any signals we should stop.
392-
"""
392+
def test_multiple_early_stopping_callbacks(
393+
tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, accelerator: Optional[str], num_processes: int
394+
):
395+
"""Ensure when using multiple early stopping callbacks we stop if any signals we should stop."""
393396

394397
model = EarlyStoppingModel(expected_stop_epoch)
395398

tests/callbacks/test_lr_monitor.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,8 @@ def test_lr_monitor_single_lr(tmpdir):
5151

5252

5353
@pytest.mark.parametrize('opt', ['SGD', 'Adam'])
54-
def test_lr_monitor_single_lr_with_momentum(tmpdir, opt):
55-
"""
56-
Test that learning rates and momentum are extracted and logged for single lr scheduler.
57-
"""
54+
def test_lr_monitor_single_lr_with_momentum(tmpdir, opt: str):
55+
"""Test that learning rates and momentum are extracted and logged for single lr scheduler."""
5856

5957
class LogMomentumModel(BoringModel):
6058

@@ -170,7 +168,7 @@ def test_lr_monitor_no_logger(tmpdir):
170168

171169

172170
@pytest.mark.parametrize("logging_interval", ['step', 'epoch'])
173-
def test_lr_monitor_multi_lrs(tmpdir, logging_interval):
171+
def test_lr_monitor_multi_lrs(tmpdir, logging_interval: str):
174172
""" Test that learning rates are extracted and logged for multi lr schedulers. """
175173
tutils.reset_seed()
176174

tests/callbacks/test_progress_bar.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
import sys
16+
from typing import Optional, Union
1617
from unittest import mock
1718
from unittest.mock import ANY, call, Mock
1819

@@ -36,7 +37,7 @@
3637
([ProgressBar(refresh_rate=2)], 1),
3738
]
3839
)
39-
def test_progress_bar_on(tmpdir, callbacks, refresh_rate):
40+
def test_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]):
4041
"""Test different ways the progress bar can be turned on."""
4142

4243
trainer = Trainer(
@@ -60,7 +61,7 @@ def test_progress_bar_on(tmpdir, callbacks, refresh_rate):
6061
([ModelCheckpoint(dirpath='../trainer')], 0),
6162
]
6263
)
63-
def test_progress_bar_off(tmpdir, callbacks, refresh_rate):
64+
def test_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int]):
6465
"""Test different ways the progress bar can be turned off."""
6566

6667
trainer = Trainer(
@@ -165,7 +166,7 @@ def test_progress_bar_fast_dev_run(tmpdir):
165166

166167

167168
@pytest.mark.parametrize('refresh_rate', [0, 1, 50])
168-
def test_progress_bar_progress_refresh(tmpdir, refresh_rate):
169+
def test_progress_bar_progress_refresh(tmpdir, refresh_rate: int):
169170
"""Test that the three progress bars get correctly updated when using different refresh rates."""
170171

171172
model = BoringModel()
@@ -219,7 +220,7 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal
219220

220221

221222
@pytest.mark.parametrize('limit_val_batches', (0, 5))
222-
def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches):
223+
def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches: int):
223224
"""
224225
Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument.
225226
"""
@@ -309,7 +310,9 @@ def init_test_tqdm(self):
309310
[5, 2, 6, [6, 1], [2]],
310311
]
311312
)
312-
def test_main_progress_bar_update_amount(tmpdir, train_batches, val_batches, refresh_rate, train_deltas, val_deltas):
313+
def test_main_progress_bar_update_amount(
314+
tmpdir, train_batches: int, val_batches: int, refresh_rate: int, train_deltas: list, val_deltas: list
315+
):
313316
"""
314317
Test that the main progress updates with the correct amount together with the val progress. At the end of
315318
the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh rate.
@@ -336,7 +339,7 @@ def test_main_progress_bar_update_amount(tmpdir, train_batches, val_batches, ref
336339
[3, 1, [1, 1, 1]],
337340
[5, 3, [3, 2]],
338341
])
339-
def test_test_progress_bar_update_amount(tmpdir, test_batches, refresh_rate, test_deltas):
342+
def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate: int, test_deltas: list):
340343
"""
341344
Test that test progress updates with the correct amount.
342345
"""
@@ -379,10 +382,18 @@ def training_step(self, batch, batch_idx):
379382

380383

381384
@pytest.mark.parametrize(
382-
"input_num, expected", [[1, '1'], [1.0, '1.000'], [0.1, '0.100'], [1e-3, '0.001'], [1e-5, '1e-5'], ['1.0', '1.000'],
383-
['10000', '10000'], ['abc', 'abc']]
385+
"input_num, expected", [
386+
[1, '1'],
387+
[1.0, '1.000'],
388+
[0.1, '0.100'],
389+
[1e-3, '0.001'],
390+
[1e-5, '1e-5'],
391+
['1.0', '1.000'],
392+
['10000', '10000'],
393+
['abc', 'abc'],
394+
]
384395
)
385-
def test_tqdm_format_num(input_num, expected):
396+
def test_tqdm_format_num(input_num: Union[str, int, float], expected: str):
386397
""" Check that the specialized tqdm.format_num appends 0 to floats and strings """
387398
assert tqdm.format_num(input_num) == expected
388399

tests/callbacks/test_pruning.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from collections import OrderedDict
1515
from logging import INFO
16+
from typing import Union
1617

1718
import pytest
1819
import torch
@@ -144,7 +145,8 @@ def test_pruning_misconfiguration():
144145
)
145146
@pytest.mark.parametrize("use_lottery_ticket_hypothesis", [False, True])
146147
def test_pruning_callback(
147-
tmpdir, use_global_unstructured, parameters_to_prune, pruning_fn, use_lottery_ticket_hypothesis
148+
tmpdir, use_global_unstructured: bool, parameters_to_prune: bool,
149+
pruning_fn: Union[str, pytorch_prune.BasePruningMethod], use_lottery_ticket_hypothesis: bool
148150
):
149151
train_with_pruning_callback(
150152
tmpdir,
@@ -158,7 +160,7 @@ def test_pruning_callback(
158160
@RunIf(special=True)
159161
@pytest.mark.parametrize("parameters_to_prune", [False, True])
160162
@pytest.mark.parametrize("use_global_unstructured", [False, True])
161-
def test_pruning_callback_ddp(tmpdir, use_global_unstructured, parameters_to_prune):
163+
def test_pruning_callback_ddp(tmpdir, use_global_unstructured: bool, parameters_to_prune: bool):
162164
train_with_pruning_callback(
163165
tmpdir,
164166
parameters_to_prune=parameters_to_prune,
@@ -179,7 +181,7 @@ def test_pruning_callback_ddp_cpu(tmpdir):
179181

180182

181183
@pytest.mark.parametrize("resample_parameters", (False, True))
182-
def test_pruning_lth_callable(tmpdir, resample_parameters):
184+
def test_pruning_lth_callable(tmpdir, resample_parameters: bool):
183185
model = TestModel()
184186

185187
class ModelPruningTestCallback(ModelPruning):
@@ -218,7 +220,7 @@ def apply_lottery_ticket_hypothesis(self):
218220

219221

220222
@pytest.mark.parametrize("make_pruning_permanent", (False, True))
221-
def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
223+
def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool):
222224
seed_everything(0)
223225
model = TestModel()
224226
pruning_kwargs = {
@@ -228,6 +230,7 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
228230
}
229231
p1 = ModelPruning("l1_unstructured", amount=0.5, apply_pruning=lambda e: not e % 2, **pruning_kwargs)
230232
p2 = ModelPruning("random_unstructured", amount=0.25, apply_pruning=lambda e: e % 2, **pruning_kwargs)
233+
231234
trainer = Trainer(
232235
default_root_dir=tmpdir,
233236
progress_bar_refresh_rate=0,

tests/callbacks/test_quantization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import copy
15+
from typing import Callable, Union
1516

1617
import pytest
1718
import torch
@@ -28,7 +29,7 @@
2829
@pytest.mark.parametrize("observe", ['average', pytest.param('histogram', marks=RunIf(min_torch="1.5"))])
2930
@pytest.mark.parametrize("fuse", [True, False])
3031
@RunIf(quantization=True)
31-
def test_quantization(tmpdir, observe, fuse):
32+
def test_quantization(tmpdir, observe: str, fuse: bool):
3233
"""Parity test for quant model"""
3334
seed_everything(42)
3435
dm = RegressDataModule()
@@ -122,7 +123,7 @@ def custom_trigger_last(trainer):
122123
]
123124
)
124125
@RunIf(quantization=True)
125-
def test_quantization_triggers(tmpdir, trigger_fn, expected_count):
126+
def test_quantization_triggers(tmpdir, trigger_fn: Union[None, int, Callable], expected_count: int):
126127
"""Test how many times the quant is called"""
127128
dm = RegressDataModule()
128129
qmodel = RegressionModel()

tests/callbacks/test_stochastic_weight_avg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def test_swa_callback_1_gpu(tmpdir):
136136

137137
@RunIf(min_torch="1.6.0")
138138
@pytest.mark.parametrize("batchnorm", (True, False))
139-
def test_swa_callback(tmpdir, batchnorm):
139+
def test_swa_callback(tmpdir, batchnorm: bool):
140140
train_with_swa(tmpdir, batchnorm=batchnorm)
141141

142142

@@ -155,7 +155,7 @@ def test_swa_raises():
155155
@pytest.mark.parametrize('stochastic_weight_avg', [False, True])
156156
@pytest.mark.parametrize('use_callbacks', [False, True])
157157
@RunIf(min_torch="1.6.0")
158-
def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks, stochastic_weight_avg):
158+
def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks: bool, stochastic_weight_avg: bool):
159159
"""Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer"""
160160

161161
class TestModel(BoringModel):

tests/checkpointing/test_checkpoint_callback_frequency.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_mc_called(tmpdir):
5151
['epochs', 'val_check_interval', 'expected'],
5252
[(1, 1.0, 1), (2, 1.0, 2), (1, 0.25, 4), (2, 0.3, 7)],
5353
)
54-
def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval, expected):
54+
def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_interval: float, expected: int):
5555

5656
model = BoringModel()
5757
trainer = Trainer(
@@ -68,9 +68,13 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs, val_check_interval,
6868

6969

7070
@mock.patch('torch.save')
71-
@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [(1, 1, 1.0, 1), (2, 2, 1.0, 2),
72-
(2, 1, 0.25, 4), (2, 2, 0.3, 7)])
73-
def test_top_k(save_mock, tmpdir, k, epochs, val_check_interval, expected):
71+
@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [
72+
(1, 1, 1.0, 1),
73+
(2, 2, 1.0, 2),
74+
(2, 1, 0.25, 4),
75+
(2, 2, 0.3, 7),
76+
])
77+
def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int):
7478

7579
class TestModel(BoringModel):
7680

tests/checkpointing/test_legacy_checkpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
"1.2.2",
5858
]
5959
)
60-
def test_resume_legacy_checkpoints(tmpdir, pl_version):
60+
def test_resume_legacy_checkpoints(tmpdir, pl_version: str):
6161
path_dir = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)
6262

6363
# todo: make this as mock, so it is cleaner...

0 commit comments

Comments
 (0)