Skip to content

Commit 28fc8d2

Browse files
ananthsubawaelchlikaushikb11pre-commit-ci[bot]
authored
Add enable_model_summary flag and deprecate weights_summary (#9699)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Kaushik B <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kaushik B <[email protected]>
1 parent b1e215d commit 28fc8d2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+248
-134
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
181181
- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))
182182

183183

184+
- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))
185+
186+
184187
### Changed
185188

186189
- Module imports are now catching `ModuleNotFoundError` instead of `ImportError` ([#9867](https://github.com/PyTorchLightning/pytorch-lightning/pull/9867))
@@ -344,6 +347,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
344347
- Deprecated `pytorch_lightning.core.decorators.parameter_validation` in favor of `pytorch_lightning.utilities.parameter_tying.set_shared_parameters` ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))
345348

346349

350+
- Deprecated passing `weights_summary` to the `Trainer` constructor in favor of adding the `ModelSummary` callback with `max_depth` directly to the list of callbacks ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))
351+
352+
347353
### Removed
348354

349355
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))

benchmarks/test_basic_parity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
159159
# as the first run is skipped, no need to run it long
160160
max_epochs=num_epochs if idx > 0 else 1,
161161
enable_progress_bar=False,
162-
weights_summary=None,
162+
enable_model_summary=False,
163163
gpus=1 if device_type == "cuda" else 0,
164164
checkpoint_callback=False,
165165
logger=False,

docs/source/common/debugging.rst

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,14 @@ Print a summary of your LightningModule
9595
---------------------------------------
9696
Whenever the ``.fit()`` function gets called, the Trainer will print the weights summary for the LightningModule.
9797
By default it only prints the top-level modules. If you want to show all submodules in your network, use the
98-
`'full'` option:
98+
``max_depth`` option:
9999

100100
.. testcode::
101101

102-
trainer = Trainer(weights_summary="full")
102+
from pytorch_lightning.callbacks import ModelSummary
103+
104+
trainer = Trainer(callbacks=[ModelSummary(max_depth=-1)])
105+
103106

104107
You can also display the intermediate input- and output sizes of all your layers by setting the
105108
``example_input_array`` attribute in your LightningModule. It will print a table like this
@@ -115,8 +118,9 @@ You can also display the intermediate input- and output sizes of all your layers
115118
when you call ``.fit()`` on the Trainer. This can help you find bugs in the composition of your layers.
116119

117120
See Also:
118-
- :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_summary` Trainer argument
119-
- :class:`~pytorch_lightning.core.memory.ModelSummary`
121+
- :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary`
122+
- :func:`~pytorch_lightning.utilities.model_summary.summarize`
123+
- :class:`~pytorch_lightning.utilities.model_summary.ModelSummary`
120124

121125
----------------
122126

docs/source/common/trainer.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,11 @@ Example::
15891589
weights_summary
15901590
^^^^^^^^^^^^^^^
15911591

1592+
.. warning:: `weights_summary` is deprecated in v1.5 and will be removed in v1.7. Please pass :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary`
1593+
directly to the Trainer's ``callbacks`` argument instead. To disable the model summary,
1594+
pass ``enable_model_summary = False`` to the Trainer.
1595+
1596+
15921597
.. raw:: html
15931598

15941599
<video width="50%" max-width="400px" controls
@@ -1611,6 +1616,25 @@ Options: 'full', 'top', None.
16111616
# don't print a summary
16121617
trainer = Trainer(weights_summary=None)
16131618

1619+
1620+
enable_model_summary
1621+
^^^^^^^^^^^^^^^^^^^^
1622+
1623+
Whether to enable or disable the model summarization. Defaults to True.
1624+
1625+
.. testcode::
1626+
1627+
# default used by the Trainer
1628+
trainer = Trainer(enable_model_summary=True)
1629+
1630+
# disable summarization
1631+
trainer = Trainer(enable_model_summary=False)
1632+
1633+
# enable custom summarization
1634+
from pytorch_lightning.callbacks import ModelSummary
1635+
1636+
trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])
1637+
16141638
-----
16151639

16161640
Trainer class API

pl_examples/bug_report_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def run():
5555
limit_val_batches=1,
5656
num_sanity_val_steps=0,
5757
max_epochs=1,
58-
weights_summary=None,
58+
enable_model_summary=False,
5959
)
6060
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
6161
trainer.test(model, dataloaders=test_data)

pl_examples/domain_templates/computer_vision_fine_tuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def add_arguments_to_parser(self, parser):
272272
parser.set_defaults(
273273
{
274274
"trainer.max_epochs": 15,
275-
"trainer.weights_summary": None,
275+
"trainer.enable_model_summary": False,
276276
"trainer.num_sanity_val_steps": 0,
277277
}
278278
)

pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def on_trainer_init(
4545
process_position: int,
4646
default_root_dir: Optional[str],
4747
weights_save_path: Optional[str],
48+
enable_model_summary: bool,
4849
weights_summary: Optional[str],
4950
stochastic_weight_avg: bool,
5051
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
@@ -101,7 +102,7 @@ def on_trainer_init(
101102
self.trainer._progress_bar_callback = None
102103

103104
# configure the ModelSummary callback
104-
self._configure_model_summary_callback(weights_summary)
105+
self._configure_model_summary_callback(enable_model_summary, weights_summary)
105106

106107
# accumulated grads
107108
self._configure_accumulated_gradients(accumulate_grad_batches)
@@ -159,24 +160,50 @@ def _configure_checkpoint_callbacks(self, checkpoint_callback: Optional[bool], e
159160
if not self._trainer_has_checkpoint_callbacks() and enable_checkpointing is True:
160161
self.trainer.callbacks.append(ModelCheckpoint())
161162

162-
def _configure_model_summary_callback(self, weights_summary: Optional[str] = None) -> None:
163-
if any(isinstance(cb, ModelSummary) for cb in self.trainer.callbacks):
163+
def _configure_model_summary_callback(
164+
self, enable_model_summary: bool, weights_summary: Optional[str] = None
165+
) -> None:
166+
if weights_summary is None:
167+
rank_zero_deprecation(
168+
"Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
169+
" in v1.7. Please set `Trainer(enable_model_summary=False)` instead."
170+
)
171+
return
172+
if not enable_model_summary:
173+
return
174+
175+
model_summary_cbs = [type(cb) for cb in self.trainer.callbacks if isinstance(cb, ModelSummary)]
176+
if model_summary_cbs:
177+
rank_zero_info(
178+
f"Trainer already configured with model summary callbacks: {model_summary_cbs}."
179+
" Skipping setting a default `ModelSummary` callback."
180+
)
164181
return
165-
if weights_summary is not None:
182+
183+
if weights_summary == "top":
184+
# special case the default value for weights_summary to preserve backward compatibility
185+
max_depth = 1
186+
else:
187+
rank_zero_deprecation(
188+
f"Setting `Trainer(weights_summary={weights_summary})` is deprecated in v1.5 and will be removed"
189+
" in v1.7. Please pass `pytorch_lightning.callbacks.model_summary.ModelSummary` with"
190+
" `max_depth` directly to the Trainer's `callbacks` argument instead."
191+
)
166192
if weights_summary not in ModelSummaryMode.supported_types():
167193
raise MisconfigurationException(
168194
f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}",
169195
f" but got {weights_summary}",
170196
)
171197
max_depth = ModelSummaryMode.get_max_depth(weights_summary)
172-
if self.trainer._progress_bar_callback is not None and isinstance(
173-
self.trainer._progress_bar_callback, RichProgressBar
174-
):
175-
model_summary = RichModelSummary(max_depth=max_depth)
176-
else:
177-
model_summary = ModelSummary(max_depth=max_depth)
178-
self.trainer.callbacks.append(model_summary)
179-
self.trainer.weights_summary = weights_summary
198+
199+
is_progress_bar_rich = isinstance(self.trainer._progress_bar_callback, RichProgressBar)
200+
201+
if self.trainer._progress_bar_callback is not None and is_progress_bar_rich:
202+
model_summary = RichModelSummary(max_depth=max_depth)
203+
else:
204+
model_summary = ModelSummary(max_depth=max_depth)
205+
self.trainer.callbacks.append(model_summary)
206+
self.trainer._weights_summary = weights_summary
180207

181208
def _configure_swa_callbacks(self):
182209
if not self.trainer._stochastic_weight_avg:

pytorch_lightning/trainer/trainer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def __init__(
157157
accelerator: Optional[Union[str, Accelerator]] = None,
158158
sync_batchnorm: bool = False,
159159
precision: Union[int, str] = 32,
160+
enable_model_summary: bool = True,
160161
weights_summary: Optional[str] = "top",
161162
weights_save_path: Optional[str] = None,
162163
num_sanity_val_steps: int = 2,
@@ -373,8 +374,16 @@ def __init__(
373374
val_check_interval: How often to check the validation set. Use float to check within a training epoch,
374375
use int to check every n steps (batches).
375376
377+
enable_model_summary: Whether to enable model summarization by default.
378+
376379
weights_summary: Prints a summary of the weights when training begins.
377380
381+
.. deprecated:: v1.5
382+
``weights_summary`` has been deprecated in v1.5 and will be removed in v1.7.
383+
To disable the summary, pass ``enable_model_summary = False`` to the Trainer.
384+
To customize the summary, pass :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary`
385+
directly to the Trainer's ``callbacks`` argument.
386+
378387
weights_save_path: Where to save weights if specified. Will override default_root_dir
379388
for checkpoints only. Use this if for whatever reason you need the checkpoints
380389
stored in a different place than the logs written in `default_root_dir`.
@@ -467,6 +476,9 @@ def __init__(
467476
self.tested_ckpt_path: Optional[str] = None
468477
self.predicted_ckpt_path: Optional[str] = None
469478

479+
# todo: remove in v1.7
480+
self._weights_summary: Optional[str] = None
481+
470482
# init callbacks
471483
# Declare attributes to be set in callback_connector on_trainer_init
472484
self.callback_connector.on_trainer_init(
@@ -478,6 +490,7 @@ def __init__(
478490
process_position,
479491
default_root_dir,
480492
weights_save_path,
493+
enable_model_summary,
481494
weights_summary,
482495
stochastic_weight_avg,
483496
max_time,
@@ -2032,6 +2045,16 @@ def _exit_gracefully_on_signal(self) -> None:
20322045
class_name = caller[0].f_locals["self"].__class__.__name__
20332046
raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}")
20342047

2048+
@property
2049+
def weights_summary(self) -> Optional[str]:
2050+
rank_zero_deprecation("`Trainer.weights_summary` is deprecated in v1.5 and will be removed in v1.7.")
2051+
return self._weights_summary
2052+
2053+
@weights_summary.setter
2054+
def weights_summary(self, val: Optional[str]) -> None:
2055+
rank_zero_deprecation("Setting `Trainer.weights_summary` is deprecated in v1.5 and will be removed in v1.7.")
2056+
self._weights_summary = val
2057+
20352058
"""
20362059
Other
20372060
"""

tests/accelerators/test_multi_nodes_gpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def validation_step(self, batch, batch_idx):
5454
limit_train_batches=1,
5555
limit_val_batches=1,
5656
max_epochs=2,
57-
weights_summary=None,
57+
enable_model_summary=False,
5858
accelerator="ddp",
5959
gpus=1,
6060
num_nodes=2,
@@ -101,7 +101,7 @@ def backward(self, loss, optimizer, optimizer_idx):
101101
limit_val_batches=2,
102102
max_epochs=2,
103103
log_every_n_steps=1,
104-
weights_summary=None,
104+
enable_model_summary=False,
105105
accelerator="ddp",
106106
gpus=1,
107107
num_nodes=2,

tests/callbacks/test_callback_hook_outputs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def training_epoch_end(self, outputs) -> None:
5353
limit_val_batches=2,
5454
max_epochs=1,
5555
log_every_n_steps=1,
56-
weights_summary=None,
56+
enable_model_summary=False,
5757
)
5858

5959
assert any(isinstance(c, CB) for c in trainer.callbacks)
@@ -74,7 +74,7 @@ def on_epoch_end(self, trainer, pl_module):
7474
limit_train_batches=2,
7575
limit_val_batches=2,
7676
max_epochs=1,
77-
weights_summary=None,
77+
enable_model_summary=False,
7878
)
7979

8080
trainer.fit(model)

0 commit comments

Comments
 (0)