Skip to content

Commit bb41c2d

Browse files
authored
Add VAE cuda tuned model (huggingface#796)
1 parent eba138e commit bb41c2d

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

shark/examples/shark_inference/stable_diffusion/opt_params.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,15 @@ def get_unet():
7979

8080
def get_vae():
8181
# Tuned model is present only for `fp16` precision.
82-
is_tuned = (
83-
"tuned" if (args.use_tuned and "vulkan" in args.device) else "untuned"
84-
)
82+
is_tuned = "tuned" if args.use_tuned else "untuned"
8583
is_base = "/base" if args.use_base_vae else ""
86-
bucket_key = f"{args.variant}/{is_tuned}"
87-
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
84+
if "vulkan" not in args.device and args.use_tuned:
85+
bucket_key = f"{args.variant}/{is_tuned}/{args.device}"
86+
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
87+
else:
88+
bucket_key = f"{args.variant}/{is_tuned}"
89+
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
90+
8891
bucket, model_name, iree_flags = get_params(
8992
bucket_key, model_key, "vae", is_tuned, args.precision
9093
)

shark/examples/shark_inference/stable_diffusion/resources/model_db.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
2828
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae2base_19dec_fp16",
2929
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
30+
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
3031
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
3132
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
33+
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
3234
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip2base_18dec_fp32",
3335
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_19dec_v2p1base_fp32_64",
3436
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet2_14dec_fp16",

shark/examples/shark_inference/stable_diffusion/sd_annotation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
winograd_config_dir = f"{WORKDIR}configs/" + config_name
4444
download_public_file(full_gs_url, winograd_config_dir, True)
4545

46-
if args.annotation_model == "unet":
47-
if args.variant in ["anythingv3", "analogdiffusion"]:
46+
if args.annotation_model == "unet" or device == "cuda":
47+
if args.variant in ["anythingv3", "analogdiffusion"] or args.annotation_model == "vae":
4848
args.max_length = 77
4949
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}_{device}.json"
5050
full_gs_url = config_bucket + config_name
@@ -67,7 +67,7 @@
6767
f.write(str(winograd_model))
6868

6969
# For Unet annotate the model with tuned lowering configs
70-
if args.annotation_model == "unet":
70+
if args.annotation_model == "unet" or device == "cuda":
7171
if args.use_winograd:
7272
input_mlir = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
7373
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"

0 commit comments

Comments
 (0)