Skip to content

Commit db77223

Browse files
shuyingsunshine21carmoccaananthsub
authored andcommitted
[sharded plugin] Fix check for fp16 precision (#7825)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: ananthsub <[email protected]> (cherry picked from commit ca89a7f)
1 parent c585913 commit db77223

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
208208
- Fixed training loop total batch counter when accumulate grad batches was enabled ([#7692](https://github.com/PyTorchLightning/pytorch-lightning/pull/7692))
209209

210210

211+
- Fixed a bug where checking `trainer.precision` changed to `'mixed'` when specifying 16 in trainer ([#7825](https://github.com/PyTorchLightning/pytorch-lightning/pull/7825))
212+
213+
211214
## [1.3.2] - 2021-05-18
212215

213216
### Changed

pytorch_lightning/plugins/training_type/sharded.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def _reinit_optimizers_with_oss(self):
5454
optim_class = type(optimizer)
5555
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
5656
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
57-
is_fp16 = self.lightning_module.trainer.precision == 16
57+
precision = self.lightning_module.trainer.precision
58+
is_fp16 = precision in ("mixed", 16)
5859
# For multi-node training, compressing the model shards in fp16 before broadcasting
5960
# improves performance. When using PyTorch AMP, it will not degrade
6061
# the model performance.

0 commit comments

Comments
 (0)