Skip to content

Commit bf50527

Browse files
authored
fix bug with sequential backends
Differential Revision: D74226258 Pull Request resolved: #10708
1 parent 6e959be commit bf50527

File tree

4 files changed

+136
-24
lines changed

4 files changed

+136
-24
lines changed

exir/backend/backend_api.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,16 @@ def _insert_lowered_submodule(
204204
owning_graph_module = call_submodule_node.graph.owning_module
205205
# call delegate args should only use user_inputs
206206
call_delegate_args = []
207-
# Preserve input order as user_inputs
208-
for inp_name in submodule_program.graph_signature.user_inputs:
209-
for inp_node in call_submodule_node.all_input_nodes:
210-
if inp_node.name == inp_name:
211-
call_delegate_args.append(inp_node)
212-
break
207+
# names of input_specs to delete
208+
input_specs_to_delete = toplevel_input_specs_to_delete
209+
# Delete owned constants from the call_submodule_node args
210+
for call_sm_input in call_submodule_node.args:
211+
if (
212+
isinstance(call_sm_input, torch.fx.Node)
213+
and call_sm_input.name in input_specs_to_delete.keys()
214+
):
215+
continue
216+
call_delegate_args.append(call_sm_input)
213217

214218
def generate_debug_handle(ep: ExportedProgram) -> int:
215219
"""
@@ -324,6 +328,7 @@ def _partition_and_lower_one_graph_module(
324328
toplevel_input_specs_to_delete,
325329
toplevel_output_specs_to_delete,
326330
)
331+
owning_program._validate()
327332

328333
return tagged_graph_module
329334

@@ -742,6 +747,7 @@ def to_backend(
742747
for method_name in method_to_edge_program.keys():
743748
if method_name in method_to_tagged_exported_program:
744749
tagged_exported_program = method_to_tagged_exported_program[method_name]
750+
tagged_exported_program._validate()
745751
partitioned_and_lowered_exported_programs[method_name] = ExportedProgram(
746752
root=tagged_exported_program.graph_module,
747753
graph=tagged_exported_program.graph_module.graph,

exir/backend/test/backend_with_preprocess_all_demo.py

+43-12
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,30 @@
2121
)
2222
from executorch.exir.dialects._ops import ops as exir_ops
2323
from executorch.exir.graph_module import get_control_flow_submodules
24+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
2425
from torch.export.exported_program import ExportedProgram
2526
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
2627

2728

29+
def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
30+
return (
31+
is_param(exp_prog, node)
32+
or is_buffer(exp_prog, node)
33+
or is_lifted_tensor_constant(exp_prog, node)
34+
)
35+
36+
37+
def get_total_num_ops_in_ep(edge_programs, supported_ops):
38+
total_number_of_ops = 0
39+
for edge_program in edge_programs.values():
40+
for partitioned_program in edge_program:
41+
for node in partitioned_program.graph.nodes:
42+
if node.op == "call_function":
43+
if node.target in supported_ops:
44+
total_number_of_ops += 1
45+
return total_number_of_ops
46+
47+
2848
def _preprocess_multimethod(
2949
edge_programs: Dict[str, List[ExportedProgram]],
3050
compile_specs: Dict[str, List[List[CompileSpec]]],
@@ -37,13 +57,7 @@ def _preprocess_multimethod(
3757
in testing for a partitioner which tags different partitions for different backends
3858
to be lowered to
3959
"""
40-
total_number_of_ops = 0
41-
for edge_program in edge_programs.values():
42-
for partitioned_program in edge_program:
43-
for node in partitioned_program.graph.nodes:
44-
if node.op == "call_function":
45-
if node.target in supported_ops:
46-
total_number_of_ops += 1
60+
total_number_of_ops = get_total_num_ops_in_ep(edge_programs, supported_ops)
4761
all_processed_results = {key: [] for key in edge_programs.keys()}
4862

4963
for method_name, partitioned_programs in edge_programs.items():
@@ -67,6 +81,8 @@ def _preprocess_multimethod(
6781
raise RuntimeError(
6882
f"{node.op} {node.target.__name__} is not supported in backend {backend_name}"
6983
)
84+
if is_param_node(partitioned_program, node):
85+
processed_bytes += f"CONST{node.name}:"
7086

7187
processed_bytes += "#"
7288
for cs in compile_spec_for_partition:
@@ -171,14 +187,30 @@ def preprocess_multimethod(
171187

172188

173189
class AddSinOperatorSupport(OperatorSupportBase):
190+
def __init__(self, original_program):
191+
self.original_program = original_program
192+
super().__init__()
193+
174194
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
175-
return node.op == "call_function" and node.target in [
195+
supported_targets = [
176196
exir_ops.edge.aten.add.Tensor,
177197
exir_ops.edge.aten.sin.default,
178198
]
199+
if node.op == "call_function" and node.target in supported_targets:
200+
return True
201+
202+
if node.op == "placeholder" and is_param_node(self.original_program, node):
203+
for user in node.users.keys():
204+
if user.target in supported_targets:
205+
return True
206+
return False
179207

180208

181209
class SubCosOperatorSupport(OperatorSupportBase):
210+
def __init__(self, original_program):
211+
self.original_program = original_program
212+
super().__init__()
213+
182214
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
183215
return node.op == "call_function" and node.target in [
184216
exir_ops.edge.aten.sub.Tensor,
@@ -199,11 +231,8 @@ class BackendWithPreprocessAllPartitioner(Partitioner):
199231
"""
200232

201233
def __init__(self) -> None:
202-
self.add_sin_support = any_chain(AddSinOperatorSupport())
203-
self.add_sin_backend_id = FirstBackendWithPreprocessAll.__name__
204-
205-
self.sub_cos_support = any_chain(SubCosOperatorSupport())
206234
self.sub_cos_backend_id = SecondBackendWithPreprocessAll.__name__
235+
self.add_sin_backend_id = FirstBackendWithPreprocessAll.__name__
207236

208237
def _partition_graph_module(
209238
self,
@@ -260,6 +289,8 @@ def _partition_graph_module(
260289
return partition_tags, start_idx_for_submodules
261290

262291
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
292+
self.add_sin_support = any_chain(AddSinOperatorSupport(exported_program))
293+
self.sub_cos_support = any_chain(SubCosOperatorSupport(exported_program))
263294
partition_tags, _ = self._partition_graph_module(exported_program.graph_module)
264295
return PartitionResult(
265296
tagged_exported_program=exported_program, partition_tags=partition_tags

exir/backend/test/test_to_backend_multi_method.py

+71
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,77 @@ def forward(self, x):
392392
}
393393
self._test(test_set)
394394

395+
def test_multi_method_to_backend_sequential_delegates(self):
396+
class SequentialBackendModule(torch.nn.Module):
397+
def __init__(self):
398+
super().__init__()
399+
400+
def forward(self, x, y, z):
401+
# delegate one
402+
x = x - x
403+
y = y - y
404+
z = z - z
405+
# graph break
406+
a = x * y * z
407+
# delegate two uses outputs from delegate one and the
408+
# output from the graph break
409+
b = x + a
410+
b = b + z + a
411+
b = b + y + a
412+
return b
413+
414+
module = SequentialBackendModule()
415+
example_inputs = (torch.ones(1), torch.ones(1), torch.ones(1))
416+
seq_edgeir_m = to_edge(torch.export.export(module, example_inputs))
417+
418+
test_set = {
419+
"seq_edgeir": (
420+
seq_edgeir_m.exported_program(),
421+
BackendWithPreprocessAllPartitioner(),
422+
[
423+
"SecondBackendWithPreprocessAll#3#aten.sub.Tensor:aten.sub.Tensor:aten.sub.Tensor:#sub:b'\\x02';sub:b'\\x02';sub:b'\\x02';",
424+
"FirstBackendWithPreprocessAll#5#aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:#add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';",
425+
],
426+
),
427+
}
428+
self._test(test_set)
429+
430+
def test_multi_method_to_backend_constants(self):
431+
class SequentialBackendModule(torch.nn.Module):
432+
def __init__(self):
433+
super().__init__()
434+
self.const = torch.zeros(1)
435+
436+
def forward(self, x, y, z):
437+
# delegate one
438+
x = x - x
439+
y = y - y
440+
z = z - z
441+
# graph break
442+
a = x * y * z * self.const
443+
# delegate two uses outputs from delegate one and the
444+
# output from the graph break
445+
b = x + self.const + a
446+
b = z + a + b
447+
b = y + a + b
448+
return b
449+
450+
module = SequentialBackendModule()
451+
example_inputs = (torch.ones(1), torch.ones(1), torch.ones(1))
452+
seq_const_m = to_edge(torch.export.export(module, example_inputs))
453+
454+
test_set = {
455+
"seq_const": (
456+
seq_const_m.exported_program(),
457+
BackendWithPreprocessAllPartitioner(),
458+
[
459+
"SecondBackendWithPreprocessAll#3#aten.sub.Tensor:aten.sub.Tensor:aten.sub.Tensor:#sub:b'\\x02';sub:b'\\x02';sub:b'\\x02';",
460+
"FirstBackendWithPreprocessAll#6#CONSTc_const_copy_0:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:#add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';",
461+
],
462+
),
463+
}
464+
self._test(test_set)
465+
395466
def test_multi_method_to_backend_not_found(self):
396467
class SinModule(torch.nn.Module):
397468
def __init__(self):

exir/lowered_backend_module.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def _fixup_output_node(gm: torch.fx.GraphModule) -> None:
381381

382382

383383
def arrange_graph_placeholders(
384-
gm: torch.fx.GraphModule, owning_program: ExportedProgram
384+
gm: torch.fx.GraphModule, owning_program: ExportedProgram, tag
385385
) -> torch.fx.GraphModule:
386386
"""
387387
Modifies the graph of the given graphmodule with one that contains the same nodes as the original,
@@ -411,9 +411,15 @@ def arrange_graph_placeholders(
411411
if node.op != "placeholder":
412412
continue
413413

414-
if node.name in graph_sign.inputs_to_parameters:
414+
if (
415+
node.name in graph_sign.inputs_to_parameters
416+
and node.meta.get("delegation_tag", None) == tag
417+
):
415418
param_nodes.append(node)
416-
elif node.name in graph_sign.inputs_to_buffers:
419+
elif (
420+
node.name in graph_sign.inputs_to_buffers
421+
and node.meta.get("delegation_tag", None) == tag
422+
):
417423
buffer_nodes.append(node)
418424
else:
419425
input_nodes.append(node)
@@ -694,7 +700,7 @@ def create_exported_program_from_submodule(
694700
removed from the toplevel ExportedProgram.
695701
"""
696702
# Arrange the submodule's placeholders in order
697-
submodule = arrange_graph_placeholders(submodule, owning_program)
703+
submodule = arrange_graph_placeholders(submodule, owning_program, tag)
698704

699705
# TODO: we probably need to arrange the outputs wrt buffer mutations.
700706

@@ -958,5 +964,3 @@ def _unsafe_adjust_original_program( # noqa: C901
958964
if user_idx > idx:
959965
user.args = (user.args[0], user_idx - (len(getitem_idxs) - i))
960966
break
961-
962-
original_program._validate()

0 commit comments

Comments
 (0)