1
- from typing import Any , Callable , Dict , Set
2
- import torch
3
1
import logging
4
- from torch ._decomp import (
5
- register_decomposition ,
6
- core_aten_decompositions ,
7
- get_decompositions as get_torch_decompositions ,
8
- )
2
+ from typing import Any , Callable , Dict , Optional , Set
3
+
4
+ import torch
5
+ from torch ._decomp import core_aten_decompositions
6
+ from torch ._decomp import get_decompositions as get_torch_decompositions
7
+ from torch ._decomp import register_decomposition
9
8
from torch ._ops import OpOverload
10
9
11
10
aten = torch .ops .aten
12
11
13
12
_core_aten_decompositions : Dict [
14
- OpOverload , Callable
13
+ OpOverload , Callable [[ Any ], Any ]
15
14
] = core_aten_decompositions ()
16
15
torch_enabled_decompositions : Set [OpOverload ] = {
17
16
aten ._adaptive_avg_pool2d_backward ,
179
178
aten .full ,
180
179
aten .repeat ,
181
180
}
182
- torch_disabled_decompositions : Set [OpOverload ] = {}
181
+ torch_disabled_decompositions : Set [OpOverload ] = set ()
183
182
184
183
185
184
ENABLED_TORCH_DECOMPOSITIONS : Dict [
186
- OpOverload , Callable
185
+ OpOverload , Callable [[ Any ], Any ]
187
186
] = get_torch_decompositions (torch_enabled_decompositions )
188
- TORCH_TRT_DECOMPOSITIONS : Dict [OpOverload , Callable ] = {}
187
+ TORCH_TRT_DECOMPOSITIONS : Dict [OpOverload , Callable [[ Any ], Any ] ] = {}
189
188
190
189
191
190
logger = logging .getLogger (__name__ )
192
191
193
192
194
- def check_decomp_set_invariants ():
193
+ def check_decomp_set_invariants () -> None :
195
194
"""Validates no overlap between enabled and disabled decomposition sets"""
196
195
overlap = torch_enabled_decompositions .intersection (torch_disabled_decompositions )
197
196
@@ -206,7 +205,9 @@ def check_decomp_set_invariants():
206
205
check_decomp_set_invariants ()
207
206
208
207
209
- def register_torch_trt_decomposition (aten_op , registry = None ):
208
+ def register_torch_trt_decomposition (
209
+ aten_op : OpOverload , registry : Optional [Any ] = None
210
+ ) -> Callable [[Any ], Any ]:
210
211
"""Checks if the decomposition already exists in one of the sets
211
212
Registers the decomposition via the Torch utility
212
213
@@ -235,7 +236,7 @@ def register_torch_trt_decomposition(aten_op, registry=None):
235
236
"The custom implementation will take precedence."
236
237
)
237
238
238
- def register (fn : Callable ) -> Callable :
239
+ def register (fn : Callable [[ Any ], Any ] ) -> Any :
239
240
return register_decomposition (aten_op = aten_op , registry = registry )(fn )
240
241
241
242
return register
@@ -248,7 +249,7 @@ def replace_inplace_op(aten_op: OpOverload, outplace_op: OpOverload) -> Any:
248
249
"""
249
250
250
251
@register_torch_trt_decomposition (aten_op , registry = TORCH_TRT_DECOMPOSITIONS )
251
- def inplace_op (* args , ** kwargs ):
252
+ def inplace_op (* args , ** kwargs ): # type: ignore
252
253
out = outplace_op (* args , ** kwargs )
253
254
return args [0 ].copy_ (out )
254
255
@@ -271,17 +272,17 @@ def inplace_op(*args, **kwargs):
271
272
272
273
273
274
@register_torch_trt_decomposition (aten .std , registry = TORCH_TRT_DECOMPOSITIONS )
274
- def std_replacement (* args , ** kwargs ) -> torch .Tensor :
275
+ def std_replacement (* args , ** kwargs ) -> torch .Tensor : # type: ignore
275
276
return torch .sqrt (torch .var (* args , ** kwargs ))
276
277
277
278
278
279
@register_torch_trt_decomposition (aten .rsqrt , registry = TORCH_TRT_DECOMPOSITIONS )
279
- def rsqrt_replacement (* args , ** kwargs ) -> torch .Tensor :
280
+ def rsqrt_replacement (* args , ** kwargs ) -> torch .Tensor : # type: ignore
280
281
return torch .reciprocal (torch .sqrt (* args , ** kwargs ))
281
282
282
283
283
284
@register_torch_trt_decomposition (aten ._unsafe_view , registry = TORCH_TRT_DECOMPOSITIONS )
284
- def unsafe_view_replacement (x : torch .Tensor , * args , ** kwargs ) -> torch .Tensor :
285
+ def unsafe_view_replacement (x : torch .Tensor , * args , ** kwargs ) -> torch .Tensor : # type: ignore
285
286
return torch .reshape (x , * args , ** kwargs )
286
287
287
288
@@ -324,9 +325,9 @@ def reciprocal_replacement(
324
325
325
326
def get_decompositions (
326
327
enable_experimental_decompositions : bool = False ,
327
- ) -> Dict [OpOverload , Callable ]:
328
+ ) -> Dict [OpOverload , Callable [[ Any ], Any ] ]:
328
329
if enable_experimental_decompositions :
329
- CORE_ATEN_DECOMPOSITIONS_FILTERED : Dict [OpOverload , Callable ] = {
330
+ CORE_ATEN_DECOMPOSITIONS_FILTERED : Dict [OpOverload , Callable [[ Any ], Any ] ] = {
330
331
decomp : _core_aten_decompositions [decomp ]
331
332
for decomp in _core_aten_decompositions
332
333
if decomp not in torch_disabled_decompositions
0 commit comments