2222
2323IREE_EXTRA_ARGS = []
2424args = None
25- DEBUG = False
2625
2726
2827class Arguments :
@@ -39,8 +38,7 @@ def __init__(
3938 seed : int ,
4039 precision : str ,
4140 device : str ,
42- load_vmfb : bool ,
43- save_vmfb : bool ,
41+ cache : bool ,
4442 iree_vulkan_target_triple : str ,
4543 live_preview : bool ,
4644 import_mlir : bool = False ,
@@ -57,8 +55,7 @@ def __init__(
5755 self .seed = seed
5856 self .precision = precision
5957 self .device = device
60- self .load_vmfb = load_vmfb
61- self .save_vmfb = save_vmfb
58+ self .cache = cache
6259 self .iree_vulkan_target_triple = iree_vulkan_target_triple
6360 self .live_preview = live_preview
6461 self .import_mlir = import_mlir
@@ -101,6 +98,37 @@ def get_models():
10198 return None , None
10299
103100
101+ schedulers = dict ()
102+ # set scheduler value
103+ schedulers ["PNDM" ] = PNDMScheduler (
104+ beta_start = 0.00085 ,
105+ beta_end = 0.012 ,
106+ beta_schedule = "scaled_linear" ,
107+ num_train_timesteps = 1000 ,
108+ )
109+ schedulers ["LMS" ] = LMSDiscreteScheduler (
110+ beta_start = 0.00085 ,
111+ beta_end = 0.012 ,
112+ beta_schedule = "scaled_linear" ,
113+ num_train_timesteps = 1000 ,
114+ )
115+ schedulers ["DDIM" ] = DDIMScheduler (
116+ beta_start = 0.00085 ,
117+ beta_end = 0.012 ,
118+ beta_schedule = "scaled_linear" ,
119+ clip_sample = False ,
120+ set_alpha_to_one = False ,
121+ )
122+
123+ cache_obj = dict ()
124+ cache_obj ["tokenizer" ] = CLIPTokenizer .from_pretrained (
125+ "openai/clip-vit-large-patch14"
126+ )
127+ cache_obj ["text_encoder" ] = CLIPTextModel .from_pretrained (
128+ "openai/clip-vit-large-patch14"
129+ )
130+
131+
104132def stable_diff_inf (
105133 prompt : str ,
106134 scheduler : str ,
@@ -113,21 +141,17 @@ def stable_diff_inf(
113141 seed : str ,
114142 precision : str ,
115143 device : str ,
116- load_vmfb : bool ,
117- save_vmfb : bool ,
144+ cache : bool ,
118145 iree_vulkan_target_triple : str ,
119146 live_preview : bool ,
120147):
121148
122149 global IREE_EXTRA_ARGS
123150 global args
124- global DEBUG
151+ global schedulers
152+ global cache_obj
125153
126- output_loc = f"stored_results/stable_diffusion/{ prompt } _{ int (steps )} _{ precision } _{ device } .jpg"
127- DEBUG = False
128- log_write = open (r"logs/stable_diffusion_log.txt" , "w" )
129- if log_write :
130- DEBUG = True
154+ output_loc = f"stored_results/stable_diffusion/{ time .time ()} _{ int (steps )} _{ precision } _{ device } .jpg"
131155
132156 # set seed value
133157 if seed == "" :
@@ -138,34 +162,7 @@ def stable_diff_inf(
138162 except ValueError :
139163 seed = hash (seed )
140164
141- # set scheduler value
142- if scheduler == "PNDM" :
143- scheduler = PNDMScheduler (
144- beta_start = 0.00085 ,
145- beta_end = 0.012 ,
146- beta_schedule = "scaled_linear" ,
147- num_train_timesteps = 1000 ,
148- )
149- elif scheduler == "LMS" :
150- scheduler = LMSDiscreteScheduler (
151- beta_start = 0.00085 ,
152- beta_end = 0.012 ,
153- beta_schedule = "scaled_linear" ,
154- num_train_timesteps = 1000 ,
155- )
156- elif scheduler == "DDIM" :
157- scheduler = DDIMScheduler (
158- beta_start = 0.00085 ,
159- beta_end = 0.012 ,
160- beta_schedule = "scaled_linear" ,
161- clip_sample = False ,
162- set_alpha_to_one = False ,
163- )
164- else :
165- raise Exception (
166- f"Does not support scheduler with name { args .scheduler } ."
167- )
168-
165+ scheduler = schedulers [scheduler ]
169166 args = Arguments (
170167 prompt ,
171168 scheduler ,
@@ -178,8 +175,7 @@ def stable_diff_inf(
178175 seed ,
179176 precision ,
180177 device ,
181- load_vmfb ,
182- save_vmfb ,
178+ cache ,
183179 iree_vulkan_target_triple ,
184180 live_preview ,
185181 )
@@ -194,11 +190,8 @@ def stable_diff_inf(
194190 ) # Seed generator to create the inital latent noise
195191
196192 vae , unet = get_models ()
197- tokenizer = CLIPTokenizer .from_pretrained ("openai/clip-vit-large-patch14" )
198- text_encoder = CLIPTextModel .from_pretrained (
199- "openai/clip-vit-large-patch14"
200- )
201-
193+ tokenizer = cache_obj ["tokenizer" ]
194+ text_encoder = cache_obj ["text_encoder" ]
202195 text_input = tokenizer (
203196 [args .prompt ] * batch_size ,
204197 padding = "max_length" ,
@@ -233,10 +226,10 @@ def stable_diff_inf(
233226
234227 avg_ms = 0
235228 out_img = None
229+ text_output = ""
236230 for i , t in tqdm (enumerate (scheduler .timesteps )):
237231
238- if DEBUG :
239- log_write .write (f"\n i = { i } t = { t } " )
232+ text_output = text_output + f"\n i = { i } t = { t } "
240233 step_start = time .time ()
241234 timestep = torch .tensor ([t ]).to (dtype ).detach ().numpy ()
242235 latents_numpy = latents .detach ().numpy ()
@@ -249,8 +242,7 @@ def stable_diff_inf(
249242 step_time = time .time () - step_start
250243 avg_ms += step_time
251244 step_ms = int ((step_time ) * 1000 )
252- if DEBUG :
253- log_write .write (f"time={ step_ms } ms" )
245+ text_output = text_output + f"time={ step_ms } ms"
254246 latents = scheduler .step (noise_pred , i , latents )["prev_sample" ]
255247
256248 if live_preview :
@@ -263,7 +255,7 @@ def stable_diff_inf(
263255 images = (image * 255 ).round ().astype ("uint8" )
264256 pil_images = [Image .fromarray (image ) for image in images ]
265257 out_img = pil_images [0 ]
266- yield out_img , ""
258+ yield out_img , text_output
267259
268260 # scale and decode the image latents with vae
269261 if not live_preview :
@@ -277,14 +269,8 @@ def stable_diff_inf(
277269 out_img = pil_images [0 ]
278270
279271 avg_ms = 1000 * avg_ms / args .steps
280- if DEBUG :
281- log_write .write (f"\n Average step time: { avg_ms } ms/it" )
272+ text_output = text_output + f"\n Average step time: { avg_ms } ms/it"
282273
283274 # save the output image with the prompt name.
284275 out_img .save (os .path .join (output_loc ))
285- log_write .close ()
286-
287- std_output = ""
288- with open (r"logs/stable_diffusion_log.txt" , "r" ) as log_read :
289- std_output = log_read .read ()
290- yield out_img , std_output
276+ yield out_img , text_output
0 commit comments