Skip to content

Commit 794bce6

Browse files
committed
Update to fix strict errors
1 parent 81d0401 commit 794bce6

File tree

3 files changed

+11
-15
lines changed

3 files changed

+11
-15
lines changed

ignite/contrib/engines/common.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numbers
22
import warnings
33
from functools import partial
4-
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union, cast
4+
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union, cast
55

66
import torch
77
import torch.nn as nn
@@ -184,9 +184,7 @@ def _setup_common_training_handlers(
184184
checkpoint_handler = Checkpoint(
185185
to_save, cast(Union[Callable, BaseSaveHandler], save_handler), filename_prefix="training", **kwargs
186186
)
187-
trainer.add_event_handler(
188-
Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler
189-
) # type: ignore[arg-type]
187+
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler)
190188

191189
if with_gpu_stats:
192190
GpuInfo().attach(
@@ -195,7 +193,7 @@ def _setup_common_training_handlers(
195193

196194
if output_names is not None:
197195

198-
def output_transform(x, index, name):
196+
def output_transform(x: Any, index: int, name: str) -> Any:
199197
if isinstance(x, Mapping):
200198
return x[name]
201199
elif isinstance(x, Sequence):
@@ -216,9 +214,7 @@ def output_transform(x, index, name):
216214
if with_pbars:
217215
if with_pbar_on_iters:
218216
ProgressBar(persist=False).attach(
219-
trainer,
220-
metric_names="all",
221-
event_name=Events.ITERATION_COMPLETED(every=log_every_iters), # type: ignore[arg-type]
217+
trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED(every=log_every_iters)
222218
)
223219

224220
ProgressBar(persist=True, bar_format="").attach(
@@ -266,18 +262,18 @@ def _setup_common_distrib_training_handlers(
266262
raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method")
267263

268264
@trainer.on(Events.EPOCH_STARTED)
269-
def distrib_set_epoch(engine):
270-
train_sampler.set_epoch(engine.state.epoch - 1)
265+
def distrib_set_epoch(engine: Engine) -> None:
266+
cast(DistributedSampler, train_sampler).set_epoch(engine.state.epoch - 1)
271267

272268

273-
def empty_cuda_cache(_) -> None:
269+
def empty_cuda_cache(_: Engine) -> None:
274270
torch.cuda.empty_cache()
275271
import gc
276272

277273
gc.collect()
278274

279275

280-
def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, log_every_iters) -> None:
276+
def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, log_every_iters) -> None: # type: ignore
281277
raise DeprecationWarning(
282278
"ignite.contrib.engines.common.setup_any_logging is deprecated since 0.4.0. and will be remove in 0.6.0. "
283279
"Please use instead: setup_tb_logging, setup_visdom_logging or setup_mlflow_logging etc."
@@ -549,7 +545,7 @@ def setup_trains_logging(
549545

550546

551547
def get_default_score_fn(metric_name: str) -> Any:
552-
def wrapper(engine: Engine):
548+
def wrapper(engine: Engine) -> Any:
553549
score = engine.state.metrics[metric_name]
554550
return score
555551

ignite/contrib/engines/tbptt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def create_supervised_tbptt_trainer(
8686
8787
"""
8888

89-
def _update(engine: Engine, batch: Sequence[torch.Tensor]):
89+
def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> float:
9090
loss_list = []
9191
hidden = None
9292

ignite/contrib/handlers/tqdm_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def attach(
159159
engine: Engine,
160160
metric_names: Optional[str] = None,
161161
output_transform: Optional[Callable] = None,
162-
event_name: Events = Events.ITERATION_COMPLETED,
162+
event_name: Union[CallableEventWithFilter, Events] = Events.ITERATION_COMPLETED,
163163
closing_event_name: Events = Events.EPOCH_COMPLETED,
164164
):
165165
"""

0 commit comments

Comments
 (0)