Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict, namedtuple
from tempfile import _TemporaryFileWrapper # type: ignore
from typing import Callable, Mapping, Optional, Union

import torch
Expand Down Expand Up @@ -235,7 +236,7 @@ def score_function(engine):

def __init__(
self,
to_save: Mapping,
to_save: Optional[Mapping],
save_handler: Union[Callable, BaseSaveHandler],
filename_prefix: str = "",
score_function: Optional[Callable] = None,
Expand Down Expand Up @@ -287,7 +288,7 @@ def __init__(
self.ext = "pt"
self.global_step_transform = global_step_transform
self.filename_pattern = filename_pattern
self._saved = []
self._saved = [] # type: list
self.include_self = include_self

@property
Expand Down Expand Up @@ -378,10 +379,11 @@ def __call__(self, engine: Engine) -> None:

def _setup_checkpoint(self) -> dict:
checkpoint = {}
for k, obj in self.to_save.items():
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
checkpoint[k] = obj.state_dict()
if self.to_save is not None:
for k, obj in self.to_save.items():
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
checkpoint[k] = obj.state_dict()
return checkpoint

@staticmethod
Expand Down Expand Up @@ -572,7 +574,7 @@ def _save_native(self, checkpoint: Mapping, path: str):
self._save_func(checkpoint, path, torch.save)

def _save_xla(self, checkpoint: Mapping, path: str):
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_model as xm # type: ignore

# all tpu procs should enter here as internally performs sync across device
self._save_func(checkpoint, path, xm.save, rank=idist.get_rank())
Expand All @@ -582,8 +584,8 @@ def _save_func(self, checkpoint: Mapping, path: str, func: Callable, rank: int =
func(checkpoint, path, **self.kwargs)
else:
tmp_file = None
tmp_name = None
tmp = None
tmp_name = ""
tmp = None # type: _TemporaryFileWrapper
if rank == 0:
tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname)
tmp_file = tmp.file
Expand Down Expand Up @@ -728,9 +730,15 @@ def __init__(
def last_checkpoint(self) -> Union[str, None]:
if len(self._saved) < 1:
return None

if not isinstance(self.save_handler, DiskSaver):
raise RuntimeError(
"Unable to save checkpoint, save_handler should be DiskSaver, got {}.".format(type(self.save_handler))
)

return os.path.join(self.save_handler.dirname, self._saved[-1].filename)

def __call__(self, engine: Engine, to_save: Mapping) -> None:
def __call__(self, engine: Engine, to_save: Mapping) -> None: # type: ignore

if len(to_save) == 0:
raise RuntimeError("No objects to checkpoint found.")
Expand Down
4 changes: 0 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ show_error_codes = True

ignore_errors = True

[mypy-ignite.handlers.*]

ignore_errors = True

[mypy-ignite.engine.*]

ignore_errors = True
Expand Down
14 changes: 14 additions & 0 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,20 @@ def _test(ext, require_empty):
_test(".pt", require_empty=False)


def test_model_checkpoint_invalid_save_handler(dirname):
h = ModelCheckpoint(dirname, _PREFIX)
to_save = {"model": DummyModel()}
# Redefine save_handler
h.save_handler = lambda x, y: None
h(Engine(lambda x, y: None), to_save)

with pytest.raises(
RuntimeError,
match=r"Unable to save checkpoint, save_handler should be DiskSaver, got {}.".format(type(h.save_handler)),
):
h.last_checkpoint


def test_disk_saver_atomic(dirname):

model = DummyModel()
Expand Down