Skip to content

Commit d345278

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Fix stories model export (#3622)
Summary: Pull Request resolved: #3622 ## Context WIth the amount of newly added operators, exporting the Stories model via the `export_llama` script is currently broken. This changeset mainly upgrades `VulkanPartitioner` to allow more fine grained control over which nodes are partitioned: 1. Update how operators are listed and allow specifying features for specific operators 2. Check node arguments and output to see that they are valid before marking them as supported. 3. Allow partitioning to select only operators that support dynamic shapes ghstack-source-id: 226482740 Reviewed By: copyrightly, jorgep31415 Differential Revision: D57385871 fbshipit-source-id: 0ac6f3c1394541d2a5b3db9a443cca4b48dd89eb
1 parent 6b6c1fa commit d345278

File tree

5 files changed

+248
-63
lines changed

5 files changed

+248
-63
lines changed

backends/vulkan/partitioner/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ oncall("executorch")
55
runtime.python_library(
66
name = "vulkan_partitioner",
77
srcs = [
8+
"supported_ops.py",
89
"vulkan_partitioner.py",
910
],
1011
visibility = [
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import operator
8+
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
11+
12+
class OpFeatures:
13+
__slots__ = ["supports_texture", "supports_buffer", "supports_dynamic_shape"]
14+
15+
def __init__(
16+
self,
17+
supports_dynamic_shape: bool = False,
18+
supports_buffer: bool = False,
19+
supports_texture: bool = True,
20+
):
21+
self.supports_dynamic_shape = supports_dynamic_shape
22+
self.supports_texture = supports_texture
23+
self.supports_buffer = supports_buffer
24+
25+
26+
class OpList:
27+
def __init__(self):
28+
self._ops = {}
29+
30+
def __getitem__(self, op):
31+
if op not in self._ops:
32+
self._ops[op] = OpFeatures()
33+
return self._ops[op]
34+
35+
def __contains__(self, op):
36+
return op in self._ops
37+
38+
39+
PRIM_OPS = [
40+
operator.getitem,
41+
]
42+
43+
BINARY_OPS = [
44+
exir_ops.edge.aten.add.Tensor,
45+
exir_ops.edge.aten.sub.Tensor,
46+
exir_ops.edge.aten.mul.Tensor,
47+
exir_ops.edge.aten.div.Tensor,
48+
exir_ops.edge.aten.div.Tensor_mode,
49+
exir_ops.edge.aten.pow.Tensor_Tensor,
50+
]
51+
52+
UNARY_OPS = [
53+
exir_ops.edge.aten.abs.default,
54+
exir_ops.edge.aten.clamp.default,
55+
exir_ops.edge.aten.gelu.default,
56+
exir_ops.edge.aten.hardtanh.default,
57+
exir_ops.edge.aten.relu.default,
58+
exir_ops.edge.aten.sigmoid.default,
59+
exir_ops.edge.aten.sqrt.default,
60+
exir_ops.edge.aten.tanh.default,
61+
]
62+
63+
MATMUL_OPS = [
64+
exir_ops.edge.aten.bmm.default,
65+
exir_ops.edge.aten.mm.default,
66+
exir_ops.edge.aten.addmm.default,
67+
exir_ops.edge.aten.linear.default,
68+
]
69+
70+
POOLING_OPS = [
71+
exir_ops.edge.aten.max_pool2d_with_indices.default,
72+
]
73+
74+
CONVOLUTION_OPS = [
75+
exir_ops.edge.aten.convolution.default,
76+
]
77+
78+
REDUCTION_OPS = [
79+
exir_ops.edge.aten.sum.dim_IntList,
80+
exir_ops.edge.aten._softmax.default,
81+
exir_ops.edge.aten._log_softmax.default,
82+
]
83+
84+
NORMALIZATION_OPS = [
85+
exir_ops.edge.aten.native_layer_norm.default,
86+
]
87+
88+
SHAPE_MANIPULATION_OPS = [
89+
exir_ops.edge.aten.unsqueeze_copy.default,
90+
exir_ops.edge.aten.view_copy.default,
91+
exir_ops.edge.aten.permute_copy.default,
92+
exir_ops.edge.aten.t_copy.default,
93+
]
94+
95+
INDEXING_OPS = [
96+
exir_ops.edge.aten.select_copy.int,
97+
exir_ops.edge.aten.slice_copy.Tensor,
98+
]
99+
100+
ORCHESTRATION_OPS = [
101+
exir_ops.edge.aten.cat.default,
102+
exir_ops.edge.aten.split_with_sizes_copy.default,
103+
exir_ops.edge.aten.split.Tensor,
104+
exir_ops.edge.aten.repeat.default,
105+
]
106+
107+
CREATION_OPS = [
108+
exir_ops.edge.aten.clone.default,
109+
exir_ops.edge.aten.full.default,
110+
]
111+
112+
113+
def register_prim_ops(ops: OpList):
114+
for op in PRIM_OPS:
115+
ops[op].supports_texture = True
116+
ops[op].supports_buffer = True
117+
ops[op].supports_dynamic_shape = True
118+
119+
120+
def register_no_dynamic_shape_ops(ops: OpList):
121+
for op in [
122+
*REDUCTION_OPS,
123+
*NORMALIZATION_OPS,
124+
*SHAPE_MANIPULATION_OPS,
125+
*INDEXING_OPS,
126+
*ORCHESTRATION_OPS,
127+
*CREATION_OPS,
128+
]:
129+
ops[op].supports_dynamic_shape = False
130+
131+
132+
def register_dynamic_shape_ops(ops: OpList):
133+
for op in [
134+
*BINARY_OPS,
135+
*UNARY_OPS,
136+
*MATMUL_OPS,
137+
*POOLING_OPS,
138+
*CONVOLUTION_OPS,
139+
]:
140+
ops[op].supports_dynamic_shape = True
141+
142+
143+
def enumerate_supported_ops():
144+
ops = OpList()
145+
register_prim_ops(ops)
146+
register_no_dynamic_shape_ops(ops)
147+
register_dynamic_shape_ops(ops)
148+
return ops

backends/vulkan/partitioner/vulkan_partitioner.py

+86-60
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import operator
87
from typing import Any, Dict, final, List, Optional
98

109
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
1110

1211
import torch
12+
13+
from executorch.backends.vulkan.partitioner.supported_ops import enumerate_supported_ops
1314
from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend
1415
from executorch.exir.backend.compile_spec_schema import CompileSpec
1516
from executorch.exir.backend.partitioner import (
@@ -19,88 +20,113 @@
1920
)
2021
from executorch.exir.backend.utils import tag_constant_data
2122
from executorch.exir.dialects._ops import ops as exir_ops
23+
24+
from torch._subclasses.fake_tensor import FakeTensor
2225
from torch.export.exported_program import ExportedProgram
2326
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2427

2528
from torch.fx.passes.operator_support import OperatorSupportBase
2629

2730

2831
class VulkanSupportedOperators(OperatorSupportBase):
29-
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
30-
supported = node.op == "call_function" and node.target in [
31-
# Binary arithmetic operators
32-
exir_ops.edge.aten.add.Tensor,
33-
exir_ops.edge.aten.sub.Tensor,
34-
exir_ops.edge.aten.mul.Tensor,
35-
exir_ops.edge.aten.div.Tensor,
36-
exir_ops.edge.aten.div.Tensor_mode,
37-
exir_ops.edge.aten.pow.Tensor_Tensor,
38-
# Unary operators
39-
exir_ops.edge.aten.abs.default,
40-
exir_ops.edge.aten.clamp.default,
41-
exir_ops.edge.aten.gelu.default,
42-
exir_ops.edge.aten.hardtanh.default,
43-
exir_ops.edge.aten.relu.default,
44-
exir_ops.edge.aten.sigmoid.default,
45-
exir_ops.edge.aten.sqrt.default,
46-
exir_ops.edge.aten.tanh.default,
47-
# Matrix multiplication operators
48-
exir_ops.edge.aten.bmm.default,
32+
_ops = enumerate_supported_ops()
33+
34+
def __init__(self, require_dynamic_shape: bool = False):
35+
super().__init__()
36+
self.require_dynamic_shapes = require_dynamic_shape
37+
38+
def node_val_is_compatible(self, node_val: Any) -> bool:
39+
# Skip nodes that don't have a value
40+
if node_val is None:
41+
return True
42+
43+
# TODO(ssjia) support symbolic ints
44+
if isinstance(node_val, torch.SymInt):
45+
return False
46+
47+
if isinstance(node_val, FakeTensor):
48+
# Vulkan currently only supports tensors of up to 4D
49+
if len(node_val.shape) > 4:
50+
return False
51+
52+
if isinstance(node_val, (list, tuple)):
53+
for item in node_val:
54+
if not self.node_val_is_compatible(item):
55+
return False
56+
57+
return True
58+
59+
def all_args_compatible(self, node: torch.fx.Node) -> bool:
60+
node_val = node.meta.get("val", None)
61+
if not self.node_val_is_compatible(node_val):
62+
return False
63+
64+
for arg in node.args:
65+
if not isinstance(arg, torch.fx.Node):
66+
continue
67+
68+
arg_val = arg.meta.get("val", None)
69+
if not self.node_val_is_compatible(arg_val):
70+
return False
71+
72+
return True
73+
74+
def is_linear_permute(self, node: torch.fx.Node) -> bool:
75+
if node.target not in [
76+
exir_ops.edge.aten.t_copy.default,
77+
exir_ops.edge.aten.permute_copy.default,
78+
]:
79+
return False
80+
81+
if len(node.users) != 1:
82+
return False
83+
84+
if list(node.users.keys())[0].target in [
4985
exir_ops.edge.aten.mm.default,
5086
exir_ops.edge.aten.addmm.default,
51-
# Pooling operators
52-
exir_ops.edge.aten.max_pool2d_with_indices.default,
53-
# Sum
54-
exir_ops.edge.aten.sum.dim_IntList,
55-
# Convolution operators
56-
exir_ops.edge.aten.convolution.default,
57-
# Normalization
58-
exir_ops.edge.aten.native_layer_norm.default,
59-
# Shape-related operators
60-
exir_ops.edge.aten.select_copy.int,
61-
exir_ops.edge.aten.unsqueeze_copy.default,
62-
exir_ops.edge.aten.view_copy.default,
63-
# Copy-releated operators
64-
exir_ops.edge.aten.permute_copy.default,
65-
exir_ops.edge.aten.clone.default,
66-
exir_ops.edge.aten.cat.default,
67-
exir_ops.edge.aten.split_with_sizes_copy.default,
68-
exir_ops.edge.aten.split.Tensor,
69-
exir_ops.edge.aten.slice_copy.Tensor,
70-
exir_ops.edge.aten.repeat.default,
71-
# Softmax
72-
exir_ops.edge.aten._softmax.default,
73-
exir_ops.edge.aten._log_softmax.default,
74-
# Other
75-
operator.getitem,
76-
exir_ops.edge.aten.full.default,
77-
]
78-
return supported
79-
80-
81-
def parse_compile_options(
82-
compile_options: Optional[Dict[str, Any]] = None
83-
) -> List[CompileSpec]:
87+
]:
88+
return True
89+
90+
return False
91+
92+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
93+
if self.is_linear_permute(node):
94+
return True
95+
96+
if node.target not in VulkanSupportedOperators._ops:
97+
return False
98+
99+
features = VulkanSupportedOperators._ops[node.target]
100+
101+
if self.require_dynamic_shapes and not features.supports_dynamic_shape:
102+
return False
103+
104+
return self.all_args_compatible(node)
105+
106+
107+
def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]:
84108
compile_specs = []
85-
if compile_options is None:
86-
return compile_specs
87109

88110
for key, value in compile_options.items():
89111
if isinstance(
90112
value, (vk_graph_schema.VkStorageType, vk_graph_schema.VkMemoryLayout)
91113
):
92114
value_bytes = int(value).to_bytes(4, byteorder="little")
93115
compile_specs.append(CompileSpec(key, value_bytes))
94-
else:
95-
raise RuntimeError(f"Invalid compile option {key} with type {type(value)}")
116+
117+
# Unhandled options are ignored
96118

97119
return compile_specs
98120

99121

100122
@final
101123
class VulkanPartitioner(Partitioner):
102124
def __init__(self, compile_options: Optional[Dict[str, Any]] = None) -> None:
103-
compile_spec = parse_compile_options(compile_options)
125+
self.options: Dict[str, Any] = {}
126+
if compile_options is not None:
127+
self.options = compile_options
128+
129+
compile_spec = parse_compile_options(self.options)
104130
self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec)
105131

106132
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
@@ -110,7 +136,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
110136

111137
capability_partitioner = CapabilityBasedPartitioner(
112138
exported_program.graph_module,
113-
VulkanSupportedOperators(),
139+
VulkanSupportedOperators(self.options.get("require_dynamic_shapes", False)),
114140
allows_single_node_partition=True,
115141
)
116142
partition_list = capability_partitioner.propose_partitions()

backends/vulkan/serialization/vulkan_graph_builder.py

+11
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
import operator
89
from types import NoneType
910
from typing import cast, List, Optional, Union
@@ -36,6 +37,9 @@ def __init__(self, program: ExportedProgram) -> None:
3637
# Mapping from Node to VkValue id
3738
self.node_to_value_ids = {}
3839

40+
# For logging
41+
self.seen_ops = set()
42+
3943
@staticmethod
4044
def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
4145
if torch_dtype == torch.bool:
@@ -230,6 +234,7 @@ def get_or_create_value_for(self, arg: _Argument):
230234
or isinstance(arg, torch.device)
231235
or isinstance(arg, torch.dtype)
232236
or isinstance(arg, torch.layout)
237+
or isinstance(arg, torch.memory_format)
233238
):
234239
return self.create_null_value()
235240
elif isinstance(arg, _ScalarType):
@@ -271,6 +276,8 @@ def process_getitem_node(self, node: Node) -> None:
271276
def process_call_function_node(self, node) -> None:
272277
operator_call_args = []
273278

279+
self.seen_ops.add(node.target)
280+
274281
for i, schema_arg in enumerate(node.target._schema.arguments):
275282
if not schema_arg.kwarg_only and i < len(node.args):
276283
function_arg = node.args[i]
@@ -325,6 +332,10 @@ def build_graph(self) -> vk_graph_schema.VkGraph:
325332
for node in self.program.graph_module.graph.nodes:
326333
self.process_node(node)
327334

335+
logging.info("Operators included in this Vulkan partition: ")
336+
for op in self.seen_ops:
337+
logging.info(f" {op.__name__}")
338+
328339
return vk_graph_schema.VkGraph(
329340
version="0",
330341
chain=self.chain,

0 commit comments

Comments
 (0)