3
3
import logging
4
4
import random
5
5
import warnings
6
- from typing import Any , Callable , Optional , Tuple , Type , Union , cast
6
+ from typing import Any , Callable , Optional , Tuple , Type , TypeVar , Union , cast
7
7
8
8
import torch
9
9
@@ -164,8 +164,11 @@ def manual_seed(seed: int) -> None:
164
164
pass
165
165
166
166
167
- def deprecated (deprecated_in , removed_in = None , reasons = [], raiseWarning = False ):
168
- def decorator (func ):
167
+ def deprecated (deprecated_in : str , removed_in : str = "" , reasons : list = [], raiseWarning : bool = False ) -> Callable :
168
+
169
+ F = TypeVar ("F" , bound = Callable [..., Any ])
170
+
171
+ def decorator (func : F ) -> F :
169
172
func_doc = func .__doc__ if func .__doc__ else ""
170
173
deprecation_warning = (
171
174
f"This function has been deprecated since version `{ deprecated_in } `"
@@ -174,7 +177,7 @@ def decorator(func):
174
177
)
175
178
176
179
@functools .wraps (func )
177
- def wrapper (* args , ** kwargs ) :
180
+ def wrapper (* args : int , ** kwargs : float ) -> Callable :
178
181
if raiseWarning :
179
182
warnings .warn (deprecation_warning , DeprecationWarning , stacklevel = 2 )
180
183
return func (* args , ** kwargs )
@@ -184,6 +187,6 @@ def wrapper(*args, **kwargs):
184
187
for reason in reasons :
185
188
appended_doc += "\n \t - " + reason
186
189
wrapper .__doc__ = "**Deprecated function**." + "\n \n " + func_doc + appended_doc
187
- return wrapper
190
+ return cast ( F , wrapper )
188
191
189
192
return decorator
0 commit comments