Skip to content

Commit a635897

Browse files
gruebelsdesrozisvfdev-5
authored
Activate mypy in ignite.engine (#1379)
* Activate mypy in ignite.engine * Fix missing import * Fix typing issues with nighty build * Fix PR findings Co-authored-by: Sylvain Desroziers <[email protected]> Co-authored-by: vfdev <[email protected]>
1 parent 260017d commit a635897

File tree

7 files changed

+127
-94
lines changed

7 files changed

+127
-94
lines changed

ignite/engine/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
23

34
import torch
@@ -27,7 +28,7 @@
2728

2829
def _prepare_batch(
2930
batch: Sequence[torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False
30-
):
31+
) -> Tuple[Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes], ...]:
3132
"""Prepare batch for training: pass to a device with options.
3233
3334
"""

ignite/engine/deterministic.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
from collections import OrderedDict
44
from functools import wraps
5-
from typing import Callable, Generator, Iterator, Optional
5+
from typing import Any, Callable, Generator, Iterator, List, Optional, cast
66

77
import torch
88
from torch.utils.data import DataLoader
@@ -61,7 +61,7 @@ def __init__(self, batch_sampler: BatchSampler, start_iteration: Optional[int] =
6161
if not isinstance(batch_sampler, BatchSampler):
6262
raise TypeError("Argument batch_sampler should be torch.utils.data.sampler.BatchSampler")
6363

64-
self.batch_indices = None
64+
self.batch_indices = [] # type: List
6565
self.batch_sampler = batch_sampler
6666
self.start_iteration = start_iteration
6767
self.sampler = self.batch_sampler.sampler
@@ -84,7 +84,7 @@ def __len__(self) -> int:
8484
return len(self.batch_sampler)
8585

8686

87-
def _get_rng_states():
87+
def _get_rng_states() -> List[Any]:
8888
output = [random.getstate(), torch.get_rng_state()]
8989
try:
9090
import numpy as np
@@ -96,7 +96,7 @@ def _get_rng_states():
9696
return output
9797

9898

99-
def _set_rng_states(rng_states):
99+
def _set_rng_states(rng_states: List[Any]) -> None:
100100
random.setstate(rng_states[0])
101101
torch.set_rng_state(rng_states[1])
102102
try:
@@ -107,14 +107,14 @@ def _set_rng_states(rng_states):
107107
pass
108108

109109

110-
def _repr_rng_state(rng_states):
110+
def _repr_rng_state(rng_states: List[Any]) -> str:
111111
from hashlib import md5
112112

113113
out = " ".join([md5(str(list(s)).encode("utf-8")).hexdigest() for s in rng_states])
114114
return out
115115

116116

117-
def keep_random_state(func: Callable):
117+
def keep_random_state(func: Callable) -> Callable:
118118
"""Helper decorator to keep random state of torch, numpy and random intact
119119
while executing a function. For more details on usage, please see :ref:`Dataflow synchronization`.
120120
@@ -123,7 +123,7 @@ def keep_random_state(func: Callable):
123123
"""
124124

125125
@wraps(func)
126-
def wrapper(*args, **kwargs):
126+
def wrapper(*args: Any, **kwargs: Any) -> None:
127127
rng_states = _get_rng_states()
128128
func(*args, **kwargs)
129129
_set_rng_states(rng_states)
@@ -181,16 +181,20 @@ def state_dict(self) -> OrderedDict:
181181
return state_dict
182182

183183
def _init_run(self) -> None:
184-
seed = torch.randint(0, int(1e9), (1,)).item()
185-
self.state.seed = seed
184+
self.state.seed = int(torch.randint(0, int(1e9), (1,)).item())
186185
if not hasattr(self.state, "rng_states"):
187-
self.state.rng_states = None
186+
self.state.rng_states = None # type: ignore[attr-defined]
188187

189188
if torch.cuda.is_available():
190189
torch.backends.cudnn.deterministic = True
191190
torch.backends.cudnn.benchmark = False
192191

193192
def _setup_engine(self) -> None:
193+
if self.state.dataloader is None:
194+
raise RuntimeError(
195+
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
196+
)
197+
194198
self._dataloader_len = self._get_data_length(self.state.dataloader)
195199

196200
# if input data is torch dataloader we replace batch sampler by a batch sampler
@@ -199,22 +203,24 @@ def _setup_engine(self) -> None:
199203
# attribute _dataset_kind is introduced since 1.3.0 => before 1.3.0 all datasets are map-like
200204
can_patch_dataloader = True
201205
if hasattr(self.state.dataloader, "_dataset_kind"):
202-
from torch.utils.data.dataloader import _DatasetKind
206+
from torch.utils.data.dataloader import _DatasetKind # type: ignore[attr-defined]
203207

204-
_dataloader_kind = self.state.dataloader._dataset_kind
208+
_dataloader_kind = self.state.dataloader._dataset_kind # type: ignore[attr-defined]
205209
can_patch_dataloader = _dataloader_kind == _DatasetKind.Map
206210
if can_patch_dataloader:
207-
if (self._dataloader_len is not None) and hasattr(self.state.dataloader.sampler, "epoch"):
211+
if self._dataloader_len is not None and hasattr(
212+
self.state.dataloader.sampler, "epoch" # type: ignore[attr-defined]
213+
):
208214
if self._dataloader_len != self.state.epoch_length:
209215
warnings.warn(
210216
"When defined engine's epoch length is different of input dataloader length, "
211217
"distributed sampler indices can not be setup in a reproducible manner"
212218
)
213219

214-
batch_sampler = self.state.dataloader.batch_sampler
220+
batch_sampler = self.state.dataloader.batch_sampler # type: ignore[attr-defined]
215221
if not (batch_sampler is None or isinstance(batch_sampler, ReproducibleBatchSampler)):
216222
self.state.dataloader = update_dataloader(
217-
self.state.dataloader, ReproducibleBatchSampler(batch_sampler)
223+
self.state.dataloader, ReproducibleBatchSampler(batch_sampler) # type: ignore[arg-type]
218224
)
219225

220226
iteration = self.state.iteration
@@ -228,28 +234,32 @@ def _setup_engine(self) -> None:
228234
# restore rng state if in the middle
229235
in_the_middle = self.state.iteration % self._dataloader_len > 0 if self._dataloader_len is not None else False
230236
if (getattr(self.state, "rng_states", None) is not None) and in_the_middle:
231-
_set_rng_states(self.state.rng_states)
232-
self.state.rng_states = None
237+
_set_rng_states(self.state.rng_states) # type: ignore[attr-defined]
238+
self.state.rng_states = None # type: ignore[attr-defined]
233239

234240
def _from_iteration(self, iteration: int) -> Iterator:
241+
if self.state.dataloader is None:
242+
raise RuntimeError(
243+
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
244+
)
235245
data = self.state.dataloader
236246
if isinstance(data, DataLoader):
237247
try:
238248
# following is unsafe for IterableDatasets
239-
iteration %= len(data.batch_sampler)
249+
iteration %= len(data.batch_sampler) # type: ignore[attr-defined, arg-type]
240250
# Synchronize dataflow according to state.iteration
241251
self._setup_seed()
242252
if iteration > 0:
243253
# batch sampler is ReproducibleBatchSampler
244-
data.batch_sampler.start_iteration = iteration
254+
data.batch_sampler.start_iteration = iteration # type: ignore[attr-defined, union-attr]
245255
return iter(data)
246256
except TypeError as e:
247257
# Probably we can do nothing with DataLoader built upon IterableDatasets
248258
pass
249259

250260
self.logger.info("Resuming from iteration for provided data will fetch data until required iteration ...")
251261
if hasattr(data, "__len__"):
252-
iteration %= len(data)
262+
iteration %= len(data) # type: ignore[arg-type]
253263
# Synchronize dataflow from the begining
254264
self._setup_seed(iteration=0)
255265
data_iter = iter(data)
@@ -263,11 +273,11 @@ def _from_iteration(self, iteration: int) -> Iterator:
263273

264274
return data_iter
265275

266-
def _setup_seed(self, _=None, iter_counter=None, iteration=None):
276+
def _setup_seed(self, _: Any = None, iter_counter: Optional[int] = None, iteration: Optional[int] = None) -> None:
267277
if iter_counter is None:
268278
le = self._dataloader_len if self._dataloader_len is not None else 1
269279
else:
270280
le = iter_counter
271281
if iteration is None:
272282
iteration = self.state.iteration
273-
manual_seed(self.state.seed + iteration // le)
283+
manual_seed(self.state.seed + iteration // le) # type: ignore[operator]

ignite/engine/engine.py

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import weakref
66
from collections import OrderedDict, defaultdict
77
from collections.abc import Mapping
8-
from typing import Any, Callable, Iterable, List, Optional, Union
8+
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
9+
10+
from torch.utils.data import DataLoader
911

1012
from ignite._utils import _to_hours_mins_secs
1113
from ignite.base import Serializable
@@ -120,18 +122,18 @@ def compute_mean_std(engine, batch):
120122
_state_dict_one_of_opt_keys = ("iteration", "epoch")
121123

122124
def __init__(self, process_function: Callable):
123-
self._event_handlers = defaultdict(list)
125+
self._event_handlers = defaultdict(list) # type: Dict[Any, List]
124126
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
125127
self._process_function = process_function
126-
self.last_event_name = None
128+
self.last_event_name = None # type: Optional[Events]
127129
self.should_terminate = False
128130
self.should_terminate_single_epoch = False
129131
self.state = State()
130-
self._state_dict_user_keys = []
131-
self._allowed_events = []
132+
self._state_dict_user_keys = [] # type: List[str]
133+
self._allowed_events = [] # type: List[EventEnum]
132134

133-
self._dataloader_iter = None
134-
self._init_iter = []
135+
self._dataloader_iter = None # type: Optional[Iterator[Any]]
136+
self._init_iter = [] # type: List[int]
135137

136138
self.register_events(*Events)
137139

@@ -232,16 +234,16 @@ def _handler_wrapper(self, handler: Callable, event_name: Any, event_filter: Cal
232234
# signature of the following wrapper will be inspected during registering to check if engine is necessary
233235
# we have to build a wrapper with relevant signature : solution is functools.wraps
234236
@functools.wraps(handler)
235-
def wrapper(*args, **kwargs) -> Any:
237+
def wrapper(*args: Any, **kwargs: Any) -> Any:
236238
event = self.state.get_event_attrib_value(event_name)
237239
if event_filter(self, event):
238240
return handler(*args, **kwargs)
239241

240242
# setup input handler as parent to make has_event_handler work
241-
wrapper._parent = weakref.ref(handler)
243+
wrapper._parent = weakref.ref(handler) # type: ignore[attr-defined]
242244
return wrapper
243245

244-
def add_event_handler(self, event_name: Any, handler: Callable, *args, **kwargs):
246+
def add_event_handler(self, event_name: Any, handler: Callable, *args: Any, **kwargs: Any) -> RemovableEventHandle:
245247
"""Add an event handler to be executed when the specified event is fired.
246248
247249
Args:
@@ -312,7 +314,7 @@ def execute_something():
312314
return RemovableEventHandle(event_name, handler, self)
313315

314316
@staticmethod
315-
def _assert_non_filtered_event(event_name: Any):
317+
def _assert_non_filtered_event(event_name: Any) -> None:
316318
if (
317319
isinstance(event_name, CallableEventWithFilter)
318320
and event_name.filter != CallableEventWithFilter.default_event_filter
@@ -321,7 +323,7 @@ def _assert_non_filtered_event(event_name: Any):
321323
"Argument event_name should not be a filtered event, " "please use event without any event filtering"
322324
)
323325

324-
def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None):
326+
def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None) -> bool:
325327
"""Check if the specified event has the specified handler.
326328
327329
Args:
@@ -332,7 +334,7 @@ def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None)
332334
if event_name is not None:
333335
if event_name not in self._event_handlers:
334336
return False
335-
events = [event_name]
337+
events = [event_name] # type: Union[List[Any], Dict[Any, List]]
336338
else:
337339
events = self._event_handlers
338340
for e in events:
@@ -344,10 +346,10 @@ def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None)
344346
@staticmethod
345347
def _compare_handlers(user_handler: Callable, registered_handler: Callable) -> bool:
346348
if hasattr(registered_handler, "_parent"):
347-
registered_handler = registered_handler._parent()
349+
registered_handler = registered_handler._parent() # type: ignore[attr-defined]
348350
return registered_handler == user_handler
349351

350-
def remove_event_handler(self, handler: Callable, event_name: Any):
352+
def remove_event_handler(self, handler: Callable, event_name: Any) -> None:
351353
"""Remove event handler `handler` from registered handlers of the engine
352354
353355
Args:
@@ -367,7 +369,7 @@ def remove_event_handler(self, handler: Callable, event_name: Any):
367369
raise ValueError("Input handler '{}' is not found among registered event handlers".format(handler))
368370
self._event_handlers[event_name] = new_event_handlers
369371

370-
def on(self, event_name, *args, **kwargs):
372+
def on(self, event_name: Any, *args: Any, **kwargs: Any) -> Callable:
371373
"""Decorator shortcut for add_event_handler.
372374
373375
Args:
@@ -398,7 +400,7 @@ def decorator(f: Callable) -> Callable:
398400

399401
return decorator
400402

401-
def _fire_event(self, event_name: Any, *event_args, **event_kwargs) -> None:
403+
def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) -> None:
402404
"""Execute all the handlers associated with given event.
403405
404406
This method executes all handlers associated with the event
@@ -460,7 +462,7 @@ def terminate_epoch(self) -> None:
460462
)
461463
self.should_terminate_single_epoch = True
462464

463-
def _handle_exception(self, e: Exception) -> None:
465+
def _handle_exception(self, e: BaseException) -> None:
464466
if Events.EXCEPTION_RAISED in self._event_handlers:
465467
self._fire_event(Events.EXCEPTION_RAISED, e)
466468
else:
@@ -497,7 +499,7 @@ def save_engine(_):
497499
a dictionary containing engine's state
498500
499501
"""
500-
keys = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],)
502+
keys = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],) # type: Tuple[str, ...]
501503
keys += tuple(self._state_dict_user_keys)
502504
return OrderedDict([(k, getattr(self.state, k)) for k in keys])
503505

@@ -555,9 +557,9 @@ def load_state_dict(self, state_dict: Mapping) -> None:
555557

556558
@staticmethod
557559
def _is_done(state: State) -> bool:
558-
return state.iteration == state.epoch_length * state.max_epochs
560+
return state.iteration == state.epoch_length * state.max_epochs # type: ignore[operator]
559561

560-
def set_data(self, data):
562+
def set_data(self, data: Union[Iterable, DataLoader]) -> None:
561563
"""Method to set data. After calling the method the next batch passed to `processing_function` is
562564
from newly provided data. Please, note that epoch length is not modified.
563565
@@ -705,21 +707,25 @@ def switch_batch(engine):
705707
return self._internal_run()
706708

707709
@staticmethod
708-
def _init_timers(state: State):
710+
def _init_timers(state: State) -> None:
709711
state.times[Events.EPOCH_COMPLETED.name] = 0.0
710712
state.times[Events.COMPLETED.name] = 0.0
711713

712-
def _get_data_length(self, data):
713-
data_length = None
714+
def _get_data_length(self, data: Iterable) -> Optional[int]:
714715
try:
715716
if hasattr(data, "__len__"):
716-
data_length = len(data)
717+
return len(data) # type: ignore[arg-type]
717718
except TypeError:
718719
# _InfiniteConstantSampler can raise a TypeError on DataLoader length of a IterableDataset
719720
pass
720-
return data_length
721+
return None
721722

722723
def _setup_engine(self) -> None:
724+
if self.state.dataloader is None:
725+
raise RuntimeError(
726+
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
727+
)
728+
723729
iteration = self.state.iteration
724730
self._dataloader_iter = iter(self.state.dataloader)
725731

@@ -734,7 +740,7 @@ def _internal_run(self) -> State:
734740
try:
735741
start_time = time.time()
736742
self._fire_event(Events.STARTED)
737-
while self.state.epoch < self.state.max_epochs and not self.should_terminate:
743+
while self.state.epoch < self.state.max_epochs and not self.should_terminate: # type: ignore[operator]
738744
self.state.epoch += 1
739745
self._fire_event(Events.EPOCH_STARTED)
740746

@@ -785,6 +791,15 @@ def _run_once_on_dataset(self) -> float:
785791
iter_counter = self._init_iter.pop() if len(self._init_iter) > 0 else 0
786792
should_exit = False
787793
try:
794+
if self._dataloader_iter is None:
795+
raise RuntimeError(
796+
"Internal error, self._dataloader_iter is None. Please, file an issue if you encounter this error."
797+
)
798+
if self.state.dataloader is None:
799+
raise RuntimeError(
800+
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
801+
)
802+
788803
while True:
789804
try:
790805
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
@@ -808,7 +823,8 @@ def _run_once_on_dataset(self) -> float:
808823
"Data iterator can not provide data anymore but required total number of "
809824
"iterations to run is not reached. "
810825
"Current iteration: {} vs Total iterations to run : {}".format(
811-
self.state.iteration, self.state.epoch_length * self.state.max_epochs
826+
self.state.iteration,
827+
self.state.epoch_length * self.state.max_epochs, # type: ignore[operator]
812828
)
813829
)
814830
break

0 commit comments

Comments
 (0)