1
+ from typing import Callable , Dict , Set
1
2
import torch
2
- from torch ._decomp import register_decomposition , core_aten_decompositions
3
+ from torch ._decomp import (
4
+ register_decomposition ,
5
+ core_aten_decompositions ,
6
+ get_decompositions as get_torch_decompositions ,
7
+ )
3
8
9
+ aten = torch .ops .aten
4
10
5
- DECOMPOSITIONS = {** core_aten_decompositions ()}
11
+ _core_aten_decompositions : Dict [
12
+ torch ._ops .OpOverload , Callable
13
+ ] = core_aten_decompositions ()
14
+ enabled_decompositions : Set [torch ._ops .OpOverload ] = {
15
+ aten ._adaptive_avg_pool2d_backward ,
16
+ aten .addcdiv ,
17
+ aten .addcdiv_ ,
18
+ aten .addcmul ,
19
+ aten .addcmul_ ,
20
+ aten .addr ,
21
+ aten .aminmax ,
22
+ aten .arange .default ,
23
+ aten .arange .start ,
24
+ aten .avg_pool2d_backward ,
25
+ aten .binary_cross_entropy ,
26
+ aten .binary_cross_entropy_backward ,
27
+ aten .binary_cross_entropy_with_logits ,
28
+ aten .celu ,
29
+ aten .col2im ,
30
+ aten .count_nonzero ,
31
+ aten .cudnn_batch_norm ,
32
+ aten .cudnn_batch_norm_backward ,
33
+ aten .deg2rad ,
34
+ aten .detach ,
35
+ aten .diag_embed ,
36
+ aten .diagonal_backward ,
37
+ aten .dot ,
38
+ aten .elu ,
39
+ aten .elu_backward ,
40
+ aten ._embedding_bag ,
41
+ aten .embedding_dense_backward ,
42
+ aten ._euclidean_dist .default ,
43
+ aten .expand_as ,
44
+ aten .eye ,
45
+ aten .fill ,
46
+ aten .frac ,
47
+ aten ._fused_moving_avg_obs_fq_helper ,
48
+ aten .gelu ,
49
+ aten .gelu_backward ,
50
+ aten .glu_backward ,
51
+ aten .grid_sampler_2d ,
52
+ aten .hardshrink ,
53
+ aten .hardshrink_backward ,
54
+ aten .hardsigmoid ,
55
+ aten .hardsigmoid_backward ,
56
+ aten .hardswish ,
57
+ aten .hardswish_ ,
58
+ aten .hardswish_backward ,
59
+ aten .hardtanh ,
60
+ aten .hardtanh_ ,
61
+ aten .hardtanh_backward ,
62
+ aten .heaviside ,
63
+ aten .huber_loss ,
64
+ aten .huber_loss_backward ,
65
+ aten .im2col ,
66
+ aten .index_add ,
67
+ aten .index_add_ ,
68
+ aten .index_copy ,
69
+ aten .index_copy_ ,
70
+ aten .index_fill ,
71
+ aten .index_fill_ ,
72
+ aten .index_select ,
73
+ aten .isneginf ,
74
+ aten .isposinf ,
75
+ aten .l1_loss ,
76
+ aten .leaky_relu ,
77
+ aten .leaky_relu_ ,
78
+ aten .leaky_relu_backward ,
79
+ aten .lerp ,
80
+ aten .linspace ,
81
+ aten .logaddexp ,
82
+ aten .logaddexp2 ,
83
+ aten .logit ,
84
+ aten .logit_backward ,
85
+ aten .log_sigmoid_backward ,
86
+ aten .log_sigmoid_forward ,
87
+ aten ._log_softmax ,
88
+ aten ._log_softmax_backward_data ,
89
+ aten .logspace ,
90
+ aten .logsumexp .default ,
91
+ aten .masked_fill ,
92
+ aten .masked_fill_ ,
93
+ aten .max_pool2d_with_indices_backward ,
94
+ aten .mish ,
95
+ aten .mse_loss ,
96
+ aten .mse_loss_backward ,
97
+ aten .mv ,
98
+ aten .mvlgamma ,
99
+ aten .nansum ,
100
+ aten .nan_to_num ,
101
+ aten .narrow ,
102
+ # TODO: Disable the below operators once freezing is done
103
+ aten .native_batch_norm ,
104
+ aten .native_batch_norm_backward ,
105
+ aten ._native_batch_norm_legit ,
106
+ aten ._native_batch_norm_legit_functional ,
107
+ aten ._native_batch_norm_legit_no_training ,
108
+ aten .native_dropout_backward ,
109
+ aten .native_group_norm ,
110
+ aten .native_group_norm_backward ,
111
+ aten .native_layer_norm ,
112
+ aten .native_layer_norm_backward ,
113
+ aten .new_empty ,
114
+ aten .new_full ,
115
+ aten .new_ones ,
116
+ aten .new_zeros ,
117
+ aten .nll_loss_backward ,
118
+ aten .nll_loss_forward ,
119
+ aten .norm ,
120
+ aten .ones ,
121
+ aten .ones_like ,
122
+ aten ._prelu_kernel ,
123
+ aten ._prelu_kernel_backward ,
124
+ aten ._reshape_alias ,
125
+ aten .rad2deg ,
126
+ aten .renorm ,
127
+ aten .renorm_ ,
128
+ aten .rot90 ,
129
+ aten .rsub .Scalar ,
130
+ aten .rsub .Tensor ,
131
+ aten .select_backward ,
132
+ aten .select_scatter ,
133
+ aten .sgn ,
134
+ aten .sigmoid_backward ,
135
+ aten .silu ,
136
+ aten .silu_ ,
137
+ aten .silu_backward ,
138
+ aten .sinc ,
139
+ aten .slice_backward ,
140
+ aten .smooth_l1_loss ,
141
+ aten .smooth_l1_loss_backward ,
142
+ aten .soft_margin_loss ,
143
+ aten .soft_margin_loss_backward ,
144
+ aten ._softmax ,
145
+ aten ._softmax_backward_data ,
146
+ aten .softplus ,
147
+ aten .softplus_backward ,
148
+ aten .softshrink ,
149
+ aten .softshrink_backward ,
150
+ aten .special_entr ,
151
+ aten .special_log_ndtr ,
152
+ aten .special_xlog1py ,
153
+ aten .stack ,
154
+ aten .t ,
155
+ aten .tanh_backward ,
156
+ aten .threshold ,
157
+ aten .threshold_backward ,
158
+ aten .trace ,
159
+ aten .transpose .int ,
160
+ aten .tril .default ,
161
+ aten .triu .default ,
162
+ aten .unfold ,
163
+ aten .unfold_backward ,
164
+ aten .unfold_copy ,
165
+ aten .upsample_bilinear2d ,
166
+ aten .upsample_bilinear2d .vec ,
167
+ aten .upsample_nearest2d_backward ,
168
+ aten .xlogy ,
169
+ aten .zero ,
170
+ aten .zero_ ,
171
+ aten .zeros ,
172
+ aten .zeros_like ,
173
+ }
174
+ disabled_decompositions : Set [torch ._ops .OpOverload ] = {}
6
175
7
- aten = torch .ops .aten
176
+ TORCH_DECOMPOSITIONS : Dict [torch ._ops .OpOverload , Callable ] = get_torch_decompositions (
177
+ enabled_decompositions
178
+ )
179
+ TORCH_EXPERIMENTAL_DECOMPOSITIONS : Dict [torch ._ops .OpOverload , Callable ] = {
180
+ decomp : _core_aten_decompositions [decomp ]
181
+ for decomp in _core_aten_decompositions
182
+ if decomp not in disabled_decompositions
183
+ }
184
+ CUSTOM_DECOMPOSITIONS : Dict [torch ._ops .OpOverload , Callable ] = {}
8
185
9
186
10
187
def replace_inplace_op (aten_op , outplace_op ):
@@ -13,7 +190,7 @@ def replace_inplace_op(aten_op, outplace_op):
13
190
https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361
14
191
"""
15
192
16
- @register_decomposition (aten_op , registry = DECOMPOSITIONS )
193
+ @register_decomposition (aten_op , registry = CUSTOM_DECOMPOSITIONS )
17
194
def inplace_op (* args , ** kwargs ):
18
195
out = outplace_op (* args , ** kwargs )
19
196
return args [0 ].copy_ (out )
@@ -36,32 +213,32 @@ def inplace_op(*args, **kwargs):
36
213
replace_inplace_op (aten .scatter_reduce_ , aten .scatter_reduce )
37
214
38
215
39
- @register_decomposition (aten .std , registry = DECOMPOSITIONS )
216
+ @register_decomposition (aten .std , registry = CUSTOM_DECOMPOSITIONS )
40
217
def std_replacement (* args , ** kwargs ) -> torch .Tensor :
41
218
return torch .sqrt (torch .var (* args , ** kwargs ))
42
219
43
220
44
- @register_decomposition (aten .rsqrt , registry = DECOMPOSITIONS )
221
+ @register_decomposition (aten .rsqrt , registry = CUSTOM_DECOMPOSITIONS )
45
222
def rsqrt_replacement (* args , ** kwargs ) -> torch .Tensor :
46
223
return torch .reciprocal (torch .sqrt (* args , ** kwargs ))
47
224
48
225
49
- @register_decomposition (aten ._unsafe_view , registry = DECOMPOSITIONS )
226
+ @register_decomposition (aten ._unsafe_view , registry = CUSTOM_DECOMPOSITIONS )
50
227
def unsafe_view_replacement (x : torch .Tensor , * args , ** kwargs ) -> torch .Tensor :
51
228
return torch .reshape (x , * args , ** kwargs )
52
229
53
230
54
- @register_decomposition (torch .ops .aten .lift_fresh_copy , registry = DECOMPOSITIONS )
231
+ @register_decomposition (torch .ops .aten .lift_fresh_copy , registry = CUSTOM_DECOMPOSITIONS )
55
232
def lift_fresh_copy_replacement (x : torch .Tensor ) -> torch .Tensor :
56
233
return x
57
234
58
235
59
- @register_decomposition (aten .alias , registry = DECOMPOSITIONS )
236
+ @register_decomposition (aten .alias , registry = CUSTOM_DECOMPOSITIONS )
60
237
def alias_replacement (x : torch .Tensor ) -> torch .Tensor :
61
238
return x
62
239
63
240
64
- @register_decomposition (torch .ops .aten .addmm , registry = DECOMPOSITIONS )
241
+ @register_decomposition (torch .ops .aten .addmm , registry = CUSTOM_DECOMPOSITIONS )
65
242
def addmm_replacement (
66
243
input_ : torch .Tensor , mat1 : torch .Tensor , mat2 : torch .Tensor , * , beta = 1 , alpha = 1
67
244
) -> torch .Tensor :
@@ -70,12 +247,31 @@ def addmm_replacement(
70
247
)
71
248
72
249
73
- @register_decomposition (torch .ops .aten .reciprocal .default , registry = DECOMPOSITIONS )
250
+ @register_decomposition (
251
+ torch .ops .aten .reciprocal .default , registry = CUSTOM_DECOMPOSITIONS
252
+ )
74
253
def reciprocal_replacement (
75
254
input_ : torch .Tensor ,
76
255
) -> torch .Tensor :
77
256
return torch .div (1 , input_ )
78
257
79
258
80
- def get_decompositions ():
81
- return DECOMPOSITIONS
259
+ def get_decompositions (
260
+ enable_experimental_decompositions : bool = False ,
261
+ ) -> Dict [torch ._ops .OpOverload , Callable ]:
262
+ if enable_experimental_decompositions :
263
+ duplicate_registrations = set (
264
+ TORCH_EXPERIMENTAL_DECOMPOSITIONS .keys ()
265
+ ).intersection (set (CUSTOM_DECOMPOSITIONS .keys ()))
266
+ assert (
267
+ not duplicate_registrations
268
+ ), f"Detected duplicate decompositions on: { duplicate_registrations } "
269
+ return {** TORCH_EXPERIMENTAL_DECOMPOSITIONS , ** CUSTOM_DECOMPOSITIONS }
270
+ else :
271
+ duplicate_registrations = set (TORCH_DECOMPOSITIONS .keys ()).intersection (
272
+ set (CUSTOM_DECOMPOSITIONS .keys ())
273
+ )
274
+ assert (
275
+ not duplicate_registrations
276
+ ), f"Detected duplicate decompositions on: { duplicate_registrations } "
277
+ return {** TORCH_DECOMPOSITIONS , ** CUSTOM_DECOMPOSITIONS }
0 commit comments