4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from collections import OrderedDict
8
+ from typing import cast , Mapping , Optional
9
+
7
10
import torch
8
- from torch ._export .utils import get_buffer , get_param , is_buffer , is_param
11
+ from executorch .exir .dialects ._ops import ops as exir_ops
12
+ from executorch .exir .dialects .edge ._ops import EdgeOpOverload
13
+ from torch ._export .utils import (
14
+ get_buffer ,
15
+ get_lifted_tensor_constant ,
16
+ get_param ,
17
+ is_buffer ,
18
+ is_lifted_tensor_constant ,
19
+ is_param ,
20
+ )
9
21
from torch ._guards import detect_fake_mode
10
22
from torch .export import ExportedProgram
11
23
from torch .export .exported_program import InputKind , InputSpec , TensorArgument
24
+ from torch .utils import _pytree as pytree
25
+
26
+
27
+ # Avoid propagating constants for `exir.ops.edge.aten.full.default`.
28
+ # Propagating aten.full can significantly increase compiled model size.
29
+ _DEFAULT_SKIP_TARGETS = {exir_ops .edge .aten .full .default }
12
30
31
+ _PRIMITIVE_TYPES = (
32
+ float ,
33
+ int ,
34
+ bool ,
35
+ str ,
36
+ torch .Tensor ,
37
+ torch .device ,
38
+ torch .dtype ,
39
+ torch .layout ,
40
+ )
13
41
14
- def is_const (arg , exported_program , const_data_list ) -> bool :
42
+
43
+ def is_const (
44
+ arg ,
45
+ exported_program : ExportedProgram ,
46
+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
47
+ ) -> bool :
15
48
if isinstance (arg , (tuple , list )):
16
- return all (is_const (x , exported_program , const_data_list ) for x in arg )
49
+ return all (is_const (x , exported_program , const_node_to_tensor ) for x in arg )
17
50
elif isinstance (arg , dict ):
18
- return all (is_const (x , exported_program , const_data_list ) for x in arg .values ())
19
- elif not isinstance (arg , torch .fx .Node ) or arg .op != "placeholder" :
51
+ return all (
52
+ is_const (x , exported_program , const_node_to_tensor ) for x in arg .values ()
53
+ )
54
+ elif isinstance (arg , _PRIMITIVE_TYPES ):
55
+ return True
56
+ elif not isinstance (arg , torch .fx .Node ):
20
57
return False
21
- elif (
22
- is_param (exported_program , arg )
23
- or is_buffer (exported_program , arg )
24
- or arg .name in const_data_list
25
- ):
58
+ elif arg in const_node_to_tensor :
26
59
return True
27
60
return False
28
61
29
62
30
- def get_data (exported_program , arg ):
63
+ def get_data (
64
+ arg ,
65
+ exported_program : ExportedProgram ,
66
+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
67
+ ):
31
68
if isinstance (arg , (tuple , list )):
32
- return [get_data (exported_program , x ) for x in arg ]
33
- elif is_param (exported_program , arg ):
34
- return get_param (exported_program , arg )
35
- elif is_buffer (exported_program , arg ):
36
- return get_buffer (exported_program , arg )
69
+ return type (arg )(
70
+ get_data (x , exported_program , const_node_to_tensor ) for x in arg
71
+ )
72
+ elif isinstance (arg , _PRIMITIVE_TYPES ):
73
+ return arg
74
+ elif arg in const_node_to_tensor :
75
+ return const_node_to_tensor [arg ]
37
76
return None
38
77
39
78
40
- def constant_prop_pass (exported_program : ExportedProgram ) -> ExportedProgram :
79
+ def get_constant_placeholder_dict (
80
+ exported_program : ExportedProgram ,
81
+ ) -> OrderedDict [torch .fx .Node , torch .Tensor ]:
41
82
"""
42
- This pass is for constant propagation for Exported Program with lifted parameters,
43
- as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
83
+ Returns a dictionary of placeholder node -> constant tensor.
44
84
"""
45
- if (
46
- len ([node for node in exported_program .graph .nodes if node .op == "placeholder" ])
47
- == 0
48
- ):
49
- return exported_program
85
+ const_node_to_tensor : OrderedDict [torch .fx .Node , torch .Tensor ] = OrderedDict ()
86
+ for node in exported_program .graph .nodes :
87
+ if node .op != "placeholder" :
88
+ continue
89
+
90
+ if is_param (exported_program , node ):
91
+ const_node_to_tensor [node ] = cast (
92
+ torch .Tensor , get_param (exported_program , node )
93
+ )
94
+ elif is_buffer (exported_program , node ):
95
+ const_node_to_tensor [node ] = cast (
96
+ torch .Tensor , get_buffer (exported_program , node )
97
+ )
98
+ elif is_lifted_tensor_constant (exported_program , node ):
99
+ const_node_to_tensor [node ] = cast (
100
+ torch .Tensor , get_lifted_tensor_constant (exported_program , node )
101
+ )
102
+ return const_node_to_tensor
50
103
51
- has_cond = [
52
- node
53
- for node in exported_program .graph .nodes
54
- if node .target == torch .ops .higher_order .cond
55
- ]
56
- if len (has_cond ) > 0 :
57
- raise RuntimeError ("constant_prop_pass for control flow is not supported yet." )
58
104
105
+ def get_propagated_const_tensor_dict (
106
+ exported_program : ExportedProgram ,
107
+ custom_skip_targets : Optional [set [EdgeOpOverload ]],
108
+ ) -> OrderedDict [torch .fx .Node , torch .Tensor ]:
109
+ """
110
+ Propagates constants and returns a dictionary of node->constant tensors.
111
+ """
112
+ # Initialize dict with all constant placeholders.
113
+ const_node_to_tensor = get_constant_placeholder_dict (exported_program )
114
+
115
+ all_skip_targets : set [EdgeOpOverload ] = set ()
116
+ # Default set of targets to skip.
117
+ all_skip_targets .update (_DEFAULT_SKIP_TARGETS )
118
+ if custom_skip_targets is not None :
119
+ all_skip_targets .update (custom_skip_targets )
120
+
121
+ for node in exported_program .graph .nodes :
122
+ if node .op != "call_function" or node .target in all_skip_targets :
123
+ continue
124
+
125
+ if not is_const (
126
+ node .args ,
127
+ exported_program ,
128
+ const_node_to_tensor ,
129
+ ):
130
+ continue
131
+
132
+ args_data , kwargs_data = pytree .tree_map (
133
+ lambda x : get_data (x , exported_program , const_node_to_tensor ),
134
+ (node .args , node .kwargs ),
135
+ )
136
+
137
+ # Execute the `node.target` and create a new propagated constant tensor.
138
+ prop_constant_tensor = node .target (* args_data , ** kwargs_data )
139
+ const_node_to_tensor [node ] = prop_constant_tensor
140
+
141
+ return const_node_to_tensor
142
+
143
+
144
+ def get_first_user_input (exported_program : ExportedProgram ) -> torch .fx .Node :
145
+ """Returns the first user input node in the graph."""
59
146
first_user_input = None
60
147
for node in exported_program .graph .nodes :
61
148
if (
@@ -64,11 +151,42 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
64
151
):
65
152
first_user_input = node
66
153
break
154
+ return first_user_input
155
+
156
+
157
+ def replace_with_constant_node (
158
+ node : torch .fx .Node ,
159
+ prop_constant_tensor : torch .Tensor ,
160
+ first_user_input : torch .fx .Node ,
161
+ fake_mode ,
162
+ exported_program : ExportedProgram ,
163
+ ) -> tuple [torch .fx .Node , str ]:
164
+ # Add `prop_constant_tensor` to program.state_dict.
165
+ prop_constant_tensor_fqn = f"_prop_tensor_constant{ len (exported_program .constants )} "
166
+ exported_program .constants [prop_constant_tensor_fqn ] = prop_constant_tensor
167
+
168
+ # Insert a new placeholder node for the propagated constant tensor.
169
+ with exported_program .graph .inserting_before (first_user_input ):
170
+ const_placeholder_node = exported_program .graph .placeholder (
171
+ prop_constant_tensor_fqn
172
+ )
173
+
174
+ # Update the meta data of the new placeholder (buffer) node.
175
+ for k , v in node .meta .items ():
176
+ const_placeholder_node .meta [k ] = v
177
+ const_placeholder_node .meta ["val" ] = fake_mode .from_tensor (
178
+ prop_constant_tensor , static_shapes = True
179
+ )
180
+ const_placeholder_node .meta ["val" ].constant = prop_constant_tensor
181
+
182
+ # Replace the original node with the new constant node.
183
+ node .replace_all_uses_with (const_placeholder_node )
184
+ exported_program .graph .erase_node (node )
185
+
186
+ return const_placeholder_node , prop_constant_tensor_fqn
67
187
68
- buffers = exported_program .graph_signature .buffers
69
- prop_constant_data = []
70
- const_data_to_be_removed = set ()
71
188
189
+ def get_fake_mode (exported_program : ExportedProgram ):
72
190
fake_mode = detect_fake_mode (
73
191
tuple (
74
192
node .meta ["val" ]
@@ -77,57 +195,115 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
77
195
)
78
196
)
79
197
assert fake_mode is not None
198
+ return fake_mode
80
199
200
+
201
+ def erase_constant_node (
202
+ exported_program : ExportedProgram ,
203
+ node : torch .fx .Node ,
204
+ ) -> None :
205
+ # Remove corresponding tensor from param/constants dict.
206
+ signature = exported_program .graph_signature
207
+ if name := signature .inputs_to_parameters .pop (node .name , None ):
208
+ exported_program .state_dict .pop (name , None )
209
+ elif name := signature .inputs_to_lifted_tensor_constants .pop (node .name , None ):
210
+ exported_program .constants .pop (name , None )
211
+ elif name := signature .inputs_to_buffers .pop (node .name , None ):
212
+ exported_program .constants .pop (name , None )
213
+ exported_program .state_dict .pop (name , None )
214
+
215
+ # Remove from graph.
216
+ exported_program .graph .erase_node (node )
217
+
218
+
219
+ def create_constant_nodes_and_return_specs (
220
+ const_node_to_tensor : Mapping [torch .fx .Node , torch .Tensor ],
221
+ exported_program : ExportedProgram ,
222
+ ) -> dict [str , InputSpec ]:
223
+ """
224
+ Creates constant nodes for all entries in `const_node_to_tensor` and returns a node.name -> InputSpec dict.
225
+ """
226
+ name_to_spec_dict : dict [str , InputSpec ] = {}
227
+
228
+ fake_mode = get_fake_mode (exported_program )
229
+ first_user_input = get_first_user_input (exported_program )
230
+
231
+ # Iterate over nodes in reverse order.
232
+ for node , prop_constant_tensor in reversed (const_node_to_tensor .items ()):
233
+ if all (x in const_node_to_tensor for x in node .users ):
234
+ # All users of this constant node are also constant, so we don't need to create a new constant node.
235
+ erase_constant_node (exported_program , node )
236
+ continue
237
+
238
+ if node .op == "placeholder" :
239
+ continue
240
+
241
+ const_placeholder_node , prop_constant_tensor_fqn = replace_with_constant_node (
242
+ node , prop_constant_tensor , first_user_input , fake_mode , exported_program
243
+ )
244
+
245
+ # Create input spec for lifted constant.
246
+ name_to_spec_dict [const_placeholder_node .name ] = InputSpec (
247
+ kind = InputKind .CONSTANT_TENSOR ,
248
+ arg = TensorArgument (name = const_placeholder_node .name ),
249
+ target = prop_constant_tensor_fqn ,
250
+ persistent = True ,
251
+ )
252
+ return name_to_spec_dict
253
+
254
+
255
+ def constant_prop_pass (
256
+ exported_program : ExportedProgram ,
257
+ custom_skip_targets : Optional [set [EdgeOpOverload ]] = None ,
258
+ ) -> ExportedProgram :
259
+ """
260
+ This pass is for constant propagation for Exported Program with lifted parameters,
261
+ as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
262
+
263
+ Args:
264
+ exported_program: The ExportedProgram to perform constant propagation on.
265
+ custom_skip_targets: Optional set of EdgeOpOverload targets to skip during constant propagation.
266
+
267
+ Returns:
268
+ The modified ExportedProgram with constant propagation applied.
269
+ """
270
+ if (
271
+ len ([node for node in exported_program .graph .nodes if node .op == "placeholder" ])
272
+ == 0
273
+ ):
274
+ return exported_program
275
+
276
+ has_control_flow = [
277
+ node
278
+ for node in exported_program .graph .nodes
279
+ if node .target == torch .ops .higher_order .cond
280
+ ]
281
+ if len (has_control_flow ) > 0 :
282
+ raise RuntimeError ("constant_prop_pass for control flow is not supported yet." )
283
+
284
+ const_node_to_tensor = get_propagated_const_tensor_dict (
285
+ exported_program , custom_skip_targets
286
+ )
287
+
288
+ # Get old input specs.
289
+ name_to_spec_dict = {
290
+ s .arg .name : s for s in exported_program .graph_signature .input_specs
291
+ }
292
+ # Add the new constants to input specs dict.
293
+ name_to_spec_dict .update (
294
+ create_constant_nodes_and_return_specs (const_node_to_tensor , exported_program )
295
+ )
296
+
297
+ # Generate new input spec.
298
+ new_input_specs = []
81
299
for node in exported_program .graph .nodes :
82
- if node .op == "call_function" :
83
- constant_data_name_list = [
84
- input_spec .target for input_spec in prop_constant_data
85
- ]
86
- if is_const (node .args , exported_program , constant_data_name_list ):
87
- args_data = [get_data (exported_program , arg ) for arg in node .args ]
88
- kwargs_data = node .kwargs
89
- const_data_to_be_removed .update (node .args )
90
- prop_constant_tensor = node .target (* args_data , ** kwargs_data )
91
- prop_constant_tensor_fqn = f"_prop_tensor_constant{ len (buffers )} "
92
-
93
- with exported_program .graph .inserting_before (first_user_input ):
94
- const_placeholder_node = exported_program .graph .placeholder (
95
- prop_constant_tensor_fqn
96
- )
97
- # Update the meta data of the new placeholder (buffer) node
98
- for k , v in node .meta .items ():
99
- const_placeholder_node .meta [k ] = v
100
- const_placeholder_node .meta ["val" ] = fake_mode .from_tensor (
101
- prop_constant_tensor , static_shapes = True
102
- )
103
- const_placeholder_node .meta ["val" ].constant = prop_constant_tensor
104
-
105
- node .replace_all_uses_with (const_placeholder_node )
106
- exported_program .graph .erase_node (node )
107
- prop_constant_node_input_spec = InputSpec (
108
- kind = InputKind .BUFFER ,
109
- arg = TensorArgument (name = const_placeholder_node .name ),
110
- target = prop_constant_tensor_fqn ,
111
- persistent = True ,
112
- )
113
- prop_constant_data .append (prop_constant_node_input_spec )
114
- buffers .append (prop_constant_tensor_fqn )
115
- exported_program .state_dict [prop_constant_tensor_fqn ] = (
116
- prop_constant_tensor
117
- )
118
- exported_program .graph_signature .input_specs .append (
119
- prop_constant_node_input_spec
120
- )
121
-
122
- # Remove the propogated buffer from the state dict
123
- for node in exported_program .graph .nodes :
124
- if (
125
- node .op == "placeholder"
126
- and node in const_data_to_be_removed
127
- and len (node .users ) == 0
128
- ):
129
- exported_program .state_dict .pop (node .name , None )
130
- exported_program .graph .erase_node (node )
300
+ if node .op != "placeholder" :
301
+ continue
302
+ new_input_specs .append (name_to_spec_dict [node .name ])
303
+ exported_program .graph_signature .input_specs = new_input_specs
131
304
305
+ # Cleanup the graph.
306
+ exported_program .graph .eliminate_dead_code ()
132
307
exported_program .graph_module .recompile ()
308
+
133
309
return exported_program
0 commit comments