Skip to content

Commit 8b12de5

Browse files
committed
feat: Add registration-time warnings for decomp.
- Add decorator-wrapper to perform import-time checks on decompositions and alert the user if any custom decompositions conflict with existing registered or specified operators - Simplify code logic for dictionary merging in `get_decompositions` function - Add safety logic to ensure invariants about the decompositions are not violated
1 parent 368a20e commit 8b12de5

File tree

1 file changed

+70
-21
lines changed

1 file changed

+70
-21
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Callable, Dict, Set
22
import torch
3+
import logging
34
from torch._decomp import (
45
register_decomposition,
56
core_aten_decompositions,
@@ -11,7 +12,7 @@
1112
_core_aten_decompositions: Dict[
1213
torch._ops.OpOverload, Callable
1314
] = core_aten_decompositions()
14-
enabled_decompositions: Set[torch._ops.OpOverload] = {
15+
torch_enabled_decompositions: Set[torch._ops.OpOverload] = {
1516
aten._adaptive_avg_pool2d_backward,
1617
aten.addcdiv,
1718
aten.addcdiv_,
@@ -171,21 +172,76 @@
171172
aten.zeros,
172173
aten.zeros_like,
173174
}
174-
disabled_decompositions: Set[torch._ops.OpOverload] = {}
175+
torch_disabled_decompositions: Set[torch._ops.OpOverload] = {}
176+
177+
assert torch_enabled_decompositions.intersection(torch_disabled_decompositions) == 0
175178

176179
ENABLED_TORCH_DECOMPOSITIONS: Dict[
177180
torch._ops.OpOverload, Callable
178-
] = get_torch_decompositions(enabled_decompositions)
181+
] = get_torch_decompositions(torch_enabled_decompositions)
179182
TORCH_TRT_DECOMPOSITIONS: Dict[torch._ops.OpOverload, Callable] = {}
180183

181184

185+
logger = logging.getLogger(__name__)
186+
187+
188+
def check_decomp_set_invariants():
189+
"""Validates no overlap between enabled and disabled decomposition sets"""
190+
overlap = torch_enabled_decompositions.intersection(torch_disabled_decompositions)
191+
192+
if overlap:
193+
raise AssertionError(
194+
f"Detected {overlap} registered in both torch_enabled_decompositions "
195+
"and torch_disabled_decompositions. Ensure all operator(s) are in "
196+
"at most one of the two sets."
197+
)
198+
199+
200+
check_decomp_set_invariants()
201+
202+
203+
def register_torch_trt_decomposition(aten_op, registry=None):
204+
"""Checks if the decomposition already exists in one of the sets
205+
Registers the decomposition via the Torch utility
206+
207+
Alerts the user if the decomposition already exists, before registering
208+
Throws an AssertionError if the user attempts to register a decomposition
209+
which is present in the set of explicitly disabled decompositions
210+
"""
211+
if aten_op in torch_enabled_decompositions:
212+
logger.warning(
213+
f"Detected custom decomposition for {aten_op}, which conflicts "
214+
"with an existing Torch decomposition in torch_enabled_decompositions. "
215+
"The custom implementation will take precedence."
216+
)
217+
elif aten_op in torch_disabled_decompositions:
218+
logger.info(
219+
f"Detected custom decomposition for {aten_op}, which is present "
220+
"in torch_disabled_decompositions."
221+
)
222+
223+
# Conflicts with _core_aten_decompositions will only occur if
224+
# enable_experimental_decompositions is True in get_decompositions
225+
if aten_op in _core_aten_decompositions:
226+
logger.debug(
227+
f"Detected custom decomposition for {aten_op}, which conflicts "
228+
"with an existing Torch decomposition in core_aten_decompositions. "
229+
"The custom implementation will take precedence."
230+
)
231+
232+
def register(fn: Callable) -> Callable:
233+
return register_decomposition(aten_op=aten_op, registry=registry)(fn)
234+
235+
return register
236+
237+
182238
def replace_inplace_op(aten_op, outplace_op):
183239
"""Replace inplace operation with functional equivalent
184240
Adapted from:
185241
https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361
186242
"""
187243

188-
@register_decomposition(aten_op, registry=TORCH_TRT_DECOMPOSITIONS)
244+
@register_torch_trt_decomposition(aten_op, registry=TORCH_TRT_DECOMPOSITIONS)
189245
def inplace_op(*args, **kwargs):
190246
out = outplace_op(*args, **kwargs)
191247
return args[0].copy_(out)
@@ -208,34 +264,36 @@ def inplace_op(*args, **kwargs):
208264
replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce)
209265

210266

211-
@register_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS)
267+
@register_torch_trt_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS)
212268
def std_replacement(*args, **kwargs) -> torch.Tensor:
213269
return torch.sqrt(torch.var(*args, **kwargs))
214270

215271

216-
@register_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS)
272+
@register_torch_trt_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS)
217273
def rsqrt_replacement(*args, **kwargs) -> torch.Tensor:
218274
return torch.reciprocal(torch.sqrt(*args, **kwargs))
219275

220276

221-
@register_decomposition(aten._unsafe_view, registry=TORCH_TRT_DECOMPOSITIONS)
277+
@register_torch_trt_decomposition(aten._unsafe_view, registry=TORCH_TRT_DECOMPOSITIONS)
222278
def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
223279
return torch.reshape(x, *args, **kwargs)
224280

225281

226-
@register_decomposition(
282+
@register_torch_trt_decomposition(
227283
torch.ops.aten.lift_fresh_copy, registry=TORCH_TRT_DECOMPOSITIONS
228284
)
229285
def lift_fresh_copy_replacement(x: torch.Tensor) -> torch.Tensor:
230286
return x
231287

232288

233-
@register_decomposition(aten.alias, registry=TORCH_TRT_DECOMPOSITIONS)
289+
@register_torch_trt_decomposition(aten.alias, registry=TORCH_TRT_DECOMPOSITIONS)
234290
def alias_replacement(x: torch.Tensor) -> torch.Tensor:
235291
return x
236292

237293

238-
@register_decomposition(torch.ops.aten.addmm, registry=TORCH_TRT_DECOMPOSITIONS)
294+
@register_torch_trt_decomposition(
295+
torch.ops.aten.addmm, registry=TORCH_TRT_DECOMPOSITIONS
296+
)
239297
def addmm_replacement(
240298
input_: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, *, beta=1, alpha=1
241299
) -> torch.Tensor:
@@ -244,7 +302,7 @@ def addmm_replacement(
244302
)
245303

246304

247-
@register_decomposition(
305+
@register_torch_trt_decomposition(
248306
torch.ops.aten.reciprocal.default, registry=TORCH_TRT_DECOMPOSITIONS
249307
)
250308
def reciprocal_replacement(
@@ -260,17 +318,8 @@ def get_decompositions(
260318
CORE_ATEN_DECOMPOSITIONS_FILTERED: Dict[torch._ops.OpOverload, Callable] = {
261319
decomp: _core_aten_decompositions[decomp]
262320
for decomp in _core_aten_decompositions
263-
if (
264-
decomp not in TORCH_TRT_DECOMPOSITIONS
265-
and decomp not in disabled_decompositions
266-
)
321+
if decomp not in torch_disabled_decompositions
267322
}
268323
return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **TORCH_TRT_DECOMPOSITIONS}
269324
else:
270-
duplicate_registrations = set(ENABLED_TORCH_DECOMPOSITIONS.keys()).intersection(
271-
set(TORCH_TRT_DECOMPOSITIONS.keys())
272-
)
273-
assert (
274-
not duplicate_registrations
275-
), f"Detected duplicate decompositions on: {duplicate_registrations}"
276325
return {**ENABLED_TORCH_DECOMPOSITIONS, **TORCH_TRT_DECOMPOSITIONS}

0 commit comments

Comments
 (0)