Skip to content

Commit 806b348

Browse files
committed
feat: Add registration-time warnings for decomp.
- Add decorator-wrapper to perform import-time checks on decompositions and alert the user if any custom decompositions conflict with existing registered or specified operators - Simplify code logic for dictionary merging in `get_decompositions` function - Add safety logic to ensure invariants about the decompositions are not violated
1 parent 368a20e commit 806b348

File tree

1 file changed

+69
-21
lines changed

1 file changed

+69
-21
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 69 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Callable, Dict, Set
22
import torch
3+
import logging
34
from torch._decomp import (
45
register_decomposition,
56
core_aten_decompositions,
@@ -11,7 +12,7 @@
1112
_core_aten_decompositions: Dict[
1213
torch._ops.OpOverload, Callable
1314
] = core_aten_decompositions()
14-
enabled_decompositions: Set[torch._ops.OpOverload] = {
15+
torch_enabled_decompositions: Set[torch._ops.OpOverload] = {
1516
aten._adaptive_avg_pool2d_backward,
1617
aten.addcdiv,
1718
aten.addcdiv_,
@@ -171,21 +172,75 @@
171172
aten.zeros,
172173
aten.zeros_like,
173174
}
174-
disabled_decompositions: Set[torch._ops.OpOverload] = {}
175+
torch_disabled_decompositions: Set[torch._ops.OpOverload] = {}
176+
175177

176178
ENABLED_TORCH_DECOMPOSITIONS: Dict[
177179
torch._ops.OpOverload, Callable
178-
] = get_torch_decompositions(enabled_decompositions)
180+
] = get_torch_decompositions(torch_enabled_decompositions)
179181
TORCH_TRT_DECOMPOSITIONS: Dict[torch._ops.OpOverload, Callable] = {}
180182

181183

184+
logger = logging.getLogger(__name__)
185+
186+
187+
def check_decomp_set_invariants():
188+
"""Validates no overlap between enabled and disabled decomposition sets"""
189+
overlap = torch_enabled_decompositions.intersection(torch_disabled_decompositions)
190+
191+
if overlap:
192+
raise AssertionError(
193+
f"Detected {overlap} registered in both torch_enabled_decompositions "
194+
"and torch_disabled_decompositions. Ensure all operator(s) are in "
195+
"at most one of the two sets."
196+
)
197+
198+
199+
check_decomp_set_invariants()
200+
201+
202+
def register_torch_trt_decomposition(aten_op, registry=None):
203+
"""Checks if the decomposition already exists in one of the sets
204+
Registers the decomposition via the Torch utility
205+
206+
Alerts the user if the decomposition already exists, before registering
207+
Throws an AssertionError if the user attempts to register a decomposition
208+
which is present in the set of explicitly disabled decompositions
209+
"""
210+
if aten_op in torch_enabled_decompositions:
211+
logger.warning(
212+
f"Detected custom decomposition for {aten_op}, which conflicts "
213+
"with an existing Torch decomposition in torch_enabled_decompositions. "
214+
"The custom implementation will take precedence."
215+
)
216+
elif aten_op in torch_disabled_decompositions:
217+
logger.info(
218+
f"Detected custom decomposition for {aten_op}, which is present "
219+
"in torch_disabled_decompositions."
220+
)
221+
222+
# Conflicts with _core_aten_decompositions will only occur if
223+
# enable_experimental_decompositions is True in get_decompositions
224+
if aten_op in _core_aten_decompositions:
225+
logger.debug(
226+
f"Detected custom decomposition for {aten_op}, which conflicts "
227+
"with an existing Torch decomposition in core_aten_decompositions. "
228+
"The custom implementation will take precedence."
229+
)
230+
231+
def register(fn: Callable) -> Callable:
232+
return register_decomposition(aten_op=aten_op, registry=registry)(fn)
233+
234+
return register
235+
236+
182237
def replace_inplace_op(aten_op, outplace_op):
183238
"""Replace inplace operation with functional equivalent
184239
Adapted from:
185240
https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361
186241
"""
187242

188-
@register_decomposition(aten_op, registry=TORCH_TRT_DECOMPOSITIONS)
243+
@register_torch_trt_decomposition(aten_op, registry=TORCH_TRT_DECOMPOSITIONS)
189244
def inplace_op(*args, **kwargs):
190245
out = outplace_op(*args, **kwargs)
191246
return args[0].copy_(out)
@@ -208,34 +263,36 @@ def inplace_op(*args, **kwargs):
208263
replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce)
209264

210265

211-
@register_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS)
266+
@register_torch_trt_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS)
212267
def std_replacement(*args, **kwargs) -> torch.Tensor:
213268
return torch.sqrt(torch.var(*args, **kwargs))
214269

215270

216-
@register_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS)
271+
@register_torch_trt_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS)
217272
def rsqrt_replacement(*args, **kwargs) -> torch.Tensor:
218273
return torch.reciprocal(torch.sqrt(*args, **kwargs))
219274

220275

221-
@register_decomposition(aten._unsafe_view, registry=TORCH_TRT_DECOMPOSITIONS)
276+
@register_torch_trt_decomposition(aten._unsafe_view, registry=TORCH_TRT_DECOMPOSITIONS)
222277
def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
223278
return torch.reshape(x, *args, **kwargs)
224279

225280

226-
@register_decomposition(
281+
@register_torch_trt_decomposition(
227282
torch.ops.aten.lift_fresh_copy, registry=TORCH_TRT_DECOMPOSITIONS
228283
)
229284
def lift_fresh_copy_replacement(x: torch.Tensor) -> torch.Tensor:
230285
return x
231286

232287

233-
@register_decomposition(aten.alias, registry=TORCH_TRT_DECOMPOSITIONS)
288+
@register_torch_trt_decomposition(aten.alias, registry=TORCH_TRT_DECOMPOSITIONS)
234289
def alias_replacement(x: torch.Tensor) -> torch.Tensor:
235290
return x
236291

237292

238-
@register_decomposition(torch.ops.aten.addmm, registry=TORCH_TRT_DECOMPOSITIONS)
293+
@register_torch_trt_decomposition(
294+
torch.ops.aten.addmm, registry=TORCH_TRT_DECOMPOSITIONS
295+
)
239296
def addmm_replacement(
240297
input_: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, *, beta=1, alpha=1
241298
) -> torch.Tensor:
@@ -244,7 +301,7 @@ def addmm_replacement(
244301
)
245302

246303

247-
@register_decomposition(
304+
@register_torch_trt_decomposition(
248305
torch.ops.aten.reciprocal.default, registry=TORCH_TRT_DECOMPOSITIONS
249306
)
250307
def reciprocal_replacement(
@@ -260,17 +317,8 @@ def get_decompositions(
260317
CORE_ATEN_DECOMPOSITIONS_FILTERED: Dict[torch._ops.OpOverload, Callable] = {
261318
decomp: _core_aten_decompositions[decomp]
262319
for decomp in _core_aten_decompositions
263-
if (
264-
decomp not in TORCH_TRT_DECOMPOSITIONS
265-
and decomp not in disabled_decompositions
266-
)
320+
if decomp not in torch_disabled_decompositions
267321
}
268322
return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **TORCH_TRT_DECOMPOSITIONS}
269323
else:
270-
duplicate_registrations = set(ENABLED_TORCH_DECOMPOSITIONS.keys()).intersection(
271-
set(TORCH_TRT_DECOMPOSITIONS.keys())
272-
)
273-
assert (
274-
not duplicate_registrations
275-
), f"Detected duplicate decompositions on: {duplicate_registrations}"
276324
return {**ENABLED_TORCH_DECOMPOSITIONS, **TORCH_TRT_DECOMPOSITIONS}

0 commit comments

Comments
 (0)