11from typing import Callable , Dict , Set
22import torch
3+ import logging
34from torch ._decomp import (
45 register_decomposition ,
56 core_aten_decompositions ,
1112_core_aten_decompositions : Dict [
1213 torch ._ops .OpOverload , Callable
1314] = core_aten_decompositions ()
14- enabled_decompositions : Set [torch ._ops .OpOverload ] = {
15+ torch_enabled_decompositions : Set [torch ._ops .OpOverload ] = {
1516 aten ._adaptive_avg_pool2d_backward ,
1617 aten .addcdiv ,
1718 aten .addcdiv_ ,
171172 aten .zeros ,
172173 aten .zeros_like ,
173174}
174- disabled_decompositions : Set [torch ._ops .OpOverload ] = {}
175+ torch_disabled_decompositions : Set [torch ._ops .OpOverload ] = {}
176+
175177
176178ENABLED_TORCH_DECOMPOSITIONS : Dict [
177179 torch ._ops .OpOverload , Callable
178- ] = get_torch_decompositions (enabled_decompositions )
180+ ] = get_torch_decompositions (torch_enabled_decompositions )
179181TORCH_TRT_DECOMPOSITIONS : Dict [torch ._ops .OpOverload , Callable ] = {}
180182
181183
184+ logger = logging .getLogger (__name__ )
185+
186+
187+ def check_decomp_set_invariants ():
188+ """Validates no overlap between enabled and disabled decomposition sets"""
189+ overlap = torch_enabled_decompositions .intersection (torch_disabled_decompositions )
190+
191+ if overlap :
192+ raise AssertionError (
193+ f"Detected { overlap } registered in both torch_enabled_decompositions "
194+ "and torch_disabled_decompositions. Ensure all operator(s) are in "
195+ "at most one of the two sets."
196+ )
197+
198+
199+ check_decomp_set_invariants ()
200+
201+
202+ def register_torch_trt_decomposition (aten_op , registry = None ):
203+ """Checks if the decomposition already exists in one of the sets
204+ Registers the decomposition via the Torch utility
205+
206+ Alerts the user if the decomposition already exists, before registering
207+ Throws an AssertionError if the user attempts to register a decomposition
208+ which is present in the set of explicitly disabled decompositions
209+ """
210+ if aten_op in torch_enabled_decompositions :
211+ logger .warning (
212+ f"Detected custom decomposition for { aten_op } , which conflicts "
213+ "with an existing Torch decomposition in torch_enabled_decompositions. "
214+ "The custom implementation will take precedence."
215+ )
216+ elif aten_op in torch_disabled_decompositions :
217+ logger .info (
218+ f"Detected custom decomposition for { aten_op } , which is present "
219+ "in torch_disabled_decompositions."
220+ )
221+
222+ # Conflicts with _core_aten_decompositions will only occur if
223+ # enable_experimental_decompositions is True in get_decompositions
224+ if aten_op in _core_aten_decompositions :
225+ logger .debug (
226+ f"Detected custom decomposition for { aten_op } , which conflicts "
227+ "with an existing Torch decomposition in core_aten_decompositions. "
228+ "The custom implementation will take precedence."
229+ )
230+
231+ def register (fn : Callable ) -> Callable :
232+ return register_decomposition (aten_op = aten_op , registry = registry )(fn )
233+
234+ return register
235+
236+
182237def replace_inplace_op (aten_op , outplace_op ):
183238 """Replace inplace operation with functional equivalent
184239 Adapted from:
185240 https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361
186241 """
187242
188- @register_decomposition (aten_op , registry = TORCH_TRT_DECOMPOSITIONS )
243+ @register_torch_trt_decomposition (aten_op , registry = TORCH_TRT_DECOMPOSITIONS )
189244 def inplace_op (* args , ** kwargs ):
190245 out = outplace_op (* args , ** kwargs )
191246 return args [0 ].copy_ (out )
@@ -208,34 +263,36 @@ def inplace_op(*args, **kwargs):
208263replace_inplace_op (aten .scatter_reduce_ , aten .scatter_reduce )
209264
210265
211- @register_decomposition (aten .std , registry = TORCH_TRT_DECOMPOSITIONS )
266+ @register_torch_trt_decomposition (aten .std , registry = TORCH_TRT_DECOMPOSITIONS )
212267def std_replacement (* args , ** kwargs ) -> torch .Tensor :
213268 return torch .sqrt (torch .var (* args , ** kwargs ))
214269
215270
216- @register_decomposition (aten .rsqrt , registry = TORCH_TRT_DECOMPOSITIONS )
271+ @register_torch_trt_decomposition (aten .rsqrt , registry = TORCH_TRT_DECOMPOSITIONS )
217272def rsqrt_replacement (* args , ** kwargs ) -> torch .Tensor :
218273 return torch .reciprocal (torch .sqrt (* args , ** kwargs ))
219274
220275
221- @register_decomposition (aten ._unsafe_view , registry = TORCH_TRT_DECOMPOSITIONS )
276+ @register_torch_trt_decomposition (aten ._unsafe_view , registry = TORCH_TRT_DECOMPOSITIONS )
222277def unsafe_view_replacement (x : torch .Tensor , * args , ** kwargs ) -> torch .Tensor :
223278 return torch .reshape (x , * args , ** kwargs )
224279
225280
226- @register_decomposition (
281+ @register_torch_trt_decomposition (
227282 torch .ops .aten .lift_fresh_copy , registry = TORCH_TRT_DECOMPOSITIONS
228283)
229284def lift_fresh_copy_replacement (x : torch .Tensor ) -> torch .Tensor :
230285 return x
231286
232287
233- @register_decomposition (aten .alias , registry = TORCH_TRT_DECOMPOSITIONS )
288+ @register_torch_trt_decomposition (aten .alias , registry = TORCH_TRT_DECOMPOSITIONS )
234289def alias_replacement (x : torch .Tensor ) -> torch .Tensor :
235290 return x
236291
237292
238- @register_decomposition (torch .ops .aten .addmm , registry = TORCH_TRT_DECOMPOSITIONS )
293+ @register_torch_trt_decomposition (
294+ torch .ops .aten .addmm , registry = TORCH_TRT_DECOMPOSITIONS
295+ )
239296def addmm_replacement (
240297 input_ : torch .Tensor , mat1 : torch .Tensor , mat2 : torch .Tensor , * , beta = 1 , alpha = 1
241298) -> torch .Tensor :
@@ -244,7 +301,7 @@ def addmm_replacement(
244301 )
245302
246303
247- @register_decomposition (
304+ @register_torch_trt_decomposition (
248305 torch .ops .aten .reciprocal .default , registry = TORCH_TRT_DECOMPOSITIONS
249306)
250307def reciprocal_replacement (
@@ -260,17 +317,8 @@ def get_decompositions(
260317 CORE_ATEN_DECOMPOSITIONS_FILTERED : Dict [torch ._ops .OpOverload , Callable ] = {
261318 decomp : _core_aten_decompositions [decomp ]
262319 for decomp in _core_aten_decompositions
263- if (
264- decomp not in TORCH_TRT_DECOMPOSITIONS
265- and decomp not in disabled_decompositions
266- )
320+ if decomp not in torch_disabled_decompositions
267321 }
268322 return {** CORE_ATEN_DECOMPOSITIONS_FILTERED , ** TORCH_TRT_DECOMPOSITIONS }
269323 else :
270- duplicate_registrations = set (ENABLED_TORCH_DECOMPOSITIONS .keys ()).intersection (
271- set (TORCH_TRT_DECOMPOSITIONS .keys ())
272- )
273- assert (
274- not duplicate_registrations
275- ), f"Detected duplicate decompositions on: { duplicate_registrations } "
276324 return {** ENABLED_TORCH_DECOMPOSITIONS , ** TORCH_TRT_DECOMPOSITIONS }
0 commit comments