5
5
import weakref
6
6
from collections import OrderedDict , defaultdict
7
7
from collections .abc import Mapping
8
- from typing import Any , Callable , Iterable , List , Optional , Union
8
+ from typing import Any , Callable , Dict , Iterable , Iterator , List , Optional , Tuple , Union
9
+
10
+ from torch .utils .data import DataLoader
9
11
10
12
from ignite ._utils import _to_hours_mins_secs
11
13
from ignite .base import Serializable
@@ -120,18 +122,18 @@ def compute_mean_std(engine, batch):
120
122
_state_dict_one_of_opt_keys = ("iteration" , "epoch" )
121
123
122
124
def __init__ (self , process_function : Callable ):
123
- self ._event_handlers = defaultdict (list )
125
+ self ._event_handlers = defaultdict (list ) # type: Dict[Any, List]
124
126
self .logger = logging .getLogger (__name__ + "." + self .__class__ .__name__ )
125
127
self ._process_function = process_function
126
- self .last_event_name = None
128
+ self .last_event_name = None # type: Optional[Events]
127
129
self .should_terminate = False
128
130
self .should_terminate_single_epoch = False
129
131
self .state = State ()
130
- self ._state_dict_user_keys = []
131
- self ._allowed_events = []
132
+ self ._state_dict_user_keys = [] # type: List[str]
133
+ self ._allowed_events = [] # type: List[EventEnum]
132
134
133
- self ._dataloader_iter = None
134
- self ._init_iter = []
135
+ self ._dataloader_iter = None # type: Optional[Iterator[Any]]
136
+ self ._init_iter = [] # type: List[int]
135
137
136
138
self .register_events (* Events )
137
139
@@ -232,16 +234,16 @@ def _handler_wrapper(self, handler: Callable, event_name: Any, event_filter: Cal
232
234
# signature of the following wrapper will be inspected during registering to check if engine is necessary
233
235
# we have to build a wrapper with relevant signature : solution is functools.wraps
234
236
@functools .wraps (handler )
235
- def wrapper (* args , ** kwargs ) -> Any :
237
+ def wrapper (* args : Any , ** kwargs : Any ) -> Any :
236
238
event = self .state .get_event_attrib_value (event_name )
237
239
if event_filter (self , event ):
238
240
return handler (* args , ** kwargs )
239
241
240
242
# setup input handler as parent to make has_event_handler work
241
- wrapper ._parent = weakref .ref (handler )
243
+ wrapper ._parent = weakref .ref (handler ) # type: ignore[attr-defined]
242
244
return wrapper
243
245
244
- def add_event_handler (self , event_name : Any , handler : Callable , * args , ** kwargs ) :
246
+ def add_event_handler (self , event_name : Any , handler : Callable , * args : Any , ** kwargs : Any ) -> RemovableEventHandle :
245
247
"""Add an event handler to be executed when the specified event is fired.
246
248
247
249
Args:
@@ -312,7 +314,7 @@ def execute_something():
312
314
return RemovableEventHandle (event_name , handler , self )
313
315
314
316
@staticmethod
315
- def _assert_non_filtered_event (event_name : Any ):
317
+ def _assert_non_filtered_event (event_name : Any ) -> None :
316
318
if (
317
319
isinstance (event_name , CallableEventWithFilter )
318
320
and event_name .filter != CallableEventWithFilter .default_event_filter
@@ -321,7 +323,7 @@ def _assert_non_filtered_event(event_name: Any):
321
323
"Argument event_name should not be a filtered event, " "please use event without any event filtering"
322
324
)
323
325
324
- def has_event_handler (self , handler : Callable , event_name : Optional [Any ] = None ):
326
+ def has_event_handler (self , handler : Callable , event_name : Optional [Any ] = None ) -> bool :
325
327
"""Check if the specified event has the specified handler.
326
328
327
329
Args:
@@ -332,7 +334,7 @@ def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None)
332
334
if event_name is not None :
333
335
if event_name not in self ._event_handlers :
334
336
return False
335
- events = [event_name ]
337
+ events = [event_name ] # type: Union[List[Any], Dict[Any, List]]
336
338
else :
337
339
events = self ._event_handlers
338
340
for e in events :
@@ -344,10 +346,10 @@ def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None)
344
346
@staticmethod
345
347
def _compare_handlers (user_handler : Callable , registered_handler : Callable ) -> bool :
346
348
if hasattr (registered_handler , "_parent" ):
347
- registered_handler = registered_handler ._parent ()
349
+ registered_handler = registered_handler ._parent () # type: ignore[attr-defined]
348
350
return registered_handler == user_handler
349
351
350
- def remove_event_handler (self , handler : Callable , event_name : Any ):
352
+ def remove_event_handler (self , handler : Callable , event_name : Any ) -> None :
351
353
"""Remove event handler `handler` from registered handlers of the engine
352
354
353
355
Args:
@@ -367,7 +369,7 @@ def remove_event_handler(self, handler: Callable, event_name: Any):
367
369
raise ValueError ("Input handler '{}' is not found among registered event handlers" .format (handler ))
368
370
self ._event_handlers [event_name ] = new_event_handlers
369
371
370
- def on (self , event_name , * args , ** kwargs ) :
372
+ def on (self , event_name : Any , * args : Any , ** kwargs : Any ) -> Callable :
371
373
"""Decorator shortcut for add_event_handler.
372
374
373
375
Args:
@@ -398,7 +400,7 @@ def decorator(f: Callable) -> Callable:
398
400
399
401
return decorator
400
402
401
- def _fire_event (self , event_name : Any , * event_args , ** event_kwargs ) -> None :
403
+ def _fire_event (self , event_name : Any , * event_args : Any , ** event_kwargs : Any ) -> None :
402
404
"""Execute all the handlers associated with given event.
403
405
404
406
This method executes all handlers associated with the event
@@ -460,7 +462,7 @@ def terminate_epoch(self) -> None:
460
462
)
461
463
self .should_terminate_single_epoch = True
462
464
463
- def _handle_exception (self , e : Exception ) -> None :
465
+ def _handle_exception (self , e : BaseException ) -> None :
464
466
if Events .EXCEPTION_RAISED in self ._event_handlers :
465
467
self ._fire_event (Events .EXCEPTION_RAISED , e )
466
468
else :
@@ -497,7 +499,7 @@ def save_engine(_):
497
499
a dictionary containing engine's state
498
500
499
501
"""
500
- keys = self ._state_dict_all_req_keys + (self ._state_dict_one_of_opt_keys [0 ],)
502
+ keys = self ._state_dict_all_req_keys + (self ._state_dict_one_of_opt_keys [0 ],) # type: Tuple[str, ...]
501
503
keys += tuple (self ._state_dict_user_keys )
502
504
return OrderedDict ([(k , getattr (self .state , k )) for k in keys ])
503
505
@@ -555,9 +557,9 @@ def load_state_dict(self, state_dict: Mapping) -> None:
555
557
556
558
@staticmethod
557
559
def _is_done (state : State ) -> bool :
558
- return state .iteration == state .epoch_length * state .max_epochs
560
+ return state .iteration == state .epoch_length * state .max_epochs # type: ignore[operator]
559
561
560
- def set_data (self , data ) :
562
+ def set_data (self , data : Union [ Iterable , DataLoader ]) -> None :
561
563
"""Method to set data. After calling the method the next batch passed to `processing_function` is
562
564
from newly provided data. Please, note that epoch length is not modified.
563
565
@@ -705,21 +707,25 @@ def switch_batch(engine):
705
707
return self ._internal_run ()
706
708
707
709
@staticmethod
708
- def _init_timers (state : State ):
710
+ def _init_timers (state : State ) -> None :
709
711
state .times [Events .EPOCH_COMPLETED .name ] = 0.0
710
712
state .times [Events .COMPLETED .name ] = 0.0
711
713
712
- def _get_data_length (self , data ):
713
- data_length = None
714
+ def _get_data_length (self , data : Iterable ) -> Optional [int ]:
714
715
try :
715
716
if hasattr (data , "__len__" ):
716
- data_length = len (data )
717
+ return len (data ) # type: ignore[arg-type]
717
718
except TypeError :
718
719
# _InfiniteConstantSampler can raise a TypeError on DataLoader length of a IterableDataset
719
720
pass
720
- return data_length
721
+ return None
721
722
722
723
def _setup_engine (self ) -> None :
724
+ if self .state .dataloader is None :
725
+ raise RuntimeError (
726
+ "Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
727
+ )
728
+
723
729
iteration = self .state .iteration
724
730
self ._dataloader_iter = iter (self .state .dataloader )
725
731
@@ -734,7 +740,7 @@ def _internal_run(self) -> State:
734
740
try :
735
741
start_time = time .time ()
736
742
self ._fire_event (Events .STARTED )
737
- while self .state .epoch < self .state .max_epochs and not self .should_terminate :
743
+ while self .state .epoch < self .state .max_epochs and not self .should_terminate : # type: ignore[operator]
738
744
self .state .epoch += 1
739
745
self ._fire_event (Events .EPOCH_STARTED )
740
746
@@ -785,6 +791,15 @@ def _run_once_on_dataset(self) -> float:
785
791
iter_counter = self ._init_iter .pop () if len (self ._init_iter ) > 0 else 0
786
792
should_exit = False
787
793
try :
794
+ if self ._dataloader_iter is None :
795
+ raise RuntimeError (
796
+ "Internal error, self._dataloader_iter is None. Please, file an issue if you encounter this error."
797
+ )
798
+ if self .state .dataloader is None :
799
+ raise RuntimeError (
800
+ "Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
801
+ )
802
+
788
803
while True :
789
804
try :
790
805
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
@@ -808,7 +823,8 @@ def _run_once_on_dataset(self) -> float:
808
823
"Data iterator can not provide data anymore but required total number of "
809
824
"iterations to run is not reached. "
810
825
"Current iteration: {} vs Total iterations to run : {}" .format (
811
- self .state .iteration , self .state .epoch_length * self .state .max_epochs
826
+ self .state .iteration ,
827
+ self .state .epoch_length * self .state .max_epochs , # type: ignore[operator]
812
828
)
813
829
)
814
830
break
0 commit comments