Skip to content

Commit bebb821

Browse files
committed
Add static typing, fix mypy errors
1 parent e7521f4 commit bebb821

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

ignite/utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import random
55
import warnings
6-
from typing import Any, Callable, Optional, Tuple, Type, Union, cast
6+
from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union, cast
77

88
import torch
99

@@ -164,8 +164,11 @@ def manual_seed(seed: int) -> None:
164164
pass
165165

166166

167-
def deprecated(deprecated_in, removed_in=None, reasons=[], raiseWarning=False):
168-
def decorator(func):
167+
def deprecated(deprecated_in: str, removed_in: str = "", reasons: list = [], raiseWarning: bool = False) -> Callable:
168+
169+
F = TypeVar("F", bound=Callable[..., Any])
170+
171+
def decorator(func: F) -> F:
169172
func_doc = func.__doc__ if func.__doc__ else ""
170173
deprecation_warning = (
171174
f"This function has been deprecated since version `{deprecated_in}`"
@@ -174,7 +177,7 @@ def decorator(func):
174177
)
175178

176179
@functools.wraps(func)
177-
def wrapper(*args, **kwargs):
180+
def wrapper(*args: int, **kwargs: float) -> Callable:
178181
if raiseWarning:
179182
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2)
180183
return func(*args, **kwargs)
@@ -184,6 +187,6 @@ def wrapper(*args, **kwargs):
184187
for reason in reasons:
185188
appended_doc += "\n\t- " + reason
186189
wrapper.__doc__ = "**Deprecated function**." + "\n\n " + func_doc + appended_doc
187-
return wrapper
190+
return cast(F, wrapper)
188191

189192
return decorator

0 commit comments

Comments
 (0)