Skip to content

Commit 239c19e

Browse files
authored
Update Stable diffusion script to enable use of tuned models (huggingface#443)
1 parent 7f37599 commit 239c19e

File tree

2 files changed

+35
-8
lines changed

2 files changed

+35
-8
lines changed

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
UNET_FP32 = "unet_fp32"
2222
IREE_EXTRA_ARGS = []
2323

24+
TUNED_GCLOUD_BUCKET = "gs://shark_tank/quinn"
25+
UNET_FP16_TUNED = "unet_fp16_tuned"
26+
2427
# Helper function to profile the vulkan device.
2528
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
2629
if args.vulkan_debug_utils and "vulkan" in args.device:
@@ -42,24 +45,41 @@ def get_models():
4245
global IREE_EXTRA_ARGS
4346
if args.precision == "fp16":
4447
IREE_EXTRA_ARGS += [
45-
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
4648
"--iree-flow-enable-padding-linalg-ops",
4749
"--iree-flow-linalg-ops-padding-size=32",
48-
"--iree-spirv-unify-aliased-resources=false",
4950
]
51+
if args.use_tuned:
52+
unet_gcloud_bucket = TUNED_GCLOUD_BUCKET
53+
vae_gcloud_bucket = GCLOUD_BUCKET
54+
unet_args = IREE_EXTRA_ARGS
55+
vae_args = IREE_EXTRA_ARGS + [
56+
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
57+
]
58+
unet_name = UNET_FP16_TUNED
59+
vae_name = VAE_FP16
60+
else:
61+
unet_gcloud_bucket = GCLOUD_BUCKET
62+
vae_gcloud_bucket = GCLOUD_BUCKET
63+
IREE_EXTRA_ARGS += [
64+
"--iree-flow-enable-conv-nchw-to-nhwc-transform"
65+
]
66+
unet_args = IREE_EXTRA_ARGS
67+
vae_args = IREE_EXTRA_ARGS
68+
unet_name = UNET_FP16
69+
vae_name = VAE_FP16
5070
if args.import_mlir == True:
5171
return get_vae16(model_name=VAE_FP16), get_unet16_wrapped(
5272
model_name=UNET_FP16
5373
)
5474
else:
5575
return get_shark_model(
56-
GCLOUD_BUCKET,
57-
VAE_FP16,
58-
IREE_EXTRA_ARGS,
76+
vae_gcloud_bucket,
77+
vae_name,
78+
vae_args,
5979
), get_shark_model(
60-
GCLOUD_BUCKET,
61-
UNET_FP16,
62-
IREE_EXTRA_ARGS,
80+
unet_gcloud_bucket,
81+
unet_name,
82+
unet_args,
6383
)
6484

6585
elif args.precision == "fp32":

shark/examples/shark_inference/stable_diffusion/stable_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,11 @@
7878
help="Profiles vulkan device and collects the .rdc info",
7979
)
8080

81+
p.add_argument(
82+
"--use_tuned",
83+
default=True,
84+
action=argparse.BooleanOptionalAction,
85+
help="Download and use the tuned version of the model if available",
86+
)
87+
8188
args = p.parse_args()

0 commit comments

Comments
 (0)