-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix checkpoint callback & Trainer.test(_) issue for TPUs #6654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3f6fe20
c2dc663
6541db6
9f6aa40
d47fa9b
312b84e
6a4ee36
38dc8e2
e18dfe4
80f15c1
4c69b62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,6 @@ | |
from typing import Any, Dict, Iterable, List, Optional, Union | ||
|
||
import torch | ||
import torch.distributed as torch_distrib | ||
import torch.multiprocessing as mp | ||
|
||
from pytorch_lightning.core.lightning import LightningModule | ||
|
@@ -109,13 +108,15 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: | |
|
||
# replace trainer save_checkpoint to use `xm.save` | ||
trainer.save_checkpoint = self.save_checkpoint | ||
self.barrier() | ||
self.barrier("pre-run-stage") | ||
|
||
results = trainer.run_stage() | ||
|
||
self.__save_end_of_training_weights(self.lightning_module) | ||
self.transfer_distrib_spawn_state_on_fit_end(results) | ||
|
||
self.barrier("end-process") | ||
|
||
def __save_end_of_training_weights(self, model: LightningModule) -> None: | ||
# when training ends on these platforms dump weights to get out of the main process | ||
if on_colab_kaggle(): | ||
|
@@ -126,11 +127,11 @@ def model_to_device(self) -> None: | |
self._model.to(xm.xla_device()) | ||
|
||
def barrier(self, name: Optional[str] = None) -> None: | ||
if torch_distrib.is_initialized(): | ||
rendezvous(f"pl.Trainer.{name}") | ||
rendezvous(name) | ||
|
||
def transfer_distrib_spawn_state_on_fit_end(self, results): | ||
best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path | ||
checkpoint_callback = self.lightning_module.trainer.checkpoint_callback | ||
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None | ||
Comment on lines
+133
to
+134
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens if there are multiple checkpoint callbacks attached? should we save once per path? @awaelchli @carmocca this is gonna be amplified if people are tracking multiple versions of "best model paths" at the same time in an example like this
should this raise an error due to ambiguity? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I'd rather use the first |
||
|
||
if self.mp_queue is not None: | ||
rank_zero_warn("cleaning up ddp environment...") | ||
|
Uh oh!
There was an error while loading. Please reload this page.