1
1
from typing import Callable , Dict , Set
2
2
import torch
3
+ import logging
3
4
from torch ._decomp import (
4
5
register_decomposition ,
5
6
core_aten_decompositions ,
11
12
_core_aten_decompositions : Dict [
12
13
torch ._ops .OpOverload , Callable
13
14
] = core_aten_decompositions ()
14
- enabled_decompositions : Set [torch ._ops .OpOverload ] = {
15
+ torch_enabled_decompositions : Set [torch ._ops .OpOverload ] = {
15
16
aten ._adaptive_avg_pool2d_backward ,
16
17
aten .addcdiv ,
17
18
aten .addcdiv_ ,
171
172
aten .zeros ,
172
173
aten .zeros_like ,
173
174
}
174
- disabled_decompositions : Set [torch ._ops .OpOverload ] = {}
175
+ torch_disabled_decompositions : Set [torch ._ops .OpOverload ] = {}
176
+
175
177
176
178
ENABLED_TORCH_DECOMPOSITIONS : Dict [
177
179
torch ._ops .OpOverload , Callable
178
- ] = get_torch_decompositions (enabled_decompositions )
180
+ ] = get_torch_decompositions (torch_enabled_decompositions )
179
181
TORCH_TRT_DECOMPOSITIONS : Dict [torch ._ops .OpOverload , Callable ] = {}
180
182
181
183
184
+ logger = logging .getLogger (__name__ )
185
+
186
+
187
+ def check_decomp_set_invariants ():
188
+ """Validates no overlap between enabled and disabled decomposition sets"""
189
+ overlap = torch_enabled_decompositions .intersection (torch_disabled_decompositions )
190
+
191
+ if overlap :
192
+ raise AssertionError (
193
+ f"Detected { overlap } registered in both torch_enabled_decompositions "
194
+ "and torch_disabled_decompositions. Ensure all operator(s) are in "
195
+ "at most one of the two sets."
196
+ )
197
+
198
+
199
+ check_decomp_set_invariants ()
200
+
201
+
202
+ def register_torch_trt_decomposition (aten_op , registry = None ):
203
+ """Checks if the decomposition already exists in one of the sets
204
+ Registers the decomposition via the Torch utility
205
+
206
+ Alerts the user if the decomposition already exists, before registering
207
+ Throws an AssertionError if the user attempts to register a decomposition
208
+ which is present in the set of explicitly disabled decompositions
209
+ """
210
+ if aten_op in torch_enabled_decompositions :
211
+ logger .warning (
212
+ f"Detected custom decomposition for { aten_op } , which conflicts "
213
+ "with an existing Torch decomposition in torch_enabled_decompositions. "
214
+ "The custom implementation will take precedence."
215
+ )
216
+ elif aten_op in torch_disabled_decompositions :
217
+ logger .info (
218
+ f"Detected custom decomposition for { aten_op } , which is present "
219
+ "in torch_disabled_decompositions."
220
+ )
221
+
222
+ # Conflicts with _core_aten_decompositions will only occur if
223
+ # enable_experimental_decompositions is True in get_decompositions
224
+ if aten_op in _core_aten_decompositions :
225
+ logger .debug (
226
+ f"Detected custom decomposition for { aten_op } , which conflicts "
227
+ "with an existing Torch decomposition in core_aten_decompositions. "
228
+ "The custom implementation will take precedence."
229
+ )
230
+
231
+ def register (fn : Callable ) -> Callable :
232
+ return register_decomposition (aten_op = aten_op , registry = registry )(fn )
233
+
234
+ return register
235
+
236
+
182
237
def replace_inplace_op (aten_op , outplace_op ):
183
238
"""Replace inplace operation with functional equivalent
184
239
Adapted from:
185
240
https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361
186
241
"""
187
242
188
- @register_decomposition (aten_op , registry = TORCH_TRT_DECOMPOSITIONS )
243
+ @register_torch_trt_decomposition (aten_op , registry = TORCH_TRT_DECOMPOSITIONS )
189
244
def inplace_op (* args , ** kwargs ):
190
245
out = outplace_op (* args , ** kwargs )
191
246
return args [0 ].copy_ (out )
@@ -208,34 +263,36 @@ def inplace_op(*args, **kwargs):
208
263
replace_inplace_op (aten .scatter_reduce_ , aten .scatter_reduce )
209
264
210
265
211
- @register_decomposition (aten .std , registry = TORCH_TRT_DECOMPOSITIONS )
266
+ @register_torch_trt_decomposition (aten .std , registry = TORCH_TRT_DECOMPOSITIONS )
212
267
def std_replacement (* args , ** kwargs ) -> torch .Tensor :
213
268
return torch .sqrt (torch .var (* args , ** kwargs ))
214
269
215
270
216
- @register_decomposition (aten .rsqrt , registry = TORCH_TRT_DECOMPOSITIONS )
271
+ @register_torch_trt_decomposition (aten .rsqrt , registry = TORCH_TRT_DECOMPOSITIONS )
217
272
def rsqrt_replacement (* args , ** kwargs ) -> torch .Tensor :
218
273
return torch .reciprocal (torch .sqrt (* args , ** kwargs ))
219
274
220
275
221
- @register_decomposition (aten ._unsafe_view , registry = TORCH_TRT_DECOMPOSITIONS )
276
+ @register_torch_trt_decomposition (aten ._unsafe_view , registry = TORCH_TRT_DECOMPOSITIONS )
222
277
def unsafe_view_replacement (x : torch .Tensor , * args , ** kwargs ) -> torch .Tensor :
223
278
return torch .reshape (x , * args , ** kwargs )
224
279
225
280
226
- @register_decomposition (
281
+ @register_torch_trt_decomposition (
227
282
torch .ops .aten .lift_fresh_copy , registry = TORCH_TRT_DECOMPOSITIONS
228
283
)
229
284
def lift_fresh_copy_replacement (x : torch .Tensor ) -> torch .Tensor :
230
285
return x
231
286
232
287
233
- @register_decomposition (aten .alias , registry = TORCH_TRT_DECOMPOSITIONS )
288
+ @register_torch_trt_decomposition (aten .alias , registry = TORCH_TRT_DECOMPOSITIONS )
234
289
def alias_replacement (x : torch .Tensor ) -> torch .Tensor :
235
290
return x
236
291
237
292
238
- @register_decomposition (torch .ops .aten .addmm , registry = TORCH_TRT_DECOMPOSITIONS )
293
+ @register_torch_trt_decomposition (
294
+ torch .ops .aten .addmm , registry = TORCH_TRT_DECOMPOSITIONS
295
+ )
239
296
def addmm_replacement (
240
297
input_ : torch .Tensor , mat1 : torch .Tensor , mat2 : torch .Tensor , * , beta = 1 , alpha = 1
241
298
) -> torch .Tensor :
@@ -244,7 +301,7 @@ def addmm_replacement(
244
301
)
245
302
246
303
247
- @register_decomposition (
304
+ @register_torch_trt_decomposition (
248
305
torch .ops .aten .reciprocal .default , registry = TORCH_TRT_DECOMPOSITIONS
249
306
)
250
307
def reciprocal_replacement (
@@ -260,17 +317,8 @@ def get_decompositions(
260
317
CORE_ATEN_DECOMPOSITIONS_FILTERED : Dict [torch ._ops .OpOverload , Callable ] = {
261
318
decomp : _core_aten_decompositions [decomp ]
262
319
for decomp in _core_aten_decompositions
263
- if (
264
- decomp not in TORCH_TRT_DECOMPOSITIONS
265
- and decomp not in disabled_decompositions
266
- )
320
+ if decomp not in torch_disabled_decompositions
267
321
}
268
322
return {** CORE_ATEN_DECOMPOSITIONS_FILTERED , ** TORCH_TRT_DECOMPOSITIONS }
269
323
else :
270
- duplicate_registrations = set (ENABLED_TORCH_DECOMPOSITIONS .keys ()).intersection (
271
- set (TORCH_TRT_DECOMPOSITIONS .keys ())
272
- )
273
- assert (
274
- not duplicate_registrations
275
- ), f"Detected duplicate decompositions on: { duplicate_registrations } "
276
324
return {** ENABLED_TORCH_DECOMPOSITIONS , ** TORCH_TRT_DECOMPOSITIONS }
0 commit comments