diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index 6870e1b0ff6f..6f4321b7812f 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -5,6 +5,7 @@ import warnings from abc import ABCMeta, abstractmethod from collections import OrderedDict, namedtuple +from tempfile import _TemporaryFileWrapper # type: ignore from typing import Callable, Mapping, Optional, Union import torch @@ -235,7 +236,7 @@ def score_function(engine): def __init__( self, - to_save: Mapping, + to_save: Optional[Mapping], save_handler: Union[Callable, BaseSaveHandler], filename_prefix: str = "", score_function: Optional[Callable] = None, @@ -287,7 +288,7 @@ def __init__( self.ext = "pt" self.global_step_transform = global_step_transform self.filename_pattern = filename_pattern - self._saved = [] + self._saved = [] # type: list self.include_self = include_self @property @@ -378,10 +379,11 @@ def __call__(self, engine: Engine) -> None: def _setup_checkpoint(self) -> dict: checkpoint = {} - for k, obj in self.to_save.items(): - if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): - obj = obj.module - checkpoint[k] = obj.state_dict() + if self.to_save is not None: + for k, obj in self.to_save.items(): + if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)): + obj = obj.module + checkpoint[k] = obj.state_dict() return checkpoint @staticmethod @@ -572,7 +574,7 @@ def _save_native(self, checkpoint: Mapping, path: str): self._save_func(checkpoint, path, torch.save) def _save_xla(self, checkpoint: Mapping, path: str): - import torch_xla.core.xla_model as xm + import torch_xla.core.xla_model as xm # type: ignore # all tpu procs should enter here as internally performs sync across device self._save_func(checkpoint, path, xm.save, rank=idist.get_rank()) @@ -582,8 +584,8 @@ def _save_func(self, checkpoint: Mapping, path: str, func: Callable, rank: int = func(checkpoint, path, **self.kwargs) else: tmp_file = None - tmp_name = None - tmp = None + tmp_name = "" + tmp = None # type: _TemporaryFileWrapper if rank == 0: tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname) tmp_file = tmp.file @@ -728,9 +730,15 @@ def __init__( def last_checkpoint(self) -> Union[str, None]: if len(self._saved) < 1: return None + + if not isinstance(self.save_handler, DiskSaver): + raise RuntimeError( + "Unable to save checkpoint, save_handler should be DiskSaver, got {}.".format(type(self.save_handler)) + ) + return os.path.join(self.save_handler.dirname, self._saved[-1].filename) - def __call__(self, engine: Engine, to_save: Mapping) -> None: + def __call__(self, engine: Engine, to_save: Mapping) -> None: # type: ignore if len(to_save) == 0: raise RuntimeError("No objects to checkpoint found.") diff --git a/mypy.ini b/mypy.ini index 24c1a726254d..2778b2e7490e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -7,10 +7,6 @@ show_error_codes = True ignore_errors = True -[mypy-ignite.handlers.*] - -ignore_errors = True - [mypy-ignite.engine.*] ignore_errors = True diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index 6478ff491040..11d044e2aa37 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -563,6 +563,20 @@ def _test(ext, require_empty): _test(".pt", require_empty=False) +def test_model_checkpoint_invalid_save_handler(dirname): + h = ModelCheckpoint(dirname, _PREFIX) + to_save = {"model": DummyModel()} + # Redefine save_handler + h.save_handler = lambda x, y: None + h(Engine(lambda x, y: None), to_save) + + with pytest.raises( + RuntimeError, + match=r"Unable to save checkpoint, save_handler should be DiskSaver, got {}.".format(type(h.save_handler)), + ): + h.last_checkpoint + + def test_disk_saver_atomic(dirname): model = DummyModel()