Skip to content

Commit fee73b0

Browse files
authored
Add SD model annotation on fly (huggingface#869)
* Add SD model annotation on fly * Move tuned_compile_through_fx to utils * Fix SD compilation flags
1 parent 9bbffa5 commit fee73b0

File tree

8 files changed

+169
-70
lines changed

8 files changed

+169
-70
lines changed

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def end_profiling(device):
125125
height=height,
126126
width=width,
127127
use_base_vae=args.use_base_vae,
128+
use_tuned=args.use_tuned,
128129
)
129130
clip, unet, vae = mlir_import()
130131

shark/examples/shark_inference/stable_diffusion/model_wrappers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(
6262
height: int = 512,
6363
batch_size: int = 1,
6464
use_base_vae: bool = False,
65+
use_tuned: bool = False,
6566
):
6667
self.check_params(max_len, width, height)
6768
self.max_len = max_len
@@ -82,6 +83,7 @@ def __init__(
8283
+ "_"
8384
+ precision
8485
)
86+
self.use_tuned = use_tuned
8587
# We need a better naming convention for the .vmfbs because despite
8688
# using the custom model variant the .vmfb names remain the same and
8789
# it'll always pick up the compiled .vmfb instead of compiling the
@@ -133,6 +135,7 @@ def forward(self, input):
133135
inputs,
134136
is_f16=is_f16,
135137
model_name=vae_name + self.model_name,
138+
use_tuned=self.use_tuned,
136139
extra_args=get_opt_flags("vae", precision=self.precision),
137140
)
138141
return shark_vae
@@ -172,6 +175,7 @@ def forward(
172175
model_name="unet" + self.model_name,
173176
is_f16=is_f16,
174177
f16_input_mask=input_mask,
178+
use_tuned=self.use_tuned,
175179
extra_args=get_opt_flags("unet", precision=self.precision),
176180
)
177181
return shark_unet
Lines changed: 104 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,102 @@
11
import os
22
from shark.model_annotation import model_annotation, create_context
3-
from shark.iree_utils._common import run_cmd, iree_target_map
3+
from shark.iree_utils._common import iree_target_map, run_cmd
44
from shark.shark_downloader import (
55
download_model,
66
download_public_file,
77
WORKDIR,
88
)
99
from shark.parser import shark_args
1010
from stable_args import args
11-
from opt_params import get_params
12-
from utils import set_init_device_flags
1311

1412

15-
set_init_device_flags()
1613
device = (
1714
args.device if "://" not in args.device else args.device.split("://")[0]
1815
)
1916

20-
# Downloads the model (Unet or VAE fp16) from shark_tank
21-
shark_args.local_tank_cache = args.local_tank_cache
22-
bucket_key = f"{args.variant}/untuned"
23-
if args.annotation_model == "unet":
24-
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/untuned"
25-
elif args.annotation_model == "vae":
26-
is_base = "/base" if args.use_base_vae else ""
27-
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/untuned{is_base}"
28-
29-
bucket, model_name, iree_flags = get_params(
30-
bucket_key, model_key, args.annotation_model, "untuned", args.precision
31-
)
32-
mlir_model, func_name, inputs, golden_out = download_model(
33-
model_name,
34-
tank_url=bucket,
35-
frontend="torch",
36-
)
3717

38-
# Downloads the tuned config files from shark_tank
39-
config_bucket = "gs://shark_tank/sd_tuned/configs/"
40-
if args.use_winograd:
18+
# Download the model (Unet or VAE fp16) from shark_tank
19+
def load_model_from_tank():
20+
from opt_params import get_params, version, variant
21+
22+
shark_args.local_tank_cache = args.local_tank_cache
23+
bucket_key = f"{variant}/untuned"
24+
if args.annotation_model == "unet":
25+
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned"
26+
elif args.annotation_model == "vae":
27+
is_base = "/base" if args.use_base_vae else ""
28+
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}"
29+
30+
bucket, model_name, iree_flags = get_params(
31+
bucket_key, model_key, args.annotation_model, "untuned", args.precision
32+
)
33+
mlir_model, func_name, inputs, golden_out = download_model(
34+
model_name,
35+
tank_url=bucket,
36+
frontend="torch",
37+
)
38+
return mlir_model, model_name
39+
40+
41+
# Download the tuned config files from shark_tank
42+
def load_winograd_configs():
43+
config_bucket = "gs://shark_tank/sd_tuned/configs/"
4144
config_name = f"{args.annotation_model}_winograd_{device}.json"
4245
full_gs_url = config_bucket + config_name
4346
winograd_config_dir = f"{WORKDIR}configs/" + config_name
47+
print("Loading Winograd config file from ", winograd_config_dir)
4448
download_public_file(full_gs_url, winograd_config_dir, True)
49+
return winograd_config_dir
4550

46-
if args.annotation_model == "unet" or device == "cuda":
47-
if args.variant in ["anythingv3", "analogdiffusion"]:
51+
52+
def load_lower_configs():
53+
from opt_params import version, variant
54+
55+
config_bucket = "gs://shark_tank/sd_tuned/configs/"
56+
config_version = version
57+
if variant in ["anythingv3", "analogdiffusion"]:
4858
args.max_length = 77
49-
args.version = "v1_4"
59+
config_version = "v1_4"
5060
if args.annotation_model == "vae":
5161
args.max_length = 77
52-
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}_{device}.json"
62+
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}.json"
5363
full_gs_url = config_bucket + config_name
5464
lowering_config_dir = f"{WORKDIR}configs/" + config_name
65+
print("Loading lowering config file from ", lowering_config_dir)
5566
download_public_file(full_gs_url, lowering_config_dir, True)
67+
return lowering_config_dir
68+
5669

5770
# Annotate the model with Winograd attribute on selected conv ops
58-
if args.use_winograd:
71+
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
72+
if model_name.split("_")[-1] != "tuned":
73+
out_file_path = (
74+
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
75+
)
76+
else:
77+
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
78+
5979
with create_context() as ctx:
6080
winograd_model = model_annotation(
6181
ctx,
62-
input_contents=mlir_model,
82+
input_contents=input_mlir,
6383
config_path=winograd_config_dir,
6484
search_op="conv",
65-
winograd=args.use_winograd,
85+
winograd=True,
6686
)
67-
with open(
68-
f"{args.annotation_output}/{model_name}_tuned_torch.mlir", "w"
69-
) as f:
87+
with open(out_file_path, "w") as f:
7088
f.write(str(winograd_model))
89+
f.close()
90+
return winograd_model, out_file_path
91+
7192

7293
# For Unet annotate the model with tuned lowering configs
73-
if args.annotation_model == "unet" or device == "cuda":
74-
if args.use_winograd:
75-
input_mlir = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
94+
def annotate_with_lower_configs(
95+
input_mlir, lowering_config_dir, model_name, use_winograd
96+
):
97+
if use_winograd:
7698
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
7799
else:
78-
input_mlir = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
79100
dump_after = "iree-flow-pad-linalg-ops"
80101

81102
# Dump IR after padding/img2col/winograd passes
@@ -90,6 +111,8 @@
90111
device_spec_args = (
91112
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
92113
)
114+
print("Applying tuned configs on", model_name)
115+
93116
run_cmd(
94117
f"iree-compile {input_mlir} "
95118
"--iree-input-type=tm_tensor "
@@ -116,7 +139,48 @@
116139

117140
# Remove the intermediate mlir and save the final annotated model
118141
os.remove(f"{args.annotation_output}/dump_after_winograd.mlir")
119-
output_path = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
120-
with open(output_path, "w") as f:
142+
if model_name.split("_")[-1] != "tuned":
143+
out_file_path = (
144+
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
145+
)
146+
else:
147+
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
148+
with open(out_file_path, "w") as f:
121149
f.write(str(tuned_model))
150+
f.close()
151+
return tuned_model, out_file_path
152+
153+
154+
def sd_model_annotation(mlir_model, model_name):
155+
if args.annotation_model == "unet" and device == "vulkan":
156+
use_winograd = True
157+
winograd_config_dir = load_winograd_configs()
158+
winograd_model, model_path = annotate_with_winograd(
159+
mlir_model, winograd_config_dir, model_name
160+
)
161+
lowering_config_dir = load_lower_configs()
162+
tuned_model, output_path = annotate_with_lower_configs(
163+
model_path, lowering_config_dir, model_name, use_winograd
164+
)
165+
elif args.annotation_model == "vae" and device == "vulkan":
166+
use_winograd = True
167+
winograd_config_dir = load_winograd_configs()
168+
tuned_model, output_path = annotate_with_winograd(
169+
mlir_model, winograd_config_dir, model_name
170+
)
171+
else:
172+
use_winograd = False
173+
lowering_config_dir = load_lower_configs()
174+
tuned_model, output_path = annotate_with_lower_configs(
175+
mlir_model, lowering_config_dir, model_name, use_winograd
176+
)
122177
print(f"Saved the annotated mlir in {output_path}.")
178+
return tuned_model, output_path
179+
180+
181+
if __name__ == "__main__":
182+
mlir_model, model_name = load_model_from_tank()
183+
if device == "cuda":
184+
mlir_model = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
185+
186+
sd_model_annotation(mlir_model, model_name)

shark/examples/shark_inference/stable_diffusion/stable_args.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def path_expand(s):
9393

9494
p.add_argument(
9595
"--import_mlir",
96-
default=True,
96+
default=False,
9797
action=argparse.BooleanOptionalAction,
9898
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
9999
)
@@ -299,11 +299,4 @@ def path_expand(s):
299299
help="Options are unet and vae.",
300300
)
301301

302-
p.add_argument(
303-
"--use_winograd",
304-
default=False,
305-
action=argparse.BooleanOptionalAction,
306-
help="Apply Winograd on selected conv ops.",
307-
)
308-
309302
args = p.parse_args()

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import gc
23
import torch
34
from shark.shark_inference import SharkInference
45
from stable_args import args
@@ -9,6 +10,7 @@
910
)
1011
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
1112
from resources import opt_flags
13+
from sd_annotation import sd_model_annotation
1214
import sys
1315

1416

@@ -70,12 +72,40 @@ def compile_through_fx(
7072
model_name,
7173
is_f16=False,
7274
f16_input_mask=None,
75+
use_tuned=False,
7376
extra_args=[],
7477
):
7578

7679
mlir_module, func_name = import_with_fx(
77-
model, inputs, is_f16, f16_input_mask
80+
model, inputs, is_f16, f16_input_mask, return_str=use_tuned
7881
)
82+
83+
if use_tuned:
84+
model_name = model_name + "_tuned"
85+
tuned_model_path = f"{args.annotation_output}/{model_name}_torch.mlir"
86+
if not os.path.exists(tuned_model_path):
87+
if "vae" in model_name.split("_")[0]:
88+
args.annotation_model = "vae"
89+
90+
if "cuda" in args.device:
91+
output_path = (
92+
f"{args.annotation_output}/{model_name}_orig.mlir"
93+
)
94+
with open(output_path, "w") as f:
95+
f.write(mlir_module)
96+
f.close()
97+
mlir_module = output_path
98+
99+
tuned_model, tuned_model_path = sd_model_annotation(
100+
mlir_module, model_name
101+
)
102+
del mlir_module, tuned_model
103+
gc.collect()
104+
105+
with open(tuned_model_path, "rb") as f:
106+
mlir_module = f.read()
107+
f.close()
108+
79109
shark_module = SharkInference(
80110
mlir_module,
81111
device=args.device,
@@ -202,36 +232,30 @@ def set_init_device_flags():
202232
elif args.hf_model_id == "prompthero/openjourney":
203233
args.max_length = 64
204234

205-
# Use tuned models in the case of stablediffusion/fp16 and rdna3 cards.
235+
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
206236
if (
207237
args.hf_model_id
208238
in ["prompthero/openjourney", "dreamlike-art/dreamlike-diffusion-1.0"]
209239
or args.precision != "fp16"
210-
or "vulkan" not in args.device
211-
or "rdna3" not in args.iree_vulkan_target_triple
240+
or ("vulkan" not in args.device and "cuda" not in args.device)
212241
):
213242
args.use_tuned = False
214243

244+
elif (
245+
"vulkan" in args.device
246+
and "rdna3" not in args.iree_vulkan_target_triple
247+
):
248+
args.use_tuned = False
249+
250+
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80", "sm_89"]:
251+
args.use_tuned = False
252+
215253
elif args.use_base_vae and args.hf_model_id not in [
216254
"stabilityai/stable-diffusion-2-1-base",
217255
"CompVis/stable-diffusion-v1-4",
218256
]:
219257
args.use_tuned = False
220258

221-
# Use tuned model in the case of stablediffusion/fp16 and cuda device sm_80
222-
if (
223-
args.hf_model_id
224-
in [
225-
"stabilityai/stable-diffusion-2-1-base",
226-
"Linaqruf/anything-v3.0",
227-
"wavymulder/Analog-Diffusion",
228-
]
229-
and args.precision == "fp16"
230-
and "cuda" in args.device
231-
and get_cuda_sm_cc() in ["sm_80", "sm_89"]
232-
):
233-
args.use_tuned = True
234-
235259
if args.use_tuned:
236260
print(f"Using {args.device} tuned models for stablediffusion/fp16.")
237261
else:
@@ -287,6 +311,11 @@ def get_opt_flags(model, precision="fp16"):
287311
if sys.platform == "darwin":
288312
iree_flags.append("-iree-stream-fuse-binding=false")
289313

314+
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
315+
iree_flags += opt_flags[model][is_tuned][precision][
316+
"default_compilation_flags"
317+
]
318+
290319
if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]:
291320
device = (
292321
args.device
@@ -303,7 +332,6 @@ def get_opt_flags(model, precision="fp16"):
303332
iree_flags += opt_flags[model][is_tuned][precision][
304333
"specified_compilation_flags"
305334
][device]
306-
307335
return iree_flags
308336

309337

0 commit comments

Comments
 (0)