Skip to content

Commit fd5cb7f

Browse files
tchatoncarmoccaananthsubBorda
authored
Add PyTorch 1.8 Profiler 5/5 (#6618)
* Refactor profilers * Update PassThrough * WIP - This is broken and will change * Update pytorch_lightning/profiler/pytorch.py Co-authored-by: thomas chaton <[email protected]> * resolve tests * resolve tests * find output * try something * update * add support for test and predict * update * update * use getattr * test * test * update * tests * update * update * update * update * update * remove file * update * update * update * update * update * test * update# * update * update tests * update * add suport for 1.8 * rename records * add support for 1.8 * update * resolve flake8 * resolve test * Refactor basic profilers * Fixes * Unused import * Introduce setup * Profile on all ranks. Print to stdout on 0 * Introduce dirpath + filename * CHANGELOG * Add tests. Address comments * add `on_run_stage_setup` * add on_run_stage_setup function * update * add test for RegisterRecordFunction * update lightnng flow direction * move variable to private * remove trace * Undo code that should be in 3/4 * Multi-stage multi-rank * 2/5 changes * Pass stage in __del__ * Remove TODOs * Describe on_evaluation_end. Add tests * Typo * Address comments * deepcopy tests * Advanced teardown * Fix teardown test * Fix tests * Minor change * Update CHANGELOG.md * Fix test * Quick fixes * Fix 6522 * resolve ddp tests * resolve tests * resolve some tests * update tests * resolve tests * update * resolve tests * resolve some tests * Missed fixes from 3/5 * Fixes * resolve some tests * resolve test for 1.7.1 * Broken refactor * Missed stage * Minor changes * resolve tests * Update CHANGELOG * resolve bug * remove print * Typo * Cleanup * resolve ddp test * remove barrier * update profiler * update * Smaller model * update * resolve tests * update * Minor changes. CHANGELOG * Minimize diff * update to 1.8.1 * RunIf. Extra code. Check segfault * resolve tests * Typo. Bad merge * Fixing a bad merge * replace for kineto * Update pytorch_lightning/profiler/pytorch.py Co-authored-by: ananthsub <[email protected]> * Update pytorch_lightning/profiler/pytorch.py Co-authored-by: ananthsub <[email protected]> * Minor changes * Bad merge * Use lists for flexibility * Use sets * predict_step * Ananth's suggestion * update * Docs * Update pl_examples/basic_examples/profiler_example.py Co-authored-by: Jirka Borovec <[email protected]> * update example * update example Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: ananthsub <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 51b10f7 commit fd5cb7f

File tree

12 files changed

+399
-96
lines changed

12 files changed

+399
-96
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5555
- Added support for including module names for forward in the autograd trace of `PyTorchProfiler` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))
5656

5757

58+
- Added support for the PyTorch 1.8.1 autograd profiler ([#6618](https://github.com/PyTorchLightning/pytorch-lightning/pull/6618))
59+
60+
5861
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))
5962

6063

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
This script will generate 2 traces: one for `training_step` and one for `validation_step`.
16+
The traces can be visualized in 2 ways:
17+
* With Chrome:
18+
1. Open Chrome and copy/paste this url: `chrome://tracing/`.
19+
2. Once tracing opens, click on `Load` at the top-right and load one of the generated traces.
20+
* With PyTorch Tensorboard Profiler (Instructions are here: https://github.com/pytorch/kineto/tree/master/tb_plugin)
21+
1. pip install tensorboard torch-tb-profiler
22+
2. tensorboard --logdir={FOLDER}
23+
"""
24+
25+
import sys
26+
from argparse import ArgumentParser
27+
28+
import torch
29+
import torchvision
30+
import torchvision.models as models
31+
import torchvision.transforms as T
32+
33+
from pl_examples import cli_lightning_logo
34+
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
35+
36+
DEFAULT_CMD_LINE = (
37+
"--max_epochs",
38+
"1",
39+
"--limit_train_batches",
40+
"15",
41+
"--limit_val_batches",
42+
"15",
43+
"--profiler",
44+
"pytorch",
45+
"--gpus",
46+
f"{int(torch.cuda.is_available())}",
47+
)
48+
49+
50+
class ModelToProfile(LightningModule):
51+
52+
def __init__(self, model):
53+
super().__init__()
54+
self.model = model
55+
self.criterion = torch.nn.CrossEntropyLoss()
56+
57+
def training_step(self, batch, batch_idx):
58+
inputs, labels = batch
59+
outputs = self.model(inputs)
60+
loss = self.criterion(outputs, labels)
61+
self.log("train_loss", loss)
62+
return loss
63+
64+
def validation_step(self, batch, batch_idx):
65+
inputs, labels = batch
66+
outputs = self.model(inputs)
67+
loss = self.criterion(outputs, labels)
68+
self.log("val_loss", loss)
69+
70+
def configure_optimizers(self):
71+
return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
72+
73+
74+
class CIFAR10DataModule(LightningDataModule):
75+
76+
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
77+
78+
def train_dataloader(self, *args, **kwargs):
79+
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=self.transform)
80+
return torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=0)
81+
82+
def val_dataloader(self, *args, **kwargs):
83+
valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=self.transform)
84+
return torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, num_workers=0)
85+
86+
87+
def cli_main():
88+
89+
parser = ArgumentParser()
90+
parser = Trainer.add_argparse_args(parser)
91+
cmd_line = None if len(sys.argv) != 1 else DEFAULT_CMD_LINE
92+
args = parser.parse_args(args=cmd_line)
93+
94+
model = ModelToProfile(models.resnet50(pretrained=True))
95+
datamodule = CIFAR10DataModule()
96+
trainer = Trainer(**vars(args))
97+
trainer.fit(model, datamodule=datamodule)
98+
99+
100+
if __name__ == '__main__':
101+
cli_lightning_logo()
102+
cli_main()

pytorch_lightning/accelerators/accelerator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -448,8 +448,10 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: Lightn
448448
.. deprecated::v1.3
449449
Will be removed in v1.5.0.
450450
"""
451-
rank_zero_warn('Accelerator method `connect_training_type_plugin` was deprecated in v1.3.'
452-
' It will be removed in v1.5.')
451+
rank_zero_warn(
452+
'Accelerator method `connect_training_type_plugin` was deprecated in v1.3.'
453+
' It will be removed in v1.5.'
454+
)
453455
self.setup_training_type_plugin(plugin, model)
454456

455457
# todo: remove in v1.5
@@ -459,6 +461,8 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
459461
.. deprecated::v1.3
460462
Will be removed in v1.5.0.
461463
"""
462-
rank_zero_warn('Accelerator method `connect_precision_plugin` was deprecated in v1.3.'
463-
' It will be removed in v1.5.')
464+
rank_zero_warn(
465+
'Accelerator method `connect_precision_plugin` was deprecated in v1.3.'
466+
' It will be removed in v1.5.'
467+
)
464468
self.setup_precision_plugin(plugin)

pytorch_lightning/profiler/__init__.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ def custom_processing_step(self, data):
121121
Autograd includes a profiler that lets you inspect the cost of different operators
122122
inside your model - both on the CPU and GPU.
123123
124-
Find the Pytorch Profiler doc at [PyTorch Profiler](https://pytorch-lightning.readthedocs.io/en/stable/profiler.html)
124+
To read more about the PyTorch Profiler and all its options,
125+
have a look at its `docs <https://pytorch.org/docs/master/profiler.html>`__
125126
126127
.. code-block:: python
127128
@@ -134,16 +135,16 @@ def custom_processing_step(self, data):
134135
135136
136137
This profiler works with PyTorch ``DistributedDataParallel``.
137-
If ``output_filename`` is provided, each rank will save their profiled operation to their own file.
138+
If ``filename`` is provided, each rank will save their profiled operation to their own file. The profiler
139+
report can be quite long, so you setting a ``filename`` will save the report instead of logging it to the
140+
output in your terminal. If no filename is given, it will be logged only on rank 0.
138141
142+
The profiler's results will be printed on the completion of ``{fit,validate,test,predict}``.
139143
140-
The profiler's results will be printed on the completion of a training `fit()`. This profiler
141-
report can be quite long, so you can also specify an `output_filename` to save the report instead
142-
of logging it to the output in your terminal.
143-
144-
This profiler will record only for `training_step_and_backward`, `evaluation_step` and `test_step` functions by default.
145-
The output below shows the profiling for the action `training_step_and_backward`.
146-
The user can provide ``PyTorchProfiler(profiled_functions=[...])`` to extend the scope of profiled functions.
144+
This profiler will record ``training_step_and_backward``, ``training_step``, ``backward``,
145+
``validation_step``, ``test_step``, and ``predict_step`` by default.
146+
The output below shows the profiling for the action ``training_step_and_backward``.
147+
The user can provide ``PyTorchProfiler(record_functions={...})`` to extend the scope of profiled functions.
147148
148149
.. note:: When using the PyTorch Profiler, wall clock time will not not be representative of the true wall clock time. This is due to forcing profiled operations to be measured synchronously, when many CUDA ops happen asynchronously. It is recommended to use this Profiler to find bottlenecks/breakdowns, however for end to end wall clock time use the `SimpleProfiler`. # noqa E501
149150
@@ -184,13 +185,13 @@ def custom_processing_step(self, data):
184185
185186
To visualize the profiled operation, you can either:
186187
187-
* Use::
188+
Use::
188189
189190
nvvp trace_name.prof
190191
191-
* Use::
192+
Or::
192193
193-
python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))'
194+
python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))'
194195
195196
"""
196197

pytorch_lightning/profiler/profilers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,14 @@ def _rank_zero_info(self, *args, **kwargs) -> None:
120120
if self._local_rank in (None, 0):
121121
log.info(*args, **kwargs)
122122

123-
def _prepare_filename(self) -> str:
123+
def _prepare_filename(self, extension: str = ".txt") -> str:
124124
filename = ""
125125
if self._stage is not None:
126126
filename += f"{self._stage}-"
127127
filename += str(self.filename)
128128
if self._local_rank is not None:
129129
filename += f"-{self._local_rank}"
130-
filename += ".txt"
130+
filename += extension
131131
return filename
132132

133133
def _prepare_streams(self) -> None:

0 commit comments

Comments
 (0)