Skip to content

Commit ce93d8b

Browse files
vatch123awaelchlicarmocca
authored
Handle errors due to uninitailized parameters (#7642)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent cca0e75 commit ce93d8b

File tree

3 files changed

+59
-3
lines changed

3 files changed

+59
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8080
- Added LightningCLI support for argument links applied on instantiation ([#7895](https://github.com/PyTorchLightning/pytorch-lightning/pull/7895))
8181

8282

83+
- Added support for `torch.nn.UninitializedParameter` in `ModelSummary` ([#7642](https://github.com/PyTorchLightning/pytorch-lightning/pull/7642))
84+
85+
8386
### Changed
8487

8588

pytorch_lightning/core/memory.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@
2121
import numpy as np
2222
import torch
2323
import torch.nn as nn
24+
from torch import Tensor
2425
from torch.utils.hooks import RemovableHandle
2526

2627
from pytorch_lightning.utilities import AMPType, DeviceType
28+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
29+
from pytorch_lightning.utilities.warnings import WarningCache
30+
31+
warning_cache = WarningCache()
2732

2833
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
2934
UNKNOWN_SIZE = "?"
@@ -118,7 +123,7 @@ def layer_type(self) -> str:
118123
@property
119124
def num_parameters(self) -> int:
120125
""" Returns the number of parameters in this module. """
121-
return sum(np.prod(p.shape) for p in self._module.parameters())
126+
return sum(np.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
122127

123128

124129
class ModelSummary(object):
@@ -225,11 +230,13 @@ def param_nums(self) -> List[int]:
225230

226231
@property
227232
def total_parameters(self) -> int:
228-
return sum(p.numel() for p in self._model.parameters())
233+
return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
229234

230235
@property
231236
def trainable_parameters(self) -> int:
232-
return sum(p.numel() for p in self._model.parameters() if p.requires_grad)
237+
return sum(
238+
p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters() if p.requires_grad
239+
)
233240

234241
@property
235242
def model_size(self) -> float:
@@ -438,3 +445,15 @@ def get_human_readable_count(number: int) -> str:
438445
return f"{int(number):,d} {labels[index]}"
439446

440447
return f"{number:,.1f} {labels[index]}"
448+
449+
450+
def _is_lazy_weight_tensor(p: Tensor) -> bool:
451+
if _TORCH_GREATER_EQUAL_1_8:
452+
from torch.nn.parameter import UninitializedParameter
453+
if isinstance(p, UninitializedParameter):
454+
warning_cache.warn(
455+
"A layer with UninitializedParameter was found. "
456+
"Thus, the total number of parameters detected may be inaccurate."
457+
)
458+
return True
459+
return False

tests/core/test_memory.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from pytorch_lightning import LightningModule, Trainer
1919
from pytorch_lightning.core.memory import ModelSummary, UNKNOWN_SIZE
20+
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_9
2021
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2122
from tests.helpers import BoringModel
2223
from tests.helpers.advanced_models import ParityModuleRNN
@@ -101,6 +102,18 @@ def forward(self, x):
101102
return self.layer2(self.layer1(x))
102103

103104

105+
class LazyModel(LightningModule):
106+
""" A model which contains lazy layers with unintialized parameters. """
107+
108+
def __init__(self):
109+
super().__init__()
110+
self.layer1 = nn.LazyLinear(5)
111+
self.layer2 = nn.LazyLinear(2)
112+
113+
def forward(self, inp):
114+
return self.layer2(self.layer1(inp))
115+
116+
104117
def test_invalid_weights_summmary():
105118
""" Test that invalid value for weights_summary raises an error. """
106119
with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'):
@@ -302,3 +315,24 @@ def test_model_size_precision(tmpdir):
302315
trainer.fit(model)
303316
summary = model.summarize()
304317
assert model.pre_calculated_model_size == summary.model_size
318+
319+
320+
@RunIf(min_torch="1.8")
321+
def test_lazy_model_summary():
322+
""" Test that the model summary can work with lazy layers. """
323+
lazy_model = LazyModel()
324+
summary = ModelSummary(lazy_model)
325+
326+
with pytest.warns(
327+
UserWarning,
328+
match=r"A layer with UninitializedParameter was found. "
329+
r"Thus, the total number of parameters detected may be inaccurate."
330+
):
331+
if _TORCH_GREATER_EQUAL_1_9:
332+
assert summary.total_parameters == 0
333+
assert summary.trainable_parameters == 0
334+
else:
335+
# bug in 1.8: the bias of a LazyLinear layer is initialized!
336+
# https://github.com/pytorch/pytorch/issues/58350
337+
assert summary.total_parameters == 7
338+
assert summary.trainable_parameters == 7

0 commit comments

Comments
 (0)