4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- import operator
8
7
from typing import Any , Dict , final , List , Optional
9
8
10
9
import executorch .backends .vulkan .serialization .vulkan_graph_schema as vk_graph_schema
11
10
12
11
import torch
12
+
13
+ from executorch .backends .vulkan .partitioner .supported_ops import enumerate_supported_ops
13
14
from executorch .backends .vulkan .vulkan_preprocess import VulkanBackend
14
15
from executorch .exir .backend .compile_spec_schema import CompileSpec
15
16
from executorch .exir .backend .partitioner import (
19
20
)
20
21
from executorch .exir .backend .utils import tag_constant_data
21
22
from executorch .exir .dialects ._ops import ops as exir_ops
23
+
24
+ from torch ._subclasses .fake_tensor import FakeTensor
22
25
from torch .export .exported_program import ExportedProgram
23
26
from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner
24
27
25
28
from torch .fx .passes .operator_support import OperatorSupportBase
26
29
27
30
28
31
class VulkanSupportedOperators (OperatorSupportBase ):
29
- def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
30
- supported = node .op == "call_function" and node .target in [
31
- # Binary arithmetic operators
32
- exir_ops .edge .aten .add .Tensor ,
33
- exir_ops .edge .aten .sub .Tensor ,
34
- exir_ops .edge .aten .mul .Tensor ,
35
- exir_ops .edge .aten .div .Tensor ,
36
- exir_ops .edge .aten .div .Tensor_mode ,
37
- exir_ops .edge .aten .pow .Tensor_Tensor ,
38
- # Unary operators
39
- exir_ops .edge .aten .abs .default ,
40
- exir_ops .edge .aten .clamp .default ,
41
- exir_ops .edge .aten .gelu .default ,
42
- exir_ops .edge .aten .hardtanh .default ,
43
- exir_ops .edge .aten .relu .default ,
44
- exir_ops .edge .aten .sigmoid .default ,
45
- exir_ops .edge .aten .sqrt .default ,
46
- exir_ops .edge .aten .tanh .default ,
47
- # Matrix multiplication operators
48
- exir_ops .edge .aten .bmm .default ,
32
+ _ops = enumerate_supported_ops ()
33
+
34
+ def __init__ (self , require_dynamic_shape : bool = False ):
35
+ super ().__init__ ()
36
+ self .require_dynamic_shapes = require_dynamic_shape
37
+
38
+ def node_val_is_compatible (self , node_val : Any ) -> bool :
39
+ # Skip nodes that don't have a value
40
+ if node_val is None :
41
+ return True
42
+
43
+ # TODO(ssjia) support symbolic ints
44
+ if isinstance (node_val , torch .SymInt ):
45
+ return False
46
+
47
+ if isinstance (node_val , FakeTensor ):
48
+ # Vulkan currently only supports tensors of up to 4D
49
+ if len (node_val .shape ) > 4 :
50
+ return False
51
+
52
+ if isinstance (node_val , (list , tuple )):
53
+ for item in node_val :
54
+ if not self .node_val_is_compatible (item ):
55
+ return False
56
+
57
+ return True
58
+
59
+ def all_args_compatible (self , node : torch .fx .Node ) -> bool :
60
+ node_val = node .meta .get ("val" , None )
61
+ if not self .node_val_is_compatible (node_val ):
62
+ return False
63
+
64
+ for arg in node .args :
65
+ if not isinstance (arg , torch .fx .Node ):
66
+ continue
67
+
68
+ arg_val = arg .meta .get ("val" , None )
69
+ if not self .node_val_is_compatible (arg_val ):
70
+ return False
71
+
72
+ return True
73
+
74
+ def is_linear_permute (self , node : torch .fx .Node ) -> bool :
75
+ if node .target not in [
76
+ exir_ops .edge .aten .t_copy .default ,
77
+ exir_ops .edge .aten .permute_copy .default ,
78
+ ]:
79
+ return False
80
+
81
+ if len (node .users ) != 1 :
82
+ return False
83
+
84
+ if list (node .users .keys ())[0 ].target in [
49
85
exir_ops .edge .aten .mm .default ,
50
86
exir_ops .edge .aten .addmm .default ,
51
- # Pooling operators
52
- exir_ops .edge .aten .max_pool2d_with_indices .default ,
53
- # Sum
54
- exir_ops .edge .aten .sum .dim_IntList ,
55
- # Convolution operators
56
- exir_ops .edge .aten .convolution .default ,
57
- # Normalization
58
- exir_ops .edge .aten .native_layer_norm .default ,
59
- # Shape-related operators
60
- exir_ops .edge .aten .select_copy .int ,
61
- exir_ops .edge .aten .unsqueeze_copy .default ,
62
- exir_ops .edge .aten .view_copy .default ,
63
- # Copy-releated operators
64
- exir_ops .edge .aten .permute_copy .default ,
65
- exir_ops .edge .aten .clone .default ,
66
- exir_ops .edge .aten .cat .default ,
67
- exir_ops .edge .aten .split_with_sizes_copy .default ,
68
- exir_ops .edge .aten .split .Tensor ,
69
- exir_ops .edge .aten .slice_copy .Tensor ,
70
- exir_ops .edge .aten .repeat .default ,
71
- # Softmax
72
- exir_ops .edge .aten ._softmax .default ,
73
- exir_ops .edge .aten ._log_softmax .default ,
74
- # Other
75
- operator .getitem ,
76
- exir_ops .edge .aten .full .default ,
77
- ]
78
- return supported
79
-
80
-
81
- def parse_compile_options (
82
- compile_options : Optional [Dict [str , Any ]] = None
83
- ) -> List [CompileSpec ]:
87
+ ]:
88
+ return True
89
+
90
+ return False
91
+
92
+ def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
93
+ if self .is_linear_permute (node ):
94
+ return True
95
+
96
+ if node .target not in VulkanSupportedOperators ._ops :
97
+ return False
98
+
99
+ features = VulkanSupportedOperators ._ops [node .target ]
100
+
101
+ if self .require_dynamic_shapes and not features .supports_dynamic_shape :
102
+ return False
103
+
104
+ return self .all_args_compatible (node )
105
+
106
+
107
+ def parse_compile_options (compile_options : Dict [str , Any ]) -> List [CompileSpec ]:
84
108
compile_specs = []
85
- if compile_options is None :
86
- return compile_specs
87
109
88
110
for key , value in compile_options .items ():
89
111
if isinstance (
90
112
value , (vk_graph_schema .VkStorageType , vk_graph_schema .VkMemoryLayout )
91
113
):
92
114
value_bytes = int (value ).to_bytes (4 , byteorder = "little" )
93
115
compile_specs .append (CompileSpec (key , value_bytes ))
94
- else :
95
- raise RuntimeError ( f"Invalid compile option { key } with type { type ( value ) } " )
116
+
117
+ # Unhandled options are ignored
96
118
97
119
return compile_specs
98
120
99
121
100
122
@final
101
123
class VulkanPartitioner (Partitioner ):
102
124
def __init__ (self , compile_options : Optional [Dict [str , Any ]] = None ) -> None :
103
- compile_spec = parse_compile_options (compile_options )
125
+ self .options : Dict [str , Any ] = {}
126
+ if compile_options is not None :
127
+ self .options = compile_options
128
+
129
+ compile_spec = parse_compile_options (self .options )
104
130
self .delegation_spec = DelegationSpec (VulkanBackend .__name__ , compile_spec )
105
131
106
132
def partition (self , exported_program : ExportedProgram ) -> PartitionResult :
@@ -110,7 +136,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
110
136
111
137
capability_partitioner = CapabilityBasedPartitioner (
112
138
exported_program .graph_module ,
113
- VulkanSupportedOperators (),
139
+ VulkanSupportedOperators (self . options . get ( "require_dynamic_shapes" , False ) ),
114
140
allows_single_node_partition = True ,
115
141
)
116
142
partition_list = capability_partitioner .propose_partitions ()
0 commit comments