Skip to content

Commit a1829bc

Browse files
kaushikb11Borda
andcommitted
Fix checkpoint callback & Trainer.test(_) issue for TPUs (#6654)
* Fix checkpoint callback issue for TPUs * update changelog * add barrier * apply code suggestions * update trainer test * remove spaces * fix tpu tests * Apply suggestions from code review * add comment Co-authored-by: Jirka Borovec <[email protected]>
1 parent 014a6b7 commit a1829bc

File tree

4 files changed

+27
-10
lines changed

4 files changed

+27
-10
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313
### Fixed
1414

1515
- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398))
16+
- Fixed error on TPUs when there was no `ModelCheckpoint` ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654))
17+
- Fixed `trainer.test` freeze on TPUs ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654))
1618

1719

1820
## [1.2.5] - 2021-03-23

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Any, Dict, Iterable, List, Optional, Union
55

66
import torch
7-
import torch.distributed as torch_distrib
87
import torch.multiprocessing as mp
98

109
from pytorch_lightning.core.lightning import LightningModule
@@ -96,13 +95,15 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
9695

9796
# replace trainer save_checkpoint to use `xm.save`
9897
trainer.save_checkpoint = self.save_checkpoint
99-
self.barrier()
98+
self.barrier("pre-run-stage")
10099

101100
results = trainer.train_or_test_or_predict()
102101

103102
self.__save_end_of_training_weights(self.lightning_module)
104103
self.transfer_distrib_spawn_state_on_fit_end(results)
105104

105+
self.barrier("end-process")
106+
106107
def __save_end_of_training_weights(self, model: LightningModule) -> None:
107108
# when training ends on these platforms dump weights to get out of the main process
108109
if on_colab_kaggle():
@@ -113,12 +114,11 @@ def model_to_device(self) -> None:
113114
self._model.to(xm.xla_device())
114115

115116
def barrier(self, name: Optional[str] = None) -> None:
116-
if torch_distrib.is_initialized():
117-
rendezvous(f"pl.Trainer.{name}")
117+
rendezvous(name)
118118

119119
def transfer_distrib_spawn_state_on_fit_end(self, results):
120-
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
121-
best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path
120+
checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
121+
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
122122

123123
if self.mp_queue is not None:
124124
rank_zero_warn("cleaning up ddp environment...")

pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from pytorch_lightning.trainer.training_loop import TrainLoop
5757
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
5858
from pytorch_lightning.tuner.tuning import Tuner
59-
from pytorch_lightning.utilities import rank_zero_warn
59+
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
6060
from pytorch_lightning.utilities.cloud_io import load as pl_load
6161
from pytorch_lightning.utilities.debugging import InternalDebugger
6262
from pytorch_lightning.utilities.enums import LightningEnum
@@ -942,7 +942,9 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
942942
)
943943
return {}
944944

945-
self.training_type_plugin.barrier()
945+
# only one process running at this point for TPUs, as spawn isn't triggered yet
946+
if not self._device_type == DeviceType.TPU:
947+
self.training_type_plugin.barrier()
946948

947949
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
948950
model.load_state_dict(ckpt['state_dict'])

tests/models/test_tpu.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,13 +349,14 @@ def test_reduce(rank):
349349
xmp.spawn(test_reduce, nprocs=8, start_method='fork')
350350

351351

352-
@pytest.mark.parametrize("clip_val", [0, 10])
353-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
352+
@RunIf(tpu=True)
354353
@pl_multi_process_test
354+
@pytest.mark.parametrize("clip_val", [10])
355355
@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_")
356356
def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
357357
"""
358358
Ensure that clip gradients is only called if the value is greater than 0.
359+
TODO: Fix (test fails with parametrize)
359360
"""
360361
tutils.reset_seed()
361362
trainer_options = dict(
@@ -375,3 +376,15 @@ def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
375376
mock_clip_grad_norm.assert_called()
376377
else:
377378
mock_clip_grad_norm.assert_not_called()
379+
380+
381+
@RunIf(tpu=True)
382+
@pl_multi_process_test
383+
def test_if_test_works_with_checkpoint_false(tmpdir):
384+
"""Ensure that model trains properly when `checkpoint_callback` is set to False."""
385+
386+
# Train a model on TPU
387+
model = BoringModel()
388+
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
389+
trainer.fit(model)
390+
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

0 commit comments

Comments
 (0)