Skip to content

Commit 51b10f7

Browse files
carmoccatchaton
andauthored
Refactor PyTorch profiler 4/5 (#6349)
Co-authored-by: thomas chaton <[email protected]>
1 parent 3cf0c31 commit 51b10f7

File tree

11 files changed

+377
-219
lines changed

11 files changed

+377
-219
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5252
- Added `AbstractProfiler` interface ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621))
5353

5454

55+
- Added support for including module names for forward in the autograd trace of `PyTorchProfiler` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))
56+
57+
5558
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))
5659

5760

@@ -72,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7275
- Changed profilers to save separate report files per state and rank ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621))
7376

7477

78+
- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))
79+
80+
7581
### Deprecated
7682

7783
- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
@@ -83,6 +89,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8389
- Deprecated `Profiler(output_filename)` in favor of `dirpath` and `filename` ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621))
8490

8591

92+
- Deprecated `PytorchProfiler(profiled_functions)` in favor of `record_functions` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))
93+
94+
8695
- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505),
8796

8897
[#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530),

pytorch_lightning/profiler/profilers.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,15 @@ def _prepare_filename(self) -> str:
126126
filename += f"{self._stage}-"
127127
filename += str(self.filename)
128128
if self._local_rank is not None:
129-
filename += f"-{self.local_rank}"
129+
filename += f"-{self._local_rank}"
130130
filename += ".txt"
131131
return filename
132132

133133
def _prepare_streams(self) -> None:
134134
if self._write_stream is not None:
135135
return
136136
if self.filename:
137-
dirpath = self.dirpath or self._log_dir
138-
filepath = os.path.join(dirpath, self._prepare_filename())
137+
filepath = os.path.join(self.dirpath, self._prepare_filename())
139138
fs = get_filesystem(filepath)
140139
file = fs.open(filepath, "a")
141140
self._output_file = file
@@ -175,8 +174,7 @@ def setup(
175174
self._stage = stage
176175
self._local_rank = local_rank
177176
self._log_dir = log_dir
178-
if self.dirpath is None:
179-
self.dirpath = self._log_dir
177+
self.dirpath = self.dirpath or log_dir
180178

181179
def teardown(self, stage: Optional[str] = None) -> None:
182180
"""
@@ -202,8 +200,8 @@ def summary(self) -> str:
202200
raise NotImplementedError
203201

204202
@property
205-
def local_rank(self):
206-
return '0' if self._local_rank is None else self._local_rank
203+
def local_rank(self) -> int:
204+
return 0 if self._local_rank is None else self._local_rank
207205

208206

209207
class PassThroughProfiler(BaseProfiler):

0 commit comments

Comments
 (0)