1
1
import numbers
2
2
import warnings
3
3
from functools import partial
4
- from typing import Any , Callable , Dict , Iterable , Mapping , Optional , Sequence , Union , cast
4
+ from typing import Any , Callable , Dict , Iterable , Mapping , Optional , Sequence , Tuple , Union , cast
5
5
6
6
import torch
7
7
import torch .nn as nn
@@ -184,9 +184,7 @@ def _setup_common_training_handlers(
184
184
checkpoint_handler = Checkpoint (
185
185
to_save , cast (Union [Callable , BaseSaveHandler ], save_handler ), filename_prefix = "training" , ** kwargs
186
186
)
187
- trainer .add_event_handler (
188
- Events .ITERATION_COMPLETED (every = save_every_iters ), checkpoint_handler
189
- ) # type: ignore[arg-type]
187
+ trainer .add_event_handler (Events .ITERATION_COMPLETED (every = save_every_iters ), checkpoint_handler )
190
188
191
189
if with_gpu_stats :
192
190
GpuInfo ().attach (
@@ -195,7 +193,7 @@ def _setup_common_training_handlers(
195
193
196
194
if output_names is not None :
197
195
198
- def output_transform (x , index , name ) :
196
+ def output_transform (x : Any , index : int , name : str ) -> Any :
199
197
if isinstance (x , Mapping ):
200
198
return x [name ]
201
199
elif isinstance (x , Sequence ):
@@ -216,9 +214,7 @@ def output_transform(x, index, name):
216
214
if with_pbars :
217
215
if with_pbar_on_iters :
218
216
ProgressBar (persist = False ).attach (
219
- trainer ,
220
- metric_names = "all" ,
221
- event_name = Events .ITERATION_COMPLETED (every = log_every_iters ), # type: ignore[arg-type]
217
+ trainer , metric_names = "all" , event_name = Events .ITERATION_COMPLETED (every = log_every_iters )
222
218
)
223
219
224
220
ProgressBar (persist = True , bar_format = "" ).attach (
@@ -266,18 +262,18 @@ def _setup_common_distrib_training_handlers(
266
262
raise TypeError ("Train sampler should be torch DistributedSampler and have `set_epoch` method" )
267
263
268
264
@trainer .on (Events .EPOCH_STARTED )
269
- def distrib_set_epoch (engine ) :
270
- train_sampler .set_epoch (engine .state .epoch - 1 )
265
+ def distrib_set_epoch (engine : Engine ) -> None :
266
+ cast ( DistributedSampler , train_sampler ) .set_epoch (engine .state .epoch - 1 )
271
267
272
268
273
- def empty_cuda_cache (_ ) -> None :
269
+ def empty_cuda_cache (_ : Engine ) -> None :
274
270
torch .cuda .empty_cache ()
275
271
import gc
276
272
277
273
gc .collect ()
278
274
279
275
280
- def setup_any_logging (logger , logger_module , trainer , optimizers , evaluators , log_every_iters ) -> None :
276
+ def setup_any_logging (logger , logger_module , trainer , optimizers , evaluators , log_every_iters ) -> None : # type: ignore
281
277
raise DeprecationWarning (
282
278
"ignite.contrib.engines.common.setup_any_logging is deprecated since 0.4.0. and will be remove in 0.6.0. "
283
279
"Please use instead: setup_tb_logging, setup_visdom_logging or setup_mlflow_logging etc."
@@ -549,7 +545,7 @@ def setup_trains_logging(
549
545
550
546
551
547
def get_default_score_fn (metric_name : str ) -> Any :
552
- def wrapper (engine : Engine ):
548
+ def wrapper (engine : Engine ) -> Any :
553
549
score = engine .state .metrics [metric_name ]
554
550
return score
555
551
0 commit comments