Skip to content

Commit bdb06d8

Browse files
committed
fix: Add enabled/disabled sets for decompositions
- Add sets to selectively enable or disable decompositions in Torch - Add new runtime argument `enable_experimental_decompositions` to enable all core aten decompositions, or a pre-selected subset thereof - Improve documentation of compilation settings overall
1 parent 0527edd commit bdb06d8

File tree

5 files changed

+238
-15
lines changed

5 files changed

+238
-15
lines changed

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
OPTIMIZATION_LEVEL = None
1111
USE_PYTHON_RUNTIME = None
1212
TRUNCATE_LONG_AND_DOUBLE = False
13+
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,33 @@
1212
OPTIMIZATION_LEVEL,
1313
USE_PYTHON_RUNTIME,
1414
TRUNCATE_LONG_AND_DOUBLE,
15+
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
1516
)
1617

1718

1819
@dataclass
1920
class CompilationSettings:
21+
"""Compilation settings for Torch-TensorRT Dynamo Paths
22+
23+
Args:
24+
precision (torch.dtype): Model Layer precision
25+
debug (bool): Whether to print out verbose debugging information
26+
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
27+
min_block_size (int): Minimum number of operators per TRT-Engine Block
28+
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
29+
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
30+
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
31+
version_compatible (bool): Provide version forward-compatibility for engine plan files
32+
optimization_level (Optional[int]): Builder optimization 0-5, higher levels imply longer build time,
33+
searching for more optimization options. TRT defaults to 3
34+
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
35+
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
36+
argument as None
37+
truncate_long_and_double (bool): Truncate int64/float64 TRT engine inputs or weights to int32/float32
38+
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
39+
or only a selected subset of them
40+
"""
41+
2042
precision: torch.dtype = PRECISION
2143
debug: bool = DEBUG
2244
workspace_size: int = WORKSPACE_SIZE
@@ -28,3 +50,4 @@ class CompilationSettings:
2850
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
2951
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
3052
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
53+
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def aot_torch_tensorrt_aten_backend(
5555
gm,
5656
sample_inputs,
5757
fw_compiler=make_boxed_compiler(custom_backend),
58-
decompositions=get_decompositions(),
58+
decompositions=get_decompositions(settings.enable_experimental_decompositions),
5959
)
6060

6161

py/torch_tensorrt/dynamo/compile.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
OPTIMIZATION_LEVEL,
3232
USE_PYTHON_RUNTIME,
3333
TRUNCATE_LONG_AND_DOUBLE,
34+
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
3435
)
3536

3637

@@ -64,6 +65,7 @@ def compile(
6465
version_compatible=VERSION_COMPATIBLE,
6566
optimization_level=OPTIMIZATION_LEVEL,
6667
use_python_runtime=USE_PYTHON_RUNTIME,
68+
enable_experimental_decompositions=ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
6769
**kwargs,
6870
):
6971
if debug:
@@ -73,7 +75,7 @@ def compile(
7375
"The Dynamo backend is an experimental feature, for which only the "
7476
+ "following arguments are supported: "
7577
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
76-
+ "torch_executed_ops, pass_through_build_failures}"
78+
+ "torch_executed_ops, pass_through_build_failures, enable_experimental_decompositions}"
7779
)
7880

7981
if not isinstance(inputs, collections.abc.Sequence):
@@ -111,6 +113,7 @@ def compile(
111113
"optimization_level": optimization_level,
112114
"use_python_runtime": use_python_runtime,
113115
"truncate_long_and_double": truncate_long_and_double,
116+
"enable_experimental_decompositions": enable_experimental_decompositions,
114117
}
115118

116119
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 209 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,187 @@
1+
from typing import Callable, Dict, Set
12
import torch
2-
from torch._decomp import register_decomposition, core_aten_decompositions
3+
from torch._decomp import (
4+
register_decomposition,
5+
core_aten_decompositions,
6+
get_decompositions as get_torch_decompositions,
7+
)
38

9+
aten = torch.ops.aten
410

5-
DECOMPOSITIONS = {**core_aten_decompositions()}
11+
_core_aten_decompositions: Dict[
12+
torch._ops.OpOverload, Callable
13+
] = core_aten_decompositions()
14+
enabled_decompositions: Set[torch._ops.OpOverload] = {
15+
aten._adaptive_avg_pool2d_backward,
16+
aten.addcdiv,
17+
aten.addcdiv_,
18+
aten.addcmul,
19+
aten.addcmul_,
20+
aten.addr,
21+
aten.aminmax,
22+
aten.arange.default,
23+
aten.arange.start,
24+
aten.avg_pool2d_backward,
25+
aten.binary_cross_entropy,
26+
aten.binary_cross_entropy_backward,
27+
aten.binary_cross_entropy_with_logits,
28+
aten.celu,
29+
aten.col2im,
30+
aten.count_nonzero,
31+
aten.cudnn_batch_norm,
32+
aten.cudnn_batch_norm_backward,
33+
aten.deg2rad,
34+
aten.detach,
35+
aten.diag_embed,
36+
aten.diagonal_backward,
37+
aten.dot,
38+
aten.elu,
39+
aten.elu_backward,
40+
aten._embedding_bag,
41+
aten.embedding_dense_backward,
42+
aten._euclidean_dist.default,
43+
aten.expand_as,
44+
aten.eye,
45+
aten.fill,
46+
aten.frac,
47+
aten._fused_moving_avg_obs_fq_helper,
48+
aten.gelu,
49+
aten.gelu_backward,
50+
aten.glu_backward,
51+
aten.grid_sampler_2d,
52+
aten.hardshrink,
53+
aten.hardshrink_backward,
54+
aten.hardsigmoid,
55+
aten.hardsigmoid_backward,
56+
aten.hardswish,
57+
aten.hardswish_,
58+
aten.hardswish_backward,
59+
aten.hardtanh,
60+
aten.hardtanh_,
61+
aten.hardtanh_backward,
62+
aten.heaviside,
63+
aten.huber_loss,
64+
aten.huber_loss_backward,
65+
aten.im2col,
66+
aten.index_add,
67+
aten.index_add_,
68+
aten.index_copy,
69+
aten.index_copy_,
70+
aten.index_fill,
71+
aten.index_fill_,
72+
aten.index_select,
73+
aten.isneginf,
74+
aten.isposinf,
75+
aten.l1_loss,
76+
aten.leaky_relu,
77+
aten.leaky_relu_,
78+
aten.leaky_relu_backward,
79+
aten.lerp,
80+
aten.linspace,
81+
aten.logaddexp,
82+
aten.logaddexp2,
83+
aten.logit,
84+
aten.logit_backward,
85+
aten.log_sigmoid_backward,
86+
aten.log_sigmoid_forward,
87+
aten._log_softmax,
88+
aten._log_softmax_backward_data,
89+
aten.logspace,
90+
aten.logsumexp.default,
91+
aten.masked_fill,
92+
aten.masked_fill_,
93+
aten.max_pool2d_with_indices_backward,
94+
aten.mish,
95+
aten.mse_loss,
96+
aten.mse_loss_backward,
97+
aten.mv,
98+
aten.mvlgamma,
99+
aten.nansum,
100+
aten.nan_to_num,
101+
aten.narrow,
102+
# TODO: Disable the below operators once freezing is done
103+
aten.native_batch_norm,
104+
aten.native_batch_norm_backward,
105+
aten._native_batch_norm_legit,
106+
aten._native_batch_norm_legit_functional,
107+
aten._native_batch_norm_legit_no_training,
108+
aten.native_dropout_backward,
109+
aten.native_group_norm,
110+
aten.native_group_norm_backward,
111+
aten.native_layer_norm,
112+
aten.native_layer_norm_backward,
113+
aten.new_empty,
114+
aten.new_full,
115+
aten.new_ones,
116+
aten.new_zeros,
117+
aten.nll_loss_backward,
118+
aten.nll_loss_forward,
119+
aten.norm,
120+
aten.ones,
121+
aten.ones_like,
122+
aten._prelu_kernel,
123+
aten._prelu_kernel_backward,
124+
aten._reshape_alias,
125+
aten.rad2deg,
126+
aten.renorm,
127+
aten.renorm_,
128+
aten.rot90,
129+
aten.rsub.Scalar,
130+
aten.rsub.Tensor,
131+
aten.select_backward,
132+
aten.select_scatter,
133+
aten.sgn,
134+
aten.sigmoid_backward,
135+
aten.silu,
136+
aten.silu_,
137+
aten.silu_backward,
138+
aten.sinc,
139+
aten.slice_backward,
140+
aten.smooth_l1_loss,
141+
aten.smooth_l1_loss_backward,
142+
aten.soft_margin_loss,
143+
aten.soft_margin_loss_backward,
144+
aten._softmax,
145+
aten._softmax_backward_data,
146+
aten.softplus,
147+
aten.softplus_backward,
148+
aten.softshrink,
149+
aten.softshrink_backward,
150+
aten.special_entr,
151+
aten.special_log_ndtr,
152+
aten.special_xlog1py,
153+
aten.stack,
154+
aten.t,
155+
aten.tanh_backward,
156+
aten.threshold,
157+
aten.threshold_backward,
158+
aten.trace,
159+
aten.transpose.int,
160+
aten.tril.default,
161+
aten.triu.default,
162+
aten.unfold,
163+
aten.unfold_backward,
164+
aten.unfold_copy,
165+
aten.upsample_bilinear2d,
166+
aten.upsample_bilinear2d.vec,
167+
aten.upsample_nearest2d_backward,
168+
aten.xlogy,
169+
aten.zero,
170+
aten.zero_,
171+
aten.zeros,
172+
aten.zeros_like,
173+
}
174+
disabled_decompositions: Set[torch._ops.OpOverload] = {}
6175

7-
aten = torch.ops.aten
176+
TORCH_DECOMPOSITIONS: Dict[torch._ops.OpOverload, Callable] = get_torch_decompositions(
177+
enabled_decompositions
178+
)
179+
TORCH_EXPERIMENTAL_DECOMPOSITIONS: Dict[torch._ops.OpOverload, Callable] = {
180+
decomp: _core_aten_decompositions[decomp]
181+
for decomp in _core_aten_decompositions
182+
if decomp not in disabled_decompositions
183+
}
184+
CUSTOM_DECOMPOSITIONS: Dict[torch._ops.OpOverload, Callable] = {}
8185

9186

10187
def replace_inplace_op(aten_op, outplace_op):
@@ -13,7 +190,7 @@ def replace_inplace_op(aten_op, outplace_op):
13190
https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361
14191
"""
15192

16-
@register_decomposition(aten_op, registry=DECOMPOSITIONS)
193+
@register_decomposition(aten_op, registry=CUSTOM_DECOMPOSITIONS)
17194
def inplace_op(*args, **kwargs):
18195
out = outplace_op(*args, **kwargs)
19196
return args[0].copy_(out)
@@ -36,32 +213,32 @@ def inplace_op(*args, **kwargs):
36213
replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce)
37214

38215

39-
@register_decomposition(aten.std, registry=DECOMPOSITIONS)
216+
@register_decomposition(aten.std, registry=CUSTOM_DECOMPOSITIONS)
40217
def std_replacement(*args, **kwargs) -> torch.Tensor:
41218
return torch.sqrt(torch.var(*args, **kwargs))
42219

43220

44-
@register_decomposition(aten.rsqrt, registry=DECOMPOSITIONS)
221+
@register_decomposition(aten.rsqrt, registry=CUSTOM_DECOMPOSITIONS)
45222
def rsqrt_replacement(*args, **kwargs) -> torch.Tensor:
46223
return torch.reciprocal(torch.sqrt(*args, **kwargs))
47224

48225

49-
@register_decomposition(aten._unsafe_view, registry=DECOMPOSITIONS)
226+
@register_decomposition(aten._unsafe_view, registry=CUSTOM_DECOMPOSITIONS)
50227
def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
51228
return torch.reshape(x, *args, **kwargs)
52229

53230

54-
@register_decomposition(torch.ops.aten.lift_fresh_copy, registry=DECOMPOSITIONS)
231+
@register_decomposition(torch.ops.aten.lift_fresh_copy, registry=CUSTOM_DECOMPOSITIONS)
55232
def lift_fresh_copy_replacement(x: torch.Tensor) -> torch.Tensor:
56233
return x
57234

58235

59-
@register_decomposition(aten.alias, registry=DECOMPOSITIONS)
236+
@register_decomposition(aten.alias, registry=CUSTOM_DECOMPOSITIONS)
60237
def alias_replacement(x: torch.Tensor) -> torch.Tensor:
61238
return x
62239

63240

64-
@register_decomposition(torch.ops.aten.addmm, registry=DECOMPOSITIONS)
241+
@register_decomposition(torch.ops.aten.addmm, registry=CUSTOM_DECOMPOSITIONS)
65242
def addmm_replacement(
66243
input_: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, *, beta=1, alpha=1
67244
) -> torch.Tensor:
@@ -70,12 +247,31 @@ def addmm_replacement(
70247
)
71248

72249

73-
@register_decomposition(torch.ops.aten.reciprocal.default, registry=DECOMPOSITIONS)
250+
@register_decomposition(
251+
torch.ops.aten.reciprocal.default, registry=CUSTOM_DECOMPOSITIONS
252+
)
74253
def reciprocal_replacement(
75254
input_: torch.Tensor,
76255
) -> torch.Tensor:
77256
return torch.div(1, input_)
78257

79258

80-
def get_decompositions():
81-
return DECOMPOSITIONS
259+
def get_decompositions(
260+
enable_experimental_decompositions: bool = False,
261+
) -> Dict[torch._ops.OpOverload, Callable]:
262+
if enable_experimental_decompositions:
263+
duplicate_registrations = set(
264+
TORCH_EXPERIMENTAL_DECOMPOSITIONS.keys()
265+
).intersection(set(CUSTOM_DECOMPOSITIONS.keys()))
266+
assert (
267+
not duplicate_registrations
268+
), f"Detected duplicate decompositions on: {duplicate_registrations}"
269+
return {**TORCH_EXPERIMENTAL_DECOMPOSITIONS, **CUSTOM_DECOMPOSITIONS}
270+
else:
271+
duplicate_registrations = set(TORCH_DECOMPOSITIONS.keys()).intersection(
272+
set(CUSTOM_DECOMPOSITIONS.keys())
273+
)
274+
assert (
275+
not duplicate_registrations
276+
), f"Detected duplicate decompositions on: {duplicate_registrations}"
277+
return {**TORCH_DECOMPOSITIONS, **CUSTOM_DECOMPOSITIONS}

0 commit comments

Comments
 (0)