Skip to content

Commit 5ef8427

Browse files
hsharma35facebook-github-bot
authored andcommitted
Extend constant prop pass to work with int/float/etc scalars and fix input specs. (#2950)
Summary: Pull Request resolved: #2950 1. Cleanup / Refactor constant prop pass. 2. Enable constant propagation for ops with constant scalar arguments -- int/float/dtype/bool/str. Nodes of type `Op(constant_tensor, some_int, some_float, some_dtype, ...)` can now be constant propagated. 3. Fix order of input spec to match the expected spec in `ExportGraphSignature` class. parameters->buffers->constants->user_inputs. Before this diff, input_specs for the newly added constant tensors were appended to graph_signature, which would cause failures. Reviewed By: dulinriley Differential Revision: D55891278 fbshipit-source-id: fe1867cb6a99d0140d6a2e076027688cb1ddc0cd
1 parent 3b727a7 commit 5ef8427

File tree

3 files changed

+362
-84
lines changed

3 files changed

+362
-84
lines changed

exir/passes/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ python_library(
9292
],
9393
deps = [
9494
"//caffe2:torch",
95+
"//executorch/exir/dialects:lib",
96+
"//executorch/exir/dialects/edge:lib",
9597
],
9698
)
9799

exir/passes/constant_prop_pass.py

Lines changed: 259 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -4,58 +4,145 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from collections import OrderedDict
8+
from typing import cast, Mapping, Optional
9+
710
import torch
8-
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
13+
from torch._export.utils import (
14+
get_buffer,
15+
get_lifted_tensor_constant,
16+
get_param,
17+
is_buffer,
18+
is_lifted_tensor_constant,
19+
is_param,
20+
)
921
from torch._guards import detect_fake_mode
1022
from torch.export import ExportedProgram
1123
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
24+
from torch.utils import _pytree as pytree
25+
26+
27+
# Avoid propagating constants for `exir.ops.edge.aten.full.default`.
28+
# Propagating aten.full can significantly increase compiled model size.
29+
_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default}
1230

31+
_PRIMITIVE_TYPES = (
32+
float,
33+
int,
34+
bool,
35+
str,
36+
torch.Tensor,
37+
torch.device,
38+
torch.dtype,
39+
torch.layout,
40+
)
1341

14-
def is_const(arg, exported_program, const_data_list) -> bool:
42+
43+
def is_const(
44+
arg,
45+
exported_program: ExportedProgram,
46+
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
47+
) -> bool:
1548
if isinstance(arg, (tuple, list)):
16-
return all(is_const(x, exported_program, const_data_list) for x in arg)
49+
return all(is_const(x, exported_program, const_node_to_tensor) for x in arg)
1750
elif isinstance(arg, dict):
18-
return all(is_const(x, exported_program, const_data_list) for x in arg.values())
19-
elif not isinstance(arg, torch.fx.Node) or arg.op != "placeholder":
51+
return all(
52+
is_const(x, exported_program, const_node_to_tensor) for x in arg.values()
53+
)
54+
elif isinstance(arg, _PRIMITIVE_TYPES):
55+
return True
56+
elif not isinstance(arg, torch.fx.Node):
2057
return False
21-
elif (
22-
is_param(exported_program, arg)
23-
or is_buffer(exported_program, arg)
24-
or arg.name in const_data_list
25-
):
58+
elif arg in const_node_to_tensor:
2659
return True
2760
return False
2861

2962

30-
def get_data(exported_program, arg):
63+
def get_data(
64+
arg,
65+
exported_program: ExportedProgram,
66+
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
67+
):
3168
if isinstance(arg, (tuple, list)):
32-
return [get_data(exported_program, x) for x in arg]
33-
elif is_param(exported_program, arg):
34-
return get_param(exported_program, arg)
35-
elif is_buffer(exported_program, arg):
36-
return get_buffer(exported_program, arg)
69+
return type(arg)(
70+
get_data(x, exported_program, const_node_to_tensor) for x in arg
71+
)
72+
elif isinstance(arg, _PRIMITIVE_TYPES):
73+
return arg
74+
elif arg in const_node_to_tensor:
75+
return const_node_to_tensor[arg]
3776
return None
3877

3978

40-
def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
79+
def get_constant_placeholder_dict(
80+
exported_program: ExportedProgram,
81+
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
4182
"""
42-
This pass is for constant propagation for Exported Program with lifted parameters,
43-
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
83+
Returns a dictionary of placeholder node -> constant tensor.
4484
"""
45-
if (
46-
len([node for node in exported_program.graph.nodes if node.op == "placeholder"])
47-
== 0
48-
):
49-
return exported_program
85+
const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict()
86+
for node in exported_program.graph.nodes:
87+
if node.op != "placeholder":
88+
continue
89+
90+
if is_param(exported_program, node):
91+
const_node_to_tensor[node] = cast(
92+
torch.Tensor, get_param(exported_program, node)
93+
)
94+
elif is_buffer(exported_program, node):
95+
const_node_to_tensor[node] = cast(
96+
torch.Tensor, get_buffer(exported_program, node)
97+
)
98+
elif is_lifted_tensor_constant(exported_program, node):
99+
const_node_to_tensor[node] = cast(
100+
torch.Tensor, get_lifted_tensor_constant(exported_program, node)
101+
)
102+
return const_node_to_tensor
50103

51-
has_cond = [
52-
node
53-
for node in exported_program.graph.nodes
54-
if node.target == torch.ops.higher_order.cond
55-
]
56-
if len(has_cond) > 0:
57-
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
58104

105+
def get_propagated_const_tensor_dict(
106+
exported_program: ExportedProgram,
107+
custom_skip_targets: Optional[set[EdgeOpOverload]],
108+
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
109+
"""
110+
Propagates constants and returns a dictionary of node->constant tensors.
111+
"""
112+
# Initialize dict with all constant placeholders.
113+
const_node_to_tensor = get_constant_placeholder_dict(exported_program)
114+
115+
all_skip_targets: set[EdgeOpOverload] = set()
116+
# Default set of targets to skip.
117+
all_skip_targets.update(_DEFAULT_SKIP_TARGETS)
118+
if custom_skip_targets is not None:
119+
all_skip_targets.update(custom_skip_targets)
120+
121+
for node in exported_program.graph.nodes:
122+
if node.op != "call_function" or node.target in all_skip_targets:
123+
continue
124+
125+
if not is_const(
126+
node.args,
127+
exported_program,
128+
const_node_to_tensor,
129+
):
130+
continue
131+
132+
args_data, kwargs_data = pytree.tree_map(
133+
lambda x: get_data(x, exported_program, const_node_to_tensor),
134+
(node.args, node.kwargs),
135+
)
136+
137+
# Execute the `node.target` and create a new propagated constant tensor.
138+
prop_constant_tensor = node.target(*args_data, **kwargs_data)
139+
const_node_to_tensor[node] = prop_constant_tensor
140+
141+
return const_node_to_tensor
142+
143+
144+
def get_first_user_input(exported_program: ExportedProgram) -> torch.fx.Node:
145+
"""Returns the first user input node in the graph."""
59146
first_user_input = None
60147
for node in exported_program.graph.nodes:
61148
if (
@@ -64,11 +151,42 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
64151
):
65152
first_user_input = node
66153
break
154+
return first_user_input
155+
156+
157+
def replace_with_constant_node(
158+
node: torch.fx.Node,
159+
prop_constant_tensor: torch.Tensor,
160+
first_user_input: torch.fx.Node,
161+
fake_mode,
162+
exported_program: ExportedProgram,
163+
) -> tuple[torch.fx.Node, str]:
164+
# Add `prop_constant_tensor` to program.state_dict.
165+
prop_constant_tensor_fqn = f"_prop_tensor_constant{len(exported_program.constants)}"
166+
exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor
167+
168+
# Insert a new placeholder node for the propagated constant tensor.
169+
with exported_program.graph.inserting_before(first_user_input):
170+
const_placeholder_node = exported_program.graph.placeholder(
171+
prop_constant_tensor_fqn
172+
)
173+
174+
# Update the meta data of the new placeholder (buffer) node.
175+
for k, v in node.meta.items():
176+
const_placeholder_node.meta[k] = v
177+
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
178+
prop_constant_tensor, static_shapes=True
179+
)
180+
const_placeholder_node.meta["val"].constant = prop_constant_tensor
181+
182+
# Replace the original node with the new constant node.
183+
node.replace_all_uses_with(const_placeholder_node)
184+
exported_program.graph.erase_node(node)
185+
186+
return const_placeholder_node, prop_constant_tensor_fqn
67187

68-
buffers = exported_program.graph_signature.buffers
69-
prop_constant_data = []
70-
const_data_to_be_removed = set()
71188

189+
def get_fake_mode(exported_program: ExportedProgram):
72190
fake_mode = detect_fake_mode(
73191
tuple(
74192
node.meta["val"]
@@ -77,57 +195,115 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
77195
)
78196
)
79197
assert fake_mode is not None
198+
return fake_mode
80199

200+
201+
def erase_constant_node(
202+
exported_program: ExportedProgram,
203+
node: torch.fx.Node,
204+
) -> None:
205+
# Remove corresponding tensor from param/constants dict.
206+
signature = exported_program.graph_signature
207+
if name := signature.inputs_to_parameters.pop(node.name, None):
208+
exported_program.state_dict.pop(name, None)
209+
elif name := signature.inputs_to_lifted_tensor_constants.pop(node.name, None):
210+
exported_program.constants.pop(name, None)
211+
elif name := signature.inputs_to_buffers.pop(node.name, None):
212+
exported_program.constants.pop(name, None)
213+
exported_program.state_dict.pop(name, None)
214+
215+
# Remove from graph.
216+
exported_program.graph.erase_node(node)
217+
218+
219+
def create_constant_nodes_and_return_specs(
220+
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
221+
exported_program: ExportedProgram,
222+
) -> dict[str, InputSpec]:
223+
"""
224+
Creates constant nodes for all entries in `const_node_to_tensor` and returns a node.name -> InputSpec dict.
225+
"""
226+
name_to_spec_dict: dict[str, InputSpec] = {}
227+
228+
fake_mode = get_fake_mode(exported_program)
229+
first_user_input = get_first_user_input(exported_program)
230+
231+
# Iterate over nodes in reverse order.
232+
for node, prop_constant_tensor in reversed(const_node_to_tensor.items()):
233+
if all(x in const_node_to_tensor for x in node.users):
234+
# All users of this constant node are also constant, so we don't need to create a new constant node.
235+
erase_constant_node(exported_program, node)
236+
continue
237+
238+
if node.op == "placeholder":
239+
continue
240+
241+
const_placeholder_node, prop_constant_tensor_fqn = replace_with_constant_node(
242+
node, prop_constant_tensor, first_user_input, fake_mode, exported_program
243+
)
244+
245+
# Create input spec for lifted constant.
246+
name_to_spec_dict[const_placeholder_node.name] = InputSpec(
247+
kind=InputKind.CONSTANT_TENSOR,
248+
arg=TensorArgument(name=const_placeholder_node.name),
249+
target=prop_constant_tensor_fqn,
250+
persistent=True,
251+
)
252+
return name_to_spec_dict
253+
254+
255+
def constant_prop_pass(
256+
exported_program: ExportedProgram,
257+
custom_skip_targets: Optional[set[EdgeOpOverload]] = None,
258+
) -> ExportedProgram:
259+
"""
260+
This pass is for constant propagation for Exported Program with lifted parameters,
261+
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
262+
263+
Args:
264+
exported_program: The ExportedProgram to perform constant propagation on.
265+
custom_skip_targets: Optional set of EdgeOpOverload targets to skip during constant propagation.
266+
267+
Returns:
268+
The modified ExportedProgram with constant propagation applied.
269+
"""
270+
if (
271+
len([node for node in exported_program.graph.nodes if node.op == "placeholder"])
272+
== 0
273+
):
274+
return exported_program
275+
276+
has_control_flow = [
277+
node
278+
for node in exported_program.graph.nodes
279+
if node.target == torch.ops.higher_order.cond
280+
]
281+
if len(has_control_flow) > 0:
282+
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
283+
284+
const_node_to_tensor = get_propagated_const_tensor_dict(
285+
exported_program, custom_skip_targets
286+
)
287+
288+
# Get old input specs.
289+
name_to_spec_dict = {
290+
s.arg.name: s for s in exported_program.graph_signature.input_specs
291+
}
292+
# Add the new constants to input specs dict.
293+
name_to_spec_dict.update(
294+
create_constant_nodes_and_return_specs(const_node_to_tensor, exported_program)
295+
)
296+
297+
# Generate new input spec.
298+
new_input_specs = []
81299
for node in exported_program.graph.nodes:
82-
if node.op == "call_function":
83-
constant_data_name_list = [
84-
input_spec.target for input_spec in prop_constant_data
85-
]
86-
if is_const(node.args, exported_program, constant_data_name_list):
87-
args_data = [get_data(exported_program, arg) for arg in node.args]
88-
kwargs_data = node.kwargs
89-
const_data_to_be_removed.update(node.args)
90-
prop_constant_tensor = node.target(*args_data, **kwargs_data)
91-
prop_constant_tensor_fqn = f"_prop_tensor_constant{len(buffers)}"
92-
93-
with exported_program.graph.inserting_before(first_user_input):
94-
const_placeholder_node = exported_program.graph.placeholder(
95-
prop_constant_tensor_fqn
96-
)
97-
# Update the meta data of the new placeholder (buffer) node
98-
for k, v in node.meta.items():
99-
const_placeholder_node.meta[k] = v
100-
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
101-
prop_constant_tensor, static_shapes=True
102-
)
103-
const_placeholder_node.meta["val"].constant = prop_constant_tensor
104-
105-
node.replace_all_uses_with(const_placeholder_node)
106-
exported_program.graph.erase_node(node)
107-
prop_constant_node_input_spec = InputSpec(
108-
kind=InputKind.BUFFER,
109-
arg=TensorArgument(name=const_placeholder_node.name),
110-
target=prop_constant_tensor_fqn,
111-
persistent=True,
112-
)
113-
prop_constant_data.append(prop_constant_node_input_spec)
114-
buffers.append(prop_constant_tensor_fqn)
115-
exported_program.state_dict[prop_constant_tensor_fqn] = (
116-
prop_constant_tensor
117-
)
118-
exported_program.graph_signature.input_specs.append(
119-
prop_constant_node_input_spec
120-
)
121-
122-
# Remove the propogated buffer from the state dict
123-
for node in exported_program.graph.nodes:
124-
if (
125-
node.op == "placeholder"
126-
and node in const_data_to_be_removed
127-
and len(node.users) == 0
128-
):
129-
exported_program.state_dict.pop(node.name, None)
130-
exported_program.graph.erase_node(node)
300+
if node.op != "placeholder":
301+
continue
302+
new_input_specs.append(name_to_spec_dict[node.name])
303+
exported_program.graph_signature.input_specs = new_input_specs
131304

305+
# Cleanup the graph.
306+
exported_program.graph.eliminate_dead_code()
132307
exported_program.graph_module.recompile()
308+
133309
return exported_program

0 commit comments

Comments
 (0)