Skip to content

Commit 9d31aff

Browse files
kaushikb11lexierule
authored andcommitted
Update Gradient Clipping for TPU Accelerator (#6576)
(cherry picked from commit 87c03b1)
1 parent f4a2dff commit 9d31aff

File tree

4 files changed

+46
-1
lines changed

4 files changed

+46
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
136136

137137
### Changed
138138

139+
- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576))
140+
139141

140142
### Fixed
141143

pytorch_lightning/accelerators/tpu.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
if _XLA_AVAILABLE:
1414
import torch_xla.core.xla_model as xm
15+
from torch_xla._patched_functions import clip_grad_norm_
16+
17+
xla_clip_grad_norm_ = clip_grad_norm_
1518

1619

1720
class TPUAccelerator(Accelerator):
@@ -44,3 +47,16 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
4447
if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
4548
return xm.all_gather(tensor).view(-1, *tensor.shape)
4649
return tensor
50+
51+
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0):
52+
53+
model = self.lightning_module
54+
parameters = model.parameters()
55+
56+
grad_clip_val = float(clip_val)
57+
if grad_clip_val <= 0:
58+
return
59+
60+
max_norm = grad_clip_val
61+
62+
xla_clip_grad_norm_(parameters, max_norm, norm_type)

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
8888

8989
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)) -> None:
9090
"""Clips the gradients to a specific value"""
91-
# TODO: separate TPU case from here
9291
if clip_val is None:
9392
return
9493

tests/models/test_tpu.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,31 @@ def test_reduce(rank):
347347
assert result.item() == 8
348348

349349
xmp.spawn(test_reduce, nprocs=8, start_method='fork')
350+
351+
352+
@pytest.mark.parametrize("clip_val", [0, 10])
353+
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
354+
@pl_multi_process_test
355+
@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_")
356+
def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
357+
"""
358+
Ensure that clip gradients is only called if the value is greater than 0.
359+
"""
360+
tutils.reset_seed()
361+
trainer_options = dict(
362+
default_root_dir=tmpdir,
363+
progress_bar_refresh_rate=0,
364+
max_epochs=1,
365+
tpu_cores=1,
366+
precision=16,
367+
limit_train_batches=4,
368+
limit_val_batches=4,
369+
gradient_clip_val=clip_val,
370+
)
371+
model = BoringModel()
372+
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
373+
374+
if clip_val > 0:
375+
mock_clip_grad_norm.assert_called()
376+
else:
377+
mock_clip_grad_norm.assert_not_called()

0 commit comments

Comments
 (0)