Skip to content

Commit c226e6f

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add constant prop pass (#1146)
Summary: Pull Request resolved: #1146 Add a const prop pass for exported program Reviewed By: angelayi Differential Revision: D50961161 fbshipit-source-id: 5f4fc9e45c063d7697a81be4469b4c675ffac80d
1 parent 0fc4383 commit c226e6f

File tree

4 files changed

+225
-0
lines changed

4 files changed

+225
-0
lines changed

exir/passes/TARGETS

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ python_library(
7777
],
7878
)
7979

80+
python_library(
81+
name = "constant_prop_pass",
82+
srcs = [
83+
"constant_prop_pass.py",
84+
],
85+
deps = [
86+
"//caffe2:torch",
87+
],
88+
)
89+
8090
python_library(
8191
name = "remove_assert_async_pass",
8292
srcs = [

exir/passes/constant_prop_pass.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torch._export import ExportedProgram
9+
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
10+
from torch._guards import detect_fake_mode
11+
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
12+
13+
14+
def is_const(arg, exported_program, const_data_list) -> bool:
15+
if isinstance(arg, (tuple, list)):
16+
return all(is_const(x, exported_program, const_data_list) for x in arg)
17+
elif isinstance(arg, dict):
18+
return all(is_const(x, exported_program, const_data_list) for x in arg.values())
19+
elif not isinstance(arg, torch.fx.Node) or arg.op != "placeholder":
20+
return False
21+
elif (
22+
is_param(exported_program, arg)
23+
or is_buffer(exported_program, arg)
24+
or arg.name in const_data_list
25+
):
26+
return True
27+
return False
28+
29+
30+
def get_data(exported_program, arg):
31+
if isinstance(arg, (tuple, list)):
32+
return [get_data(exported_program, x) for x in arg]
33+
elif is_param(exported_program, arg):
34+
return get_param(exported_program, arg)
35+
elif is_buffer(exported_program, arg):
36+
return get_buffer(exported_program, arg)
37+
return None
38+
39+
40+
def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
41+
"""
42+
This pass is for constant propagation for Exported Program with lifted parameters,
43+
as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph.
44+
"""
45+
if (
46+
len([node for node in exported_program.graph.nodes if node.op == "placeholder"])
47+
== 0
48+
):
49+
return exported_program
50+
51+
has_cond = [
52+
node
53+
for node in exported_program.graph.nodes
54+
if node.target == torch.ops.higher_order.cond
55+
]
56+
if len(has_cond) > 0:
57+
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
58+
59+
first_user_input = None
60+
for node in exported_program.graph.nodes:
61+
if (
62+
node.op == "placeholder"
63+
and node.name in exported_program.graph_signature.user_inputs
64+
):
65+
first_user_input = node
66+
break
67+
68+
buffers = exported_program.graph_signature.buffers
69+
prop_constant_data = []
70+
const_data_to_be_removed = set()
71+
72+
fake_mode = detect_fake_mode(
73+
tuple(
74+
node.meta["val"]
75+
for node in exported_program.graph.nodes
76+
if node.op == "placeholder"
77+
)
78+
)
79+
assert fake_mode is not None
80+
81+
for node in exported_program.graph.nodes:
82+
if node.op == "call_function":
83+
constant_data_name_list = [
84+
input_spec.target for input_spec in prop_constant_data
85+
]
86+
if is_const(node.args, exported_program, constant_data_name_list):
87+
args_data = [get_data(exported_program, arg) for arg in node.args]
88+
kwargs_data = node.kwargs
89+
const_data_to_be_removed.update(node.args)
90+
prop_constant_tensor = node.target(*args_data, **kwargs_data)
91+
prop_constant_tensor_fqn = f"_prop_tensor_constant{len(buffers)}"
92+
93+
with exported_program.graph.inserting_before(first_user_input):
94+
const_placeholder_node = exported_program.graph.placeholder(
95+
prop_constant_tensor_fqn
96+
)
97+
# Update the meta data of the new placeholder (buffer) node
98+
for k, v in node.meta.items():
99+
const_placeholder_node.meta[k] = v
100+
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
101+
prop_constant_tensor, static_shapes=True
102+
)
103+
const_placeholder_node.meta["val"].constant = prop_constant_tensor
104+
105+
node.replace_all_uses_with(const_placeholder_node)
106+
exported_program.graph.erase_node(node)
107+
prop_constant_node_input_spec = InputSpec(
108+
kind=InputKind.BUFFER,
109+
arg=TensorArgument(name=const_placeholder_node.name),
110+
target=prop_constant_tensor_fqn,
111+
)
112+
prop_constant_data.append(prop_constant_node_input_spec)
113+
buffers.append(prop_constant_tensor_fqn)
114+
exported_program.state_dict[
115+
prop_constant_tensor_fqn
116+
] = prop_constant_tensor
117+
exported_program.graph_signature.input_specs.append(
118+
prop_constant_node_input_spec
119+
)
120+
121+
# Remove the propogated buffer from the state dict
122+
for node in exported_program.graph.nodes:
123+
if (
124+
node.op == "placeholder"
125+
and node in const_data_to_be_removed
126+
and len(node.users) == 0
127+
):
128+
exported_program.state_dict.pop(node.name, None)
129+
exported_program.graph.erase_node(node)
130+
131+
exported_program.graph_module.recompile()
132+
return exported_program

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ python_unittest(
212212
"//executorch/exir/dialects/edge:lib",
213213
"//executorch/exir/emit:lib",
214214
"//executorch/exir/passes:const_prop_pass",
215+
"//executorch/exir/passes:constant_prop_pass",
215216
"//executorch/exir/passes:debug_handle_generator_pass",
216217
"//executorch/exir/passes:lib",
217218
"//executorch/exir/passes:remove_assert_async_pass",

exir/tests/test_passes.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
ToOutVarPass,
3333
)
3434
from executorch.exir.passes.const_prop_pass import ConstPropPass
35+
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
3536
from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass
3637
from executorch.exir.passes.remove_assert_async_pass import RemoveAssertAsyncPass
3738
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
@@ -1057,3 +1058,84 @@ def forward(self, x):
10571058
FileCheck().check(
10581059
"executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"
10591060
).run(gm.code)
1061+
1062+
def test_constant_prop_pass_for_add(self) -> None:
1063+
def add(x: torch.Tensor) -> torch.Tensor:
1064+
return x + 3
1065+
1066+
edge = exir.capture(add, (torch.ones(1),), exir.CaptureConfig(enable_aot=True))
1067+
edge = edge.transform(ScalarToTensorPass(), RemoveMixedTypeOperators())
1068+
edge.exported_program = lift_constant_tensor_pass(edge.exported_program)
1069+
1070+
# Check there is a lifted tensor followed by a to_copy node
1071+
FileCheck().check("_lifted_tensor_constant0").check(
1072+
"torch.ops.aten._to_copy.default"
1073+
).run(edge.exported_program.graph_module.code)
1074+
1075+
new_ep = constant_prop_pass(edge.exported_program)
1076+
1077+
# Check (_lifted_tensor_constant + to_copy) node is replaced by prop tensor
1078+
FileCheck().check_not("_lifted_tensor_constant").check(
1079+
"_prop_tensor_constant1"
1080+
).check_not("torch.ops.aten._to_copy.default").run(new_ep.graph_module.code)
1081+
1082+
def test_constant_prop_pass_for_parameter(self) -> None:
1083+
def count_additions(gm: torch.fx.GraphModule) -> int:
1084+
return sum(
1085+
(node.target == torch.ops.aten.add.Tensor) for node in gm.graph.nodes
1086+
)
1087+
1088+
class M(torch.nn.Module):
1089+
def __init__(self):
1090+
super().__init__()
1091+
self.a = torch.nn.Parameter(torch.ones(1, 2, 3))
1092+
1093+
def forward(self, x):
1094+
b = self.a + self.a
1095+
c = torch.cat([self.a, b])
1096+
return (c + c) + x
1097+
1098+
edge = exir.capture(
1099+
M(),
1100+
(torch.zeros(2, 2, 3),),
1101+
exir.CaptureConfig(enable_aot=True),
1102+
)
1103+
self.assertEqual(count_additions(edge.exported_program.graph_module), 3)
1104+
edge.exported_program = constant_prop_pass(edge.exported_program)
1105+
self.assertEqual(count_additions(edge.exported_program.graph_module), 1)
1106+
1107+
def test_constant_prop_pass_for_control_flow(self) -> None:
1108+
class Module(torch.nn.Module):
1109+
def __init__(self):
1110+
super().__init__()
1111+
self.linear = torch.nn.Linear(3, 3)
1112+
1113+
def t(self, val):
1114+
return val + 1
1115+
1116+
def f(self, val):
1117+
return val - 1
1118+
1119+
def true_fn(self, val):
1120+
return self.linear(val) + self.t(val)
1121+
1122+
def false_fn(self, val):
1123+
return self.linear(val) - self.f(val)
1124+
1125+
def forward(self, pred, x):
1126+
return torch.ops.higher_order.cond(
1127+
pred, self.true_fn, self.false_fn, [x]
1128+
)
1129+
1130+
mod = Module()
1131+
x = torch.randn([3, 3])
1132+
pred = torch.tensor(x[0][0].item() < 0)
1133+
edge = exir.capture(mod, (pred, x), config=exir.CaptureConfig(enable_aot=True))
1134+
error_msg = r"constant_prop_pass for control flow is not supported yet."
1135+
1136+
# TODO(chenlai): enable constant prop pass for control flow
1137+
with self.assertRaisesRegex(
1138+
RuntimeError,
1139+
error_msg,
1140+
):
1141+
_ = constant_prop_pass(edge.exported_program)

0 commit comments

Comments
 (0)