Skip to content

Commit 5883ef6

Browse files
committed
mypy: Reformatting
1 parent e64fedd commit 5883ef6

File tree

3 files changed

+24
-23
lines changed

3 files changed

+24
-23
lines changed

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch_tensorrt.dynamo._defaults import (
66
DEBUG,
7+
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
78
MAX_AUX_STREAMS,
89
MIN_BLOCK_SIZE,
910
OPTIMIZATION_LEVEL,
@@ -13,7 +14,6 @@
1314
USE_PYTHON_RUNTIME,
1415
VERSION_COMPATIBLE,
1516
WORKSPACE_SIZE,
16-
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
1717
)
1818

1919

py/torch_tensorrt/dynamo/compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch_tensorrt.dynamo import CompilationSettings
1414
from torch_tensorrt.dynamo._defaults import (
1515
DEBUG,
16+
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
1617
MAX_AUX_STREAMS,
1718
MIN_BLOCK_SIZE,
1819
OPTIMIZATION_LEVEL,
@@ -22,7 +23,6 @@
2223
USE_PYTHON_RUNTIME,
2324
VERSION_COMPATIBLE,
2425
WORKSPACE_SIZE,
25-
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
2626
)
2727
from torch_tensorrt.dynamo.backend.backends import _compile_module
2828
from torch_tensorrt.dynamo.conversion import convert_module
@@ -63,7 +63,7 @@ def compile(
6363
version_compatible: bool = VERSION_COMPATIBLE,
6464
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
6565
use_python_runtime: bool = USE_PYTHON_RUNTIME,
66-
enable_experimental_decompositions=ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
66+
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
6767
**kwargs: Any,
6868
) -> torch.fx.GraphModule:
6969
if debug:

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
from typing import Any, Callable, Dict, Set
2-
import torch
31
import logging
4-
from torch._decomp import (
5-
register_decomposition,
6-
core_aten_decompositions,
7-
get_decompositions as get_torch_decompositions,
8-
)
2+
from typing import Any, Callable, Dict, Optional, Set
3+
4+
import torch
5+
from torch._decomp import core_aten_decompositions
6+
from torch._decomp import get_decompositions as get_torch_decompositions
7+
from torch._decomp import register_decomposition
98
from torch._ops import OpOverload
109

1110
aten = torch.ops.aten
1211

1312
_core_aten_decompositions: Dict[
14-
OpOverload, Callable
13+
OpOverload, Callable[[Any], Any]
1514
] = core_aten_decompositions()
1615
torch_enabled_decompositions: Set[OpOverload] = {
1716
aten._adaptive_avg_pool2d_backward,
@@ -179,19 +178,19 @@
179178
aten.full,
180179
aten.repeat,
181180
}
182-
torch_disabled_decompositions: Set[OpOverload] = {}
181+
torch_disabled_decompositions: Set[OpOverload] = set()
183182

184183

185184
ENABLED_TORCH_DECOMPOSITIONS: Dict[
186-
OpOverload, Callable
185+
OpOverload, Callable[[Any], Any]
187186
] = get_torch_decompositions(torch_enabled_decompositions)
188-
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable] = {}
187+
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}
189188

190189

191190
logger = logging.getLogger(__name__)
192191

193192

194-
def check_decomp_set_invariants():
193+
def check_decomp_set_invariants() -> None:
195194
"""Validates no overlap between enabled and disabled decomposition sets"""
196195
overlap = torch_enabled_decompositions.intersection(torch_disabled_decompositions)
197196

@@ -206,7 +205,9 @@ def check_decomp_set_invariants():
206205
check_decomp_set_invariants()
207206

208207

209-
def register_torch_trt_decomposition(aten_op, registry=None):
208+
def register_torch_trt_decomposition(
209+
aten_op: OpOverload, registry: Optional[Any] = None
210+
) -> Callable[[Any], Any]:
210211
"""Checks if the decomposition already exists in one of the sets
211212
Registers the decomposition via the Torch utility
212213
@@ -235,7 +236,7 @@ def register_torch_trt_decomposition(aten_op, registry=None):
235236
"The custom implementation will take precedence."
236237
)
237238

238-
def register(fn: Callable) -> Callable:
239+
def register(fn: Callable[[Any], Any]) -> Any:
239240
return register_decomposition(aten_op=aten_op, registry=registry)(fn)
240241

241242
return register
@@ -248,7 +249,7 @@ def replace_inplace_op(aten_op: OpOverload, outplace_op: OpOverload) -> Any:
248249
"""
249250

250251
@register_torch_trt_decomposition(aten_op, registry=TORCH_TRT_DECOMPOSITIONS)
251-
def inplace_op(*args, **kwargs):
252+
def inplace_op(*args, **kwargs): # type: ignore
252253
out = outplace_op(*args, **kwargs)
253254
return args[0].copy_(out)
254255

@@ -271,17 +272,17 @@ def inplace_op(*args, **kwargs):
271272

272273

273274
@register_torch_trt_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS)
274-
def std_replacement(*args, **kwargs) -> torch.Tensor:
275+
def std_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore
275276
return torch.sqrt(torch.var(*args, **kwargs))
276277

277278

278279
@register_torch_trt_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS)
279-
def rsqrt_replacement(*args, **kwargs) -> torch.Tensor:
280+
def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore
280281
return torch.reciprocal(torch.sqrt(*args, **kwargs))
281282

282283

283284
@register_torch_trt_decomposition(aten._unsafe_view, registry=TORCH_TRT_DECOMPOSITIONS)
284-
def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
285+
def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor: # type: ignore
285286
return torch.reshape(x, *args, **kwargs)
286287

287288

@@ -324,9 +325,9 @@ def reciprocal_replacement(
324325

325326
def get_decompositions(
326327
enable_experimental_decompositions: bool = False,
327-
) -> Dict[OpOverload, Callable]:
328+
) -> Dict[OpOverload, Callable[[Any], Any]]:
328329
if enable_experimental_decompositions:
329-
CORE_ATEN_DECOMPOSITIONS_FILTERED: Dict[OpOverload, Callable] = {
330+
CORE_ATEN_DECOMPOSITIONS_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = {
330331
decomp: _core_aten_decompositions[decomp]
331332
for decomp in _core_aten_decompositions
332333
if decomp not in torch_disabled_decompositions

0 commit comments

Comments
 (0)