@@ -86,16 +86,6 @@ def end_profiling(device):
8686 # Scale for classifier-free guidance
8787 guidance_scale = torch .tensor (args .guidance_scale ).to (torch .float32 )
8888
89- # Handle out of range seeds.
90- uint32_info = np .iinfo (np .uint32 )
91- uint32_min , uint32_max = uint32_info .min , uint32_info .max
92- seed = args .seed
93- if seed < uint32_min or seed >= uint32_max :
94- seed = randint (uint32_min , uint32_max )
95- generator = torch .manual_seed (
96- seed
97- ) # Seed generator to create the inital latent noise
98-
9989 # TODO: Add support for batch_size > 1.
10090 batch_size = len (prompt )
10191 if batch_size != 1 :
@@ -144,139 +134,157 @@ def end_profiling(device):
144134 "stabilityai/stable-diffusion-2-1-base" ,
145135 subfolder = "scheduler" ,
146136 )
137+ for run in range (args .runs ):
138+ # Handle out of range seeds.
139+ uint32_info = np .iinfo (np .uint32 )
140+ uint32_min , uint32_max = uint32_info .min , uint32_info .max
141+ seed = args .seed
142+ if run >= 1 or seed < uint32_min or seed >= uint32_max :
143+ seed = randint (uint32_min , uint32_max )
144+ generator = torch .manual_seed (
145+ seed
146+ ) # Seed generator to create the inital latent noise
147+
148+ # create a random initial latent.
149+ latents = torch .randn (
150+ (batch_size , 4 , height // 8 , width // 8 ),
151+ generator = generator ,
152+ dtype = torch .float32 ,
153+ ).to (dtype )
154+ if run == 0 :
155+ # Warmup phase to improve performance.
156+ if args .warmup_count >= 1 :
157+ vae_warmup_input = torch .clone (latents ).detach ().numpy ()
158+ clip_warmup_input = torch .randint (1 , 2 , (2 , args .max_length ))
159+ for i in range (args .warmup_count ):
160+ vae ("forward" , (vae_warmup_input ,))
161+ clip ("forward" , (clip_warmup_input ,))
162+
163+ start = time .time ()
164+ if run == 0 :
165+ text_input = tokenizer (
166+ prompt ,
167+ padding = "max_length" ,
168+ max_length = args .max_length ,
169+ truncation = True ,
170+ return_tensors = "pt" ,
171+ )
172+ max_length = text_input .input_ids .shape [- 1 ]
173+ uncond_input = tokenizer (
174+ neg_prompt ,
175+ padding = "max_length" ,
176+ max_length = max_length ,
177+ truncation = True ,
178+ return_tensors = "pt" ,
179+ )
180+ text_input = torch .cat (
181+ [uncond_input .input_ids , text_input .input_ids ]
182+ )
147183
148- # create a random initial latent.
149- latents = torch .randn (
150- (batch_size , 4 , height // 8 , width // 8 ),
151- generator = generator ,
152- dtype = torch .float32 ,
153- ).to (dtype )
154- # Warmup phase to improve performance.
155- if args .warmup_count >= 1 :
156- vae_warmup_input = torch .clone (latents ).detach ().numpy ()
157- clip_warmup_input = torch .randint (1 , 2 , (2 , args .max_length ))
158- for i in range (args .warmup_count ):
159- vae ("forward" , (vae_warmup_input ,))
160- clip ("forward" , (clip_warmup_input ,))
161-
162- start = time .time ()
163-
164- text_input = tokenizer (
165- prompt ,
166- padding = "max_length" ,
167- max_length = args .max_length ,
168- truncation = True ,
169- return_tensors = "pt" ,
170- )
171- max_length = text_input .input_ids .shape [- 1 ]
172- uncond_input = tokenizer (
173- neg_prompt ,
174- padding = "max_length" ,
175- max_length = max_length ,
176- truncation = True ,
177- return_tensors = "pt" ,
178- )
179- text_input = torch .cat ([uncond_input .input_ids , text_input .input_ids ])
180-
181- clip_inf_start = time .time ()
182- text_embeddings = clip ("forward" , (text_input ,))
183- clip_inf_end = time .time ()
184- text_embeddings = torch .from_numpy (text_embeddings ).to (dtype )
185- text_embeddings_numpy = text_embeddings .detach ().numpy ()
186-
187- scheduler .set_timesteps (num_inference_steps )
188- scheduler .is_scale_input_called = True
189-
190- latents = latents * scheduler .init_noise_sigma
191-
192- avg_ms = 0
193- for i , t in tqdm (enumerate (scheduler .timesteps ), disable = args .hide_steps ):
194- step_start = time .time ()
195- if not args .hide_steps :
196- print (f"i = { i } t = { t } " , end = "" )
197- timestep = torch .tensor ([t ]).to (dtype ).detach ().numpy ()
198- latent_model_input = scheduler .scale_model_input (latents , t )
199- if cpu_scheduling :
200- latent_model_input = latent_model_input .detach ().numpy ()
201-
202- profile_device = start_profiling (file_path = "unet.rdc" )
203-
204- noise_pred = unet (
205- "forward" ,
206- (
207- latent_model_input ,
208- timestep ,
209- text_embeddings_numpy ,
210- guidance_scale ,
211- ),
212- send_to_host = False ,
213- )
214-
215- end_profiling (profile_device )
184+ clip_inf_start = time .time ()
185+ text_embeddings = clip ("forward" , (text_input ,))
186+ clip_inf_end = time .time ()
187+ text_embeddings = torch .from_numpy (text_embeddings ).to (dtype )
188+ text_embeddings_numpy = text_embeddings .detach ().numpy ()
189+
190+ scheduler .set_timesteps (num_inference_steps )
191+ scheduler .is_scale_input_called = True
192+
193+ latents = latents * scheduler .init_noise_sigma
194+
195+ avg_ms = 0
196+ for i , t in tqdm (
197+ enumerate (scheduler .timesteps ), disable = args .hide_steps
198+ ):
199+ step_start = time .time ()
200+ if not args .hide_steps :
201+ print (f"i = { i } t = { t } " , end = "" )
202+ timestep = torch .tensor ([t ]).to (dtype ).detach ().numpy ()
203+ latent_model_input = scheduler .scale_model_input (latents , t )
204+ if cpu_scheduling :
205+ latent_model_input = latent_model_input .detach ().numpy ()
206+
207+ profile_device = start_profiling (file_path = "unet.rdc" )
208+
209+ noise_pred = unet (
210+ "forward" ,
211+ (
212+ latent_model_input ,
213+ timestep ,
214+ text_embeddings_numpy ,
215+ guidance_scale ,
216+ ),
217+ send_to_host = False ,
218+ )
216219
220+ end_profiling (profile_device )
221+
222+ if cpu_scheduling :
223+ noise_pred = torch .from_numpy (noise_pred .to_host ())
224+ latents = scheduler .step (noise_pred , t , latents ).prev_sample
225+ else :
226+ latents = scheduler .step (noise_pred , t , latents )
227+ step_time = time .time () - step_start
228+ avg_ms += step_time
229+ step_ms = int ((step_time ) * 1000 )
230+ if not args .hide_steps :
231+ print (f" ({ step_ms } ms)" )
232+
233+ # scale and decode the image latents with vae
234+ if args .use_base_vae :
235+ latents = 1 / 0.18215 * latents
236+ latents_numpy = latents
217237 if cpu_scheduling :
218- noise_pred = torch .from_numpy (noise_pred .to_host ())
219- latents = scheduler .step (noise_pred , t , latents ).prev_sample
238+ latents_numpy = latents .detach ().numpy ()
239+ profile_device = start_profiling (file_path = "vae.rdc" )
240+ vae_start = time .time ()
241+ images = vae ("forward" , (latents_numpy ,))
242+ vae_end = time .time ()
243+ end_profiling (profile_device )
244+ if args .use_base_vae :
245+ image = torch .from_numpy (images )
246+ image = (image .detach ().cpu () * 255.0 ).numpy ()
247+ images = image .round ()
248+ end_time = time .time ()
249+
250+ avg_ms = 1000 * avg_ms / args .steps
251+ clip_inf_time = (clip_inf_end - clip_inf_start ) * 1000
252+ vae_inf_time = (vae_end - vae_start ) * 1000
253+ total_time = end_time - start
254+
255+ print (f"\n Stats for run { run } :" )
256+ print (f"Average step time: { avg_ms } ms/it" )
257+ print (f"Clip Inference time (ms) = { clip_inf_time :.3f} " )
258+ print (f"VAE Inference time (ms): { vae_inf_time :.3f} " )
259+ print (f"\n Total image generation time: { total_time } sec" )
260+
261+ transform = T .ToPILImage ()
262+ pil_images = [
263+ transform (image )
264+ for image in torch .from_numpy (images ).to (torch .uint8 )
265+ ]
266+
267+ if args .output_dir is not None :
268+ output_path = Path (args .output_dir )
269+ output_path .mkdir (parents = True , exist_ok = True )
220270 else :
221- latents = scheduler .step (noise_pred , t , latents )
222- step_time = time .time () - step_start
223- avg_ms += step_time
224- step_ms = int ((step_time ) * 1000 )
225- if not args .hide_steps :
226- print (f" ({ step_ms } ms)" )
227-
228- # scale and decode the image latents with vae
229- if args .use_base_vae :
230- latents = 1 / 0.18215 * latents
231- latents_numpy = latents
232- if cpu_scheduling :
233- latents_numpy = latents .detach ().numpy ()
234- profile_device = start_profiling (file_path = "vae.rdc" )
235- vae_start = time .time ()
236- images = vae ("forward" , (latents_numpy ,))
237- vae_end = time .time ()
238- end_profiling (profile_device )
239- if args .use_base_vae :
240- image = torch .from_numpy (images )
241- image = (image .detach ().cpu () * 255.0 ).numpy ()
242- images = image .round ()
243- end_time = time .time ()
244-
245- avg_ms = 1000 * avg_ms / args .steps
246- clip_inf_time = (clip_inf_end - clip_inf_start ) * 1000
247- vae_inf_time = (vae_end - vae_start ) * 1000
248- total_time = end_time - start
249- print (f"\n Average step time: { avg_ms } ms/it" )
250- print (f"Clip Inference time (ms) = { clip_inf_time :.3f} " )
251- print (f"VAE Inference time (ms): { vae_inf_time :.3f} " )
252- print (f"\n Total image generation time: { total_time } sec" )
253-
254- transform = T .ToPILImage ()
255- pil_images = [
256- transform (image ) for image in torch .from_numpy (images ).to (torch .uint8 )
257- ]
258-
259- if args .output_dir is not None :
260- output_path = Path (args .output_dir )
261- output_path .mkdir (parents = True , exist_ok = True )
262- else :
263- output_path = Path .cwd ()
264- disk_space_check (output_path , lim = 5 )
265- for i in range (batch_size ):
266- json_store = {
267- "prompt" : args .prompts [i ],
268- "negative prompt" : args .negative_prompts [i ],
269- "seed" : args .seed ,
270- "variant" : args .variant ,
271- "precision" : args .precision ,
272- "steps" : args .steps ,
273- "guidance_scale" : args .guidance_scale ,
274- "scheduler" : args .scheduler ,
275- }
276- prompt_slice = re .sub ("[^a-zA-Z0-9]" , "_" , args .prompts [i ][:15 ])
277- img_name = f"{ prompt_slice } _{ args .seed } _{ i } _{ dt .now ().strftime ('%y%m%d_%H%M%S' )} "
278- pil_images [i ].save (
279- output_path / f"{ img_name } .jpg" , quality = 95 , subsampling = 0
280- )
281- with open (output_path / f"{ img_name } .json" , "w" ) as f :
282- f .write (json .dumps (json_store , indent = 4 ))
271+ output_path = Path .cwd ()
272+ disk_space_check (output_path , lim = 5 )
273+ for i in range (batch_size ):
274+ json_store = {
275+ "prompt" : args .prompts [i ],
276+ "negative prompt" : args .negative_prompts [i ],
277+ "seed" : args .seed ,
278+ "variant" : args .variant ,
279+ "precision" : args .precision ,
280+ "steps" : args .steps ,
281+ "guidance_scale" : args .guidance_scale ,
282+ "scheduler" : args .scheduler ,
283+ }
284+ prompt_slice = re .sub ("[^a-zA-Z0-9]" , "_" , args .prompts [i ][:15 ])
285+ img_name = f"{ prompt_slice } _{ args .seed } _{ run } _{ dt .now ().strftime ('%y%m%d_%H%M%S' )} "
286+ pil_images [i ].save (
287+ output_path / f"{ img_name } .jpg" , quality = 95 , subsampling = 0
288+ )
289+ with open (output_path / f"{ img_name } .json" , "w" ) as f :
290+ f .write (json .dumps (json_store , indent = 4 ))
0 commit comments