Skip to content

Commit 2005bce

Browse files
authored
Fix flags for untuned Stable Diffusion FP16 model (huggingface#478)
1 parent 8a02d77 commit 2005bce

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

shark/examples/shark_inference/stable_diffusion/opt_params.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ def get_unet():
3737
else:
3838
bucket = "gs://shark_tank/prashant_nod"
3939
model_name = "unet_fp16_v2"
40-
iree_flags += ["--iree-flow-enable-conv-nchw-to-nhwc-transform"]
40+
iree_flags += [
41+
"--iree-flow-enable-padding-linalg-ops",
42+
"--iree-flow-linalg-ops-padding-size=32",
43+
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
44+
]
4145
if args.import_mlir:
4246
return get_unet16_wrapped(model_name=model_name)
4347
return get_shark_model(bucket, model_name, iree_flags)

0 commit comments

Comments
 (0)