Skip to content

Commit 09baf29

Browse files
authored
prune deprecated profiler as bool (#6164)
* prune profiler * chlog
1 parent 45158aa commit 09baf29

File tree

12 files changed

+28
-89
lines changed

12 files changed

+28
-89
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818

1919
### Removed
2020

21+
- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))
22+
2123

2224
### Fixed
2325

docs/source/starter/new-project.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ Lightning has many tools for debugging. Here is an example of just a few of them
737737
.. testcode::
738738

739739
# Profile your code to find speed/memory bottlenecks
740-
Trainer(profiler=True)
740+
Trainer(profiler="simple")
741741

742742
---------------
743743

pytorch_lightning/trainer/connectors/profiler_connector.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,29 @@
2121
PyTorchProfiler,
2222
SimpleProfiler,
2323
)
24-
from pytorch_lightning.utilities import rank_zero_warn
2524
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2625

27-
PROFILERS = {"simple": SimpleProfiler, "advanced": AdvancedProfiler, "pytorch": PyTorchProfiler}
26+
PROFILERS = {
27+
"simple": SimpleProfiler,
28+
"advanced": AdvancedProfiler,
29+
"pytorch": PyTorchProfiler,
30+
}
2831

2932

3033
class ProfilerConnector:
3134

3235
def __init__(self, trainer):
3336
self.trainer = trainer
3437

35-
def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]):
38+
def on_trainer_init(self, profiler: Union[BaseProfiler, str]):
3639

37-
if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
38-
# TODO: Update exception on removal of bool
40+
if profiler and not isinstance(profiler, (str, BaseProfiler)):
3941
raise MisconfigurationException(
40-
"Only None, bool, str and subclasses of `BaseProfiler`"
42+
"Only None, str and subclasses of `BaseProfiler`"
4143
" are valid values for `Trainer`'s `profiler` parameter."
4244
f" Received {profiler} which is of type {type(profiler)}."
4345
)
44-
45-
if isinstance(profiler, bool):
46-
rank_zero_warn(
47-
"Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
48-
" and will be removed in v1.3. Use str ('simple' or 'advanced') instead.", DeprecationWarning
49-
)
50-
if profiler:
51-
profiler = SimpleProfiler()
52-
elif isinstance(profiler, str):
46+
if isinstance(profiler, str):
5347
if profiler.lower() in PROFILERS:
5448
profiler_class = PROFILERS[profiler.lower()]
5549
profiler = profiler_class()

pytorch_lightning/trainer/trainer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(
122122
num_sanity_val_steps: int = 2,
123123
truncated_bptt_steps: Optional[int] = None,
124124
resume_from_checkpoint: Optional[Union[Path, str]] = None,
125-
profiler: Optional[Union[BaseProfiler, bool, str]] = None,
125+
profiler: Optional[Union[BaseProfiler, str]] = None,
126126
benchmark: bool = False,
127127
deterministic: bool = False,
128128
reload_dataloaders_every_epoch: bool = False,
@@ -177,7 +177,7 @@ def __init__(
177177
178178
checkpoint_callback: If ``True``, enable checkpointing.
179179
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
180-
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``.
180+
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`.
181181
182182
.. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since
183183
v1.1 and will be unsupported from v1.3. Use `callbacks` argument instead.
@@ -226,10 +226,9 @@ def __init__(
226226
Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means
227227
a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.).
228228
229-
profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool
230-
value is deprecated in v1.1 and will be removed in v1.3.
229+
profiler: To profile individual steps during training and assist in identifying bottlenecks.
231230
232-
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0
231+
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int).
233232
234233
plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
235234
@@ -250,7 +249,7 @@ def __init__(
250249
num_processes: number of processes for distributed training with distributed_backend="ddp_cpu"
251250
252251
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
253-
Set it to `-1` to run all batches in all validation dataloaders. Default: 2
252+
Set it to `-1` to run all batches in all validation dataloaders.
254253
255254
reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch.
256255

tests/accelerators/test_cpu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ def test_unsupported_precision_plugins():
1414
trainer = Mock()
1515
model = Mock()
1616
accelerator = CPUAccelerator(
17-
training_type_plugin=SingleDevicePlugin(torch.device("cpu")),
18-
precision_plugin=MixedPrecisionPlugin()
17+
training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin()
1918
)
2019
with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."):
2120
accelerator.setup(trainer=trainer, model=model)

tests/base/model_optimizers.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,8 @@ def configure_optimizers__multiple_optimizers_frequency(self):
4545
optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate)
4646
optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate)
4747
return [
48-
{
49-
'optimizer': optimizer1,
50-
'frequency': 1
51-
},
52-
{
53-
'optimizer': optimizer2,
54-
'frequency': 5
55-
},
48+
dict(optimizer=optimizer1, frequency=1),
49+
dict(optimizer=optimizer2, frequency=5),
5650
]
5751

5852
def configure_optimizers__single_scheduler(self):

tests/base/model_test_steps.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ def test_step(self, batch, batch_idx, *args, **kwargs):
5454
output = OrderedDict({
5555
'test_loss': loss_test,
5656
'test_acc': test_acc,
57-
'test_dic': {
58-
'test_loss_a': loss_test
59-
},
57+
'test_dic': dict(test_loss_a=loss_test),
6058
})
6159
return output
6260

@@ -90,9 +88,7 @@ def test_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kw
9088
output = OrderedDict({
9189
'test_loss': loss_test,
9290
'test_acc': test_acc,
93-
'test_dic': {
94-
'test_loss_a': loss_test
95-
},
91+
'test_dic': dict(test_loss_a=loss_test),
9692
})
9793
return output
9894
if batch_idx % 5 == 0:

tests/base/model_train_steps.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,8 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
4444

4545
output = OrderedDict({
4646
'loss': loss_train,
47-
'progress_bar': {
48-
'some_val': log_train * log_train
49-
},
50-
'log': {
51-
'train_some_val': log_train * log_train
52-
},
47+
'progress_bar': dict(some_val=log_train * log_train),
48+
'log': dict(train_some_val=log_train * log_train),
5349
})
5450
return output
5551

tests/base/model_valid_steps.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ def validation_step(self, batch, batch_idx, *args, **kwargs):
4343
output = OrderedDict({
4444
'val_loss': loss_val,
4545
'val_acc': val_acc,
46-
'test_dic': {
47-
'val_loss_a': loss_val
48-
},
46+
'test_dic': dict(val_loss_a=loss_val),
4947
})
5048
return output
5149

tests/deprecated_api/test_remove_1-3.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Test deprecated functionality which will be removed in vX.Y.Z"""
15-
from argparse import ArgumentParser
16-
from unittest import mock
1715

1816
import pytest
1917
import torch
2018

2119
from pytorch_lightning import LightningModule, Trainer
2220
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
23-
from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler
2421

2522

2623
def test_v1_3_0_deprecated_arguments(tmpdir):
@@ -111,38 +108,6 @@ def test_v1_3_0_deprecated_metrics():
111108
)
112109

113110

114-
# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
115-
@pytest.mark.parametrize(['profiler', 'expected'], [
116-
(True, SimpleProfiler),
117-
(False, PassThroughProfiler),
118-
])
119-
def test_trainer_profiler_remove_in_v1_3_0(profiler, expected):
120-
# remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
121-
with pytest.deprecated_call(match='will be removed in v1.3'):
122-
trainer = Trainer(profiler=profiler)
123-
assert isinstance(trainer.profiler, expected)
124-
125-
126-
@pytest.mark.parametrize(
127-
['cli_args', 'expected_parsed_arg', 'expected_profiler'],
128-
[
129-
('--profiler', True, SimpleProfiler),
130-
('--profiler True', True, SimpleProfiler),
131-
('--profiler False', False, PassThroughProfiler),
132-
],
133-
)
134-
def test_v1_3_0_trainer_cli_profiler(cli_args, expected_parsed_arg, expected_profiler):
135-
cli_args = cli_args.split(' ')
136-
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
137-
parser = ArgumentParser(add_help=False)
138-
parser = Trainer.add_argparse_args(parent_parser=parser)
139-
args = Trainer.parse_argparser(parser)
140-
141-
assert getattr(args, "profiler") == expected_parsed_arg
142-
trainer = Trainer.from_argparse_args(args)
143-
assert isinstance(trainer.profiler, expected_profiler)
144-
145-
146111
def test_trainer_enable_pl_optimizer(tmpdir):
147112
with pytest.deprecated_call(match='will be removed in v1.3'):
148113
Trainer(enable_pl_optimizer=True)

tests/trainer/logging_/test_train_loop_logging_1_0.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,7 @@ def backward(self, loss, optimizer, optimizer_idx):
178178
assert logged_metrics == expected_logged_metrics
179179

180180
pbar_metrics = set(trainer.progress_bar_metrics.keys())
181-
expected_pbar_metrics = {
182-
'b',
183-
}
181+
expected_pbar_metrics = {'b'}
184182
assert pbar_metrics == expected_pbar_metrics
185183

186184
callback_metrics = set(trainer.callback_metrics.keys())

tests/trainer/test_trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,16 +1475,14 @@ def test_trainer_profiler_incorrect_str_arg():
14751475
@pytest.mark.parametrize('profiler', (
14761476
42,
14771477
[42],
1478-
{
1479-
"a": 42
1480-
},
1478+
dict(a=42),
14811479
torch.tensor(42),
14821480
Trainer(),
14831481
))
14841482
def test_trainer_profiler_incorrect_arg_type(profiler):
14851483
with pytest.raises(
14861484
MisconfigurationException,
1487-
match=r"Only None, bool, str and subclasses of `BaseProfiler`"
1485+
match="Only None, str and subclasses of `BaseProfiler`"
14881486
r" are valid values for `Trainer`'s `profiler` parameter. *"
14891487
):
14901488
Trainer(profiler=profiler)

0 commit comments

Comments
 (0)