11from typing import Callable , Dict , Set
22import torch
3+ import logging
34from torch ._decomp import (
45 register_decomposition ,
56 core_aten_decompositions ,
1112_core_aten_decompositions : Dict [
1213 torch ._ops .OpOverload , Callable
1314] = core_aten_decompositions ()
14- enabled_decompositions : Set [torch ._ops .OpOverload ] = {
15+ torch_enabled_decompositions : Set [torch ._ops .OpOverload ] = {
1516 aten ._adaptive_avg_pool2d_backward ,
1617 aten .addcdiv ,
1718 aten .addcdiv_ ,
171172 aten .zeros ,
172173 aten .zeros_like ,
173174}
174- disabled_decompositions : Set [torch ._ops .OpOverload ] = {}
175+ torch_disabled_decompositions : Set [torch ._ops .OpOverload ] = {}
176+
177+ assert torch_enabled_decompositions .intersection (torch_disabled_decompositions ) == 0
175178
176179ENABLED_TORCH_DECOMPOSITIONS : Dict [
177180 torch ._ops .OpOverload , Callable
178- ] = get_torch_decompositions (enabled_decompositions )
181+ ] = get_torch_decompositions (torch_enabled_decompositions )
179182TORCH_TRT_DECOMPOSITIONS : Dict [torch ._ops .OpOverload , Callable ] = {}
180183
181184
185+ logger = logging .getLogger (__name__ )
186+
187+
188+ def check_decomp_set_invariants ():
189+ """Validates no overlap between enabled and disabled decomposition sets"""
190+ overlap = torch_enabled_decompositions .intersection (torch_disabled_decompositions )
191+
192+ if overlap :
193+ raise AssertionError (
194+ f"Detected { overlap } registered in both torch_enabled_decompositions "
195+ "and torch_disabled_decompositions. Ensure all operator(s) are in "
196+ "at most one of the two sets."
197+ )
198+
199+
200+ check_decomp_set_invariants ()
201+
202+
203+ def register_torch_trt_decomposition (aten_op , registry = None ):
204+ """Checks if the decomposition already exists in one of the sets
205+ Registers the decomposition via the Torch utility
206+
207+ Alerts the user if the decomposition already exists, before registering
208+ Throws an AssertionError if the user attempts to register a decomposition
209+ which is present in the set of explicitly disabled decompositions
210+ """
211+ if aten_op in torch_enabled_decompositions :
212+ logger .warning (
213+ f"Detected custom decomposition for { aten_op } , which conflicts "
214+ "with an existing Torch decomposition in torch_enabled_decompositions. "
215+ "The custom implementation will take precedence."
216+ )
217+ elif aten_op in torch_disabled_decompositions :
218+ logger .info (
219+ f"Detected custom decomposition for { aten_op } , which is present "
220+ "in torch_disabled_decompositions."
221+ )
222+
223+ # Conflicts with _core_aten_decompositions will only occur if
224+ # enable_experimental_decompositions is True in get_decompositions
225+ if aten_op in _core_aten_decompositions :
226+ logger .debug (
227+ f"Detected custom decomposition for { aten_op } , which conflicts "
228+ "with an existing Torch decomposition in core_aten_decompositions. "
229+ "The custom implementation will take precedence."
230+ )
231+
232+ def register (fn : Callable ) -> Callable :
233+ return register_decomposition (aten_op = aten_op , registry = registry )(fn )
234+
235+ return register
236+
237+
182238def replace_inplace_op (aten_op , outplace_op ):
183239 """Replace inplace operation with functional equivalent
184240 Adapted from:
185241 https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361
186242 """
187243
188- @register_decomposition (aten_op , registry = TORCH_TRT_DECOMPOSITIONS )
244+ @register_torch_trt_decomposition (aten_op , registry = TORCH_TRT_DECOMPOSITIONS )
189245 def inplace_op (* args , ** kwargs ):
190246 out = outplace_op (* args , ** kwargs )
191247 return args [0 ].copy_ (out )
@@ -208,34 +264,36 @@ def inplace_op(*args, **kwargs):
208264replace_inplace_op (aten .scatter_reduce_ , aten .scatter_reduce )
209265
210266
211- @register_decomposition (aten .std , registry = TORCH_TRT_DECOMPOSITIONS )
267+ @register_torch_trt_decomposition (aten .std , registry = TORCH_TRT_DECOMPOSITIONS )
212268def std_replacement (* args , ** kwargs ) -> torch .Tensor :
213269 return torch .sqrt (torch .var (* args , ** kwargs ))
214270
215271
216- @register_decomposition (aten .rsqrt , registry = TORCH_TRT_DECOMPOSITIONS )
272+ @register_torch_trt_decomposition (aten .rsqrt , registry = TORCH_TRT_DECOMPOSITIONS )
217273def rsqrt_replacement (* args , ** kwargs ) -> torch .Tensor :
218274 return torch .reciprocal (torch .sqrt (* args , ** kwargs ))
219275
220276
221- @register_decomposition (aten ._unsafe_view , registry = TORCH_TRT_DECOMPOSITIONS )
277+ @register_torch_trt_decomposition (aten ._unsafe_view , registry = TORCH_TRT_DECOMPOSITIONS )
222278def unsafe_view_replacement (x : torch .Tensor , * args , ** kwargs ) -> torch .Tensor :
223279 return torch .reshape (x , * args , ** kwargs )
224280
225281
226- @register_decomposition (
282+ @register_torch_trt_decomposition (
227283 torch .ops .aten .lift_fresh_copy , registry = TORCH_TRT_DECOMPOSITIONS
228284)
229285def lift_fresh_copy_replacement (x : torch .Tensor ) -> torch .Tensor :
230286 return x
231287
232288
233- @register_decomposition (aten .alias , registry = TORCH_TRT_DECOMPOSITIONS )
289+ @register_torch_trt_decomposition (aten .alias , registry = TORCH_TRT_DECOMPOSITIONS )
234290def alias_replacement (x : torch .Tensor ) -> torch .Tensor :
235291 return x
236292
237293
238- @register_decomposition (torch .ops .aten .addmm , registry = TORCH_TRT_DECOMPOSITIONS )
294+ @register_torch_trt_decomposition (
295+ torch .ops .aten .addmm , registry = TORCH_TRT_DECOMPOSITIONS
296+ )
239297def addmm_replacement (
240298 input_ : torch .Tensor , mat1 : torch .Tensor , mat2 : torch .Tensor , * , beta = 1 , alpha = 1
241299) -> torch .Tensor :
@@ -244,7 +302,7 @@ def addmm_replacement(
244302 )
245303
246304
247- @register_decomposition (
305+ @register_torch_trt_decomposition (
248306 torch .ops .aten .reciprocal .default , registry = TORCH_TRT_DECOMPOSITIONS
249307)
250308def reciprocal_replacement (
@@ -260,17 +318,8 @@ def get_decompositions(
260318 CORE_ATEN_DECOMPOSITIONS_FILTERED : Dict [torch ._ops .OpOverload , Callable ] = {
261319 decomp : _core_aten_decompositions [decomp ]
262320 for decomp in _core_aten_decompositions
263- if (
264- decomp not in TORCH_TRT_DECOMPOSITIONS
265- and decomp not in disabled_decompositions
266- )
321+ if decomp not in torch_disabled_decompositions
267322 }
268323 return {** CORE_ATEN_DECOMPOSITIONS_FILTERED , ** TORCH_TRT_DECOMPOSITIONS }
269324 else :
270- duplicate_registrations = set (ENABLED_TORCH_DECOMPOSITIONS .keys ()).intersection (
271- set (TORCH_TRT_DECOMPOSITIONS .keys ())
272- )
273- assert (
274- not duplicate_registrations
275- ), f"Detected duplicate decompositions on: { duplicate_registrations } "
276325 return {** ENABLED_TORCH_DECOMPOSITIONS , ** TORCH_TRT_DECOMPOSITIONS }
0 commit comments