Skip to content

Commit 0754c6e

Browse files
yzhang93vivian
andauthored
Update model annotation to take vulkan configs (huggingface#495)
Co-authored-by: vivian <[email protected]>
1 parent 7b1f04d commit 0754c6e

File tree

1 file changed

+33
-11
lines changed

1 file changed

+33
-11
lines changed

shark/model_annotation.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
"""
16+
Usage:
17+
This function takes the model mlir file and the tuned config file as input,
18+
and output a new mlir file with lowering configs annotated on certain ops.
19+
There are two ways to utilize the function:
20+
1. Call model_annotation function within another python script
21+
from shark.model_annotation import model_annotation
22+
with create_context() as ctx:
23+
module = model_annotation(ctx, input_contents=..., config_path=..., search_op=...)
24+
2. Run model_annotation.py directly
25+
python model_annotation.py path_to_original_mlir path_to_config_file
26+
"""
27+
1528
import json
1629
import os
1730
import sys
@@ -105,32 +118,41 @@ def add_attributes(op: ir.Operation, config: Dict):
105118

106119

107120
def parse_config(config: Dict):
108-
if config["pipeline"] == "GPU" or config["pipeline"] == "GPU_TENSORCORE":
121+
split_k = None
122+
pipeline_depth = None
123+
if "GPU" in config["pipeline"]:
109124
pipeline = (
110125
"LLVMGPUMatmulSimt"
111126
if config["pipeline"] == "GPU"
112127
else "LLVMGPUMatmulTensorCore"
113128
)
114129
tile_sizes = [config["work_group_tile_sizes"]]
115130
workgroup_size = config["work_group_sizes"]
116-
try:
131+
if "pipeline_depth" in config.keys():
117132
pipeline_depth = config["pipeline_depth"]
118-
except:
119-
pipeline_depth = None
120-
try:
133+
if "split_k" in config.keys():
121134
split_k = config["split_k"]
122-
except:
123-
split_k = None
135+
elif "SPIRV" in config["pipeline"]:
136+
pipeline = config["pipeline"]
137+
tile_sizes = [
138+
config["work_group_tile_sizes"],
139+
config["parallel_tile_sizes"],
140+
config["reduction_tile_sizes"],
141+
]
142+
if "vector_tile_sizes" in config.keys():
143+
tile_sizes += [config["vector_tile_sizes"]]
144+
if "window_tile_sizes" in config.keys():
145+
tile_sizes += [config["window_tile_sizes"]]
146+
workgroup_size = config["work_group_sizes"]
124147
else:
148+
# For IREE CPU pipelines
125149
pipeline = config["pipeline"]
126150
tile_sizes = [
127151
config["work_group_tile_sizes"],
128-
config["l1_tile_sizes"],
129-
config["vector_tile_sizes"],
152+
config["parallel_tile_sizes"],
153+
config["reduction_tile_sizes"],
130154
]
131155
workgroup_size = []
132-
split_k = None
133-
pipeline_depth = None
134156
return tile_sizes, pipeline, workgroup_size, split_k, pipeline_depth
135157

136158

0 commit comments

Comments
 (0)