1
1
from typing import Callable , Dict , Set
2
2
import torch
3
+ import logging
3
4
from torch ._decomp import (
4
5
register_decomposition ,
5
6
core_aten_decompositions ,
11
12
_core_aten_decompositions : Dict [
12
13
torch ._ops .OpOverload , Callable
13
14
] = core_aten_decompositions ()
14
- enabled_decompositions : Set [torch ._ops .OpOverload ] = {
15
+ torch_enabled_decompositions : Set [torch ._ops .OpOverload ] = {
15
16
aten ._adaptive_avg_pool2d_backward ,
16
17
aten .addcdiv ,
17
18
aten .addcdiv_ ,
171
172
aten .zeros ,
172
173
aten .zeros_like ,
173
174
}
174
- disabled_decompositions : Set [torch ._ops .OpOverload ] = {}
175
+ torch_disabled_decompositions : Set [torch ._ops .OpOverload ] = {}
176
+
177
+ assert torch_enabled_decompositions .intersection (torch_disabled_decompositions ) == 0
175
178
176
179
ENABLED_TORCH_DECOMPOSITIONS : Dict [
177
180
torch ._ops .OpOverload , Callable
178
- ] = get_torch_decompositions (enabled_decompositions )
181
+ ] = get_torch_decompositions (torch_enabled_decompositions )
179
182
TORCH_TRT_DECOMPOSITIONS : Dict [torch ._ops .OpOverload , Callable ] = {}
180
183
181
184
185
+ logger = logging .getLogger (__name__ )
186
+
187
+
188
+ def check_decomp_set_invariants ():
189
+ """Validates no overlap between enabled and disabled decomposition sets"""
190
+ overlap = torch_enabled_decompositions .intersection (torch_disabled_decompositions )
191
+
192
+ if overlap :
193
+ raise AssertionError (
194
+ f"Detected { overlap } registered in both torch_enabled_decompositions "
195
+ "and torch_disabled_decompositions. Ensure all operator(s) are in "
196
+ "at most one of the two sets."
197
+ )
198
+
199
+
200
+ check_decomp_set_invariants ()
201
+
202
+
203
+ def register_torch_trt_decomposition (aten_op , registry = None ):
204
+ """Checks if the decomposition already exists in one of the sets
205
+ Registers the decomposition via the Torch utility
206
+
207
+ Alerts the user if the decomposition already exists, before registering
208
+ Throws an AssertionError if the user attempts to register a decomposition
209
+ which is present in the set of explicitly disabled decompositions
210
+ """
211
+ if aten_op in torch_enabled_decompositions :
212
+ logger .warning (
213
+ f"Detected custom decomposition for { aten_op } , which conflicts "
214
+ "with an existing Torch decomposition in torch_enabled_decompositions. "
215
+ "The custom implementation will take precedence."
216
+ )
217
+ elif aten_op in torch_disabled_decompositions :
218
+ logger .info (
219
+ f"Detected custom decomposition for { aten_op } , which is present "
220
+ "in torch_disabled_decompositions."
221
+ )
222
+
223
+ # Conflicts with _core_aten_decompositions will only occur if
224
+ # enable_experimental_decompositions is True in get_decompositions
225
+ if aten_op in _core_aten_decompositions :
226
+ logger .debug (
227
+ f"Detected custom decomposition for { aten_op } , which conflicts "
228
+ "with an existing Torch decomposition in core_aten_decompositions. "
229
+ "The custom implementation will take precedence."
230
+ )
231
+
232
+ def register (fn : Callable ) -> Callable :
233
+ return register_decomposition (aten_op = aten_op , registry = registry )(fn )
234
+
235
+ return register
236
+
237
+
182
238
def replace_inplace_op (aten_op , outplace_op ):
183
239
"""Replace inplace operation with functional equivalent
184
240
Adapted from:
185
241
https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361
186
242
"""
187
243
188
- @register_decomposition (aten_op , registry = TORCH_TRT_DECOMPOSITIONS )
244
+ @register_torch_trt_decomposition (aten_op , registry = TORCH_TRT_DECOMPOSITIONS )
189
245
def inplace_op (* args , ** kwargs ):
190
246
out = outplace_op (* args , ** kwargs )
191
247
return args [0 ].copy_ (out )
@@ -208,34 +264,36 @@ def inplace_op(*args, **kwargs):
208
264
replace_inplace_op (aten .scatter_reduce_ , aten .scatter_reduce )
209
265
210
266
211
- @register_decomposition (aten .std , registry = TORCH_TRT_DECOMPOSITIONS )
267
+ @register_torch_trt_decomposition (aten .std , registry = TORCH_TRT_DECOMPOSITIONS )
212
268
def std_replacement (* args , ** kwargs ) -> torch .Tensor :
213
269
return torch .sqrt (torch .var (* args , ** kwargs ))
214
270
215
271
216
- @register_decomposition (aten .rsqrt , registry = TORCH_TRT_DECOMPOSITIONS )
272
+ @register_torch_trt_decomposition (aten .rsqrt , registry = TORCH_TRT_DECOMPOSITIONS )
217
273
def rsqrt_replacement (* args , ** kwargs ) -> torch .Tensor :
218
274
return torch .reciprocal (torch .sqrt (* args , ** kwargs ))
219
275
220
276
221
- @register_decomposition (aten ._unsafe_view , registry = TORCH_TRT_DECOMPOSITIONS )
277
+ @register_torch_trt_decomposition (aten ._unsafe_view , registry = TORCH_TRT_DECOMPOSITIONS )
222
278
def unsafe_view_replacement (x : torch .Tensor , * args , ** kwargs ) -> torch .Tensor :
223
279
return torch .reshape (x , * args , ** kwargs )
224
280
225
281
226
- @register_decomposition (
282
+ @register_torch_trt_decomposition (
227
283
torch .ops .aten .lift_fresh_copy , registry = TORCH_TRT_DECOMPOSITIONS
228
284
)
229
285
def lift_fresh_copy_replacement (x : torch .Tensor ) -> torch .Tensor :
230
286
return x
231
287
232
288
233
- @register_decomposition (aten .alias , registry = TORCH_TRT_DECOMPOSITIONS )
289
+ @register_torch_trt_decomposition (aten .alias , registry = TORCH_TRT_DECOMPOSITIONS )
234
290
def alias_replacement (x : torch .Tensor ) -> torch .Tensor :
235
291
return x
236
292
237
293
238
- @register_decomposition (torch .ops .aten .addmm , registry = TORCH_TRT_DECOMPOSITIONS )
294
+ @register_torch_trt_decomposition (
295
+ torch .ops .aten .addmm , registry = TORCH_TRT_DECOMPOSITIONS
296
+ )
239
297
def addmm_replacement (
240
298
input_ : torch .Tensor , mat1 : torch .Tensor , mat2 : torch .Tensor , * , beta = 1 , alpha = 1
241
299
) -> torch .Tensor :
@@ -244,7 +302,7 @@ def addmm_replacement(
244
302
)
245
303
246
304
247
- @register_decomposition (
305
+ @register_torch_trt_decomposition (
248
306
torch .ops .aten .reciprocal .default , registry = TORCH_TRT_DECOMPOSITIONS
249
307
)
250
308
def reciprocal_replacement (
@@ -260,17 +318,8 @@ def get_decompositions(
260
318
CORE_ATEN_DECOMPOSITIONS_FILTERED : Dict [torch ._ops .OpOverload , Callable ] = {
261
319
decomp : _core_aten_decompositions [decomp ]
262
320
for decomp in _core_aten_decompositions
263
- if (
264
- decomp not in TORCH_TRT_DECOMPOSITIONS
265
- and decomp not in disabled_decompositions
266
- )
321
+ if decomp not in torch_disabled_decompositions
267
322
}
268
323
return {** CORE_ATEN_DECOMPOSITIONS_FILTERED , ** TORCH_TRT_DECOMPOSITIONS }
269
324
else :
270
- duplicate_registrations = set (ENABLED_TORCH_DECOMPOSITIONS .keys ()).intersection (
271
- set (TORCH_TRT_DECOMPOSITIONS .keys ())
272
- )
273
- assert (
274
- not duplicate_registrations
275
- ), f"Detected duplicate decompositions on: { duplicate_registrations } "
276
325
return {** ENABLED_TORCH_DECOMPOSITIONS , ** TORCH_TRT_DECOMPOSITIONS }
0 commit comments