55from tqdm .auto import tqdm
66import numpy as np
77from stable_args import args
8- from model_wrappers import (
9- get_vae32 ,
10- get_vae16 ,
11- get_unet16_wrapped ,
12- get_unet32_wrapped ,
13- get_clipped_text ,
14- )
158from utils import get_shark_model
9+ from opt_params import get_unet , get_vae , get_clip
1610import time
1711
18- GCLOUD_BUCKET = "gs://shark_tank/prashant_nod"
19- VAE_FP16 = "vae_fp16"
20- VAE_FP32 = "vae_fp32"
21- UNET_FP16 = "unet_fp16"
22- UNET_FP32 = "unet_fp32"
23- IREE_EXTRA_ARGS = []
24-
25- TUNED_GCLOUD_BUCKET = "gs://shark_tank/quinn"
26- UNET_FP16_TUNED = "unet_fp16_tunedv2"
27-
28- BATCH_SIZE = len (args .prompts )
29-
30- if BATCH_SIZE not in [1 , 2 ]:
31- import sys
32-
33- sys .exit ("Only batch size 1 and 2 are supported." )
34-
35- if BATCH_SIZE > 1 and args .precision != "fp16" :
36- sys .exit ("batch size > 1 is supported for fp16 model." )
37-
38-
39- if BATCH_SIZE != 1 :
40- TUNED_GCLOUD_BUCKET = "gs://shark_tank/prashant_nod"
41- UNET_FP16_TUNED = f"unet_fp16_{ BATCH_SIZE } "
42- VAE_FP16 = f"vae_fp16_{ BATCH_SIZE } "
4312
4413# Helper function to profile the vulkan device.
4514def start_profiling (file_path = "foo.rdc" , profiling_mode = "queue" ):
@@ -58,87 +27,9 @@ def end_profiling(device):
5827 return device .end_profiling ()
5928
6029
61- def get_models ():
62- global IREE_EXTRA_ARGS
63- if args .precision == "fp16" :
64- IREE_EXTRA_ARGS += [
65- "--iree-flow-enable-padding-linalg-ops" ,
66- "--iree-flow-linalg-ops-padding-size=32" ,
67- ]
68- if args .use_tuned :
69- unet_gcloud_bucket = TUNED_GCLOUD_BUCKET
70- vae_gcloud_bucket = GCLOUD_BUCKET
71- unet_args = IREE_EXTRA_ARGS
72- vae_args = IREE_EXTRA_ARGS + [
73- "--iree-flow-enable-conv-nchw-to-nhwc-transform"
74- ]
75- unet_name = UNET_FP16_TUNED
76- vae_name = VAE_FP16
77- else :
78- unet_gcloud_bucket = GCLOUD_BUCKET
79- vae_gcloud_bucket = GCLOUD_BUCKET
80- IREE_EXTRA_ARGS += [
81- "--iree-flow-enable-conv-nchw-to-nhwc-transform"
82- ]
83- unet_args = IREE_EXTRA_ARGS
84- vae_args = IREE_EXTRA_ARGS
85- unet_name = UNET_FP16
86- vae_name = VAE_FP16
87-
88- if batch_size > 1 :
89- vae_args = []
90-
91- if args .import_mlir == True :
92- return get_vae16 (model_name = VAE_FP16 ), get_unet16_wrapped (
93- model_name = UNET_FP16
94- )
95- else :
96- return get_shark_model (
97- vae_gcloud_bucket ,
98- vae_name ,
99- vae_args ,
100- ), get_shark_model (
101- unet_gcloud_bucket ,
102- unet_name ,
103- unet_args ,
104- )
105-
106- elif args .precision == "fp32" :
107- IREE_EXTRA_ARGS += [
108- "--iree-flow-enable-conv-nchw-to-nhwc-transform" ,
109- "--iree-flow-enable-padding-linalg-ops" ,
110- "--iree-flow-linalg-ops-padding-size=16" ,
111- ]
112- if args .import_mlir == True :
113- return get_vae32 (model_name = VAE_FP32 ), get_unet32_wrapped (
114- model_name = UNET_FP32
115- )
116- else :
117- return get_shark_model (
118- GCLOUD_BUCKET ,
119- VAE_FP32 ,
120- IREE_EXTRA_ARGS ,
121- ), get_shark_model (
122- GCLOUD_BUCKET ,
123- UNET_FP32 ,
124- IREE_EXTRA_ARGS ,
125- )
126-
127-
12830if __name__ == "__main__" :
12931
13032 dtype = torch .float32 if args .precision == "fp32" else torch .half
131- if len (args .iree_vulkan_target_triple ) > 0 :
132- IREE_EXTRA_ARGS .append (
133- f"-iree-vulkan-target-triple={ args .iree_vulkan_target_triple } "
134- )
135-
136- clip_model = "clip_text"
137- clip_extra_args = [
138- "--iree-flow-linalg-ops-padding-size=16" ,
139- "--iree-flow-enable-padding-linalg-ops" ,
140- ]
141- clip = get_shark_model (GCLOUD_BUCKET , clip_model , clip_extra_args )
14233
14334 prompt = args .prompts
14435 height = 512 # default height of Stable Diffusion
@@ -154,7 +45,7 @@ def get_models():
15445
15546 batch_size = len (prompt )
15647
157- vae , unet = get_models ()
48+ unet , vae , clip = get_unet (), get_vae (), get_clip ()
15849
15950 tokenizer = CLIPTokenizer .from_pretrained ("openai/clip-vit-large-patch14" )
16051
0 commit comments