@@ -298,6 +298,73 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
298298
299299 return text_embeddings
300300
301+ def run_safety_checker (self , image , device , dtype ):
302+ if self .safety_checker is not None :
303+ safety_checker_input = self .feature_extractor (self .numpy_to_pil (image ), return_tensors = "pt" ).to (device )
304+ image , has_nsfw_concept = self .safety_checker (
305+ images = image , clip_input = safety_checker_input .pixel_values .to (dtype )
306+ )
307+ else :
308+ has_nsfw_concept = None
309+ return image , has_nsfw_concept
310+
311+ def decode_latents (self , latents ):
312+ latents = 1 / 0.18215 * latents
313+ image = self .vae .decode (latents ).sample
314+ image = (image / 2 + 0.5 ).clamp (0 , 1 )
315+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
316+ image = image .cpu ().permute (0 , 2 , 3 , 1 ).float ().numpy ()
317+ return image
318+
319+ def prepare_extra_step_kwargs (self , generator , eta ):
320+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
321+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
322+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
323+ # and should be between [0, 1]
324+
325+ accepts_eta = "eta" in set (inspect .signature (self .scheduler .step ).parameters .keys ())
326+ extra_step_kwargs = {}
327+ if accepts_eta :
328+ extra_step_kwargs ["eta" ] = eta
329+
330+ # check if the scheduler accepts generator
331+ accepts_generator = "generator" in set (inspect .signature (self .scheduler .step ).parameters .keys ())
332+ if accepts_generator :
333+ extra_step_kwargs ["generator" ] = generator
334+ return extra_step_kwargs
335+
336+ def check_inputs (self , prompt , height , width , callback_steps ):
337+ if not isinstance (prompt , str ) and not isinstance (prompt , list ):
338+ raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
339+
340+ if height % 8 != 0 or width % 8 != 0 :
341+ raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
342+
343+ if (callback_steps is None ) or (
344+ callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
345+ ):
346+ raise ValueError (
347+ f"`callback_steps` has to be a positive integer but is { callback_steps } of type"
348+ f" { type (callback_steps )} ."
349+ )
350+
351+ def prepare_latents (self , batch_size , num_channels_latents , height , width , dtype , device , generator , latents = None ):
352+ shape = (batch_size , num_channels_latents , height // 8 , width // 8 )
353+ if latents is None :
354+ if device .type == "mps" :
355+ # randn does not work reproducibly on mps
356+ latents = torch .randn (shape , generator = generator , device = "cpu" , dtype = dtype ).to (device )
357+ else :
358+ latents = torch .randn (shape , generator = generator , device = device , dtype = dtype )
359+ else :
360+ if latents .shape != shape :
361+ raise ValueError (f"Unexpected latents shape, got { latents .shape } , expected { shape } " )
362+ latents = latents .to (device )
363+
364+ # scale the initial noise by the standard deviation required by the scheduler
365+ latents = latents * self .scheduler .init_noise_sigma
366+ return latents
367+
301368 @torch .no_grad ()
302369 def __call__ (
303370 self ,
@@ -371,75 +438,45 @@ def __call__(
371438 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
372439 (nsfw) content, according to the `safety_checker`.
373440 """
374- if isinstance (prompt , str ):
375- batch_size = 1
376- elif isinstance (prompt , list ):
377- batch_size = len (prompt )
378- else :
379- raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
380441
381- if height % 8 != 0 or width % 8 != 0 :
382- raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
383-
384- if (callback_steps is None ) or (
385- callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
386- ):
387- raise ValueError (
388- f"`callback_steps` has to be a positive integer but is { callback_steps } of type"
389- f" { type (callback_steps )} ."
390- )
442+ # 1. Check inputs. Raise error if not correct
443+ self .check_inputs (prompt , height , width , callback_steps )
391444
445+ # 2. Define call parameters
446+ batch_size = 1 if isinstance (prompt , str ) else len (prompt )
392447 device = self ._execution_device
393-
394448 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
395449 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
396450 # corresponds to doing no classifier free guidance.
397451 do_classifier_free_guidance = guidance_scale > 1.0
398452
453+ # 3. Encode input prompt
399454 text_embeddings = self ._encode_prompt (
400455 prompt , device , num_images_per_prompt , do_classifier_free_guidance , negative_prompt
401456 )
402457
403- # Unlike in other pipelines, latents need to be generated in the target device
404- # for 1-to-1 results reproducibility with the CompVis implementation.
405- # However this currently doesn't work in `mps`.
406-
407- # get the initial random noise unless the user supplied it
408- latents_shape = (batch_size * num_images_per_prompt , self .unet .in_channels , height // 8 , width // 8 )
409- latents_dtype = text_embeddings .dtype
410- if latents is None :
411- if device .type == "mps" :
412- # randn does not work reproducibly on mps
413- latents = torch .randn (latents_shape , generator = generator , device = "cpu" , dtype = latents_dtype ).to (device )
414- else :
415- latents = torch .randn (latents_shape , generator = generator , device = device , dtype = latents_dtype )
416- else :
417- if latents .shape != latents_shape :
418- raise ValueError (f"Unexpected latents shape, got { latents .shape } , expected { latents_shape } " )
419- latents = latents .to (device )
420-
421- # set timesteps and move to the correct device
458+ # 4. Prepare timesteps
422459 self .scheduler .set_timesteps (num_inference_steps , device = device )
423- timesteps_tensor = self .scheduler .timesteps
424-
425- # scale the initial noise by the standard deviation required by the scheduler
426- latents = latents * self .scheduler .init_noise_sigma
427-
428- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
429- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
430- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
431- # and should be between [0, 1]
432- accepts_eta = "eta" in set (inspect .signature (self .scheduler .step ).parameters .keys ())
433- extra_step_kwargs = {}
434- if accepts_eta :
435- extra_step_kwargs ["eta" ] = eta
460+ timesteps = self .scheduler .timesteps
461+
462+ # 5. Prepare latent variables
463+ num_channels_latents = self .unet .in_channels
464+ latents = self .prepare_latents (
465+ batch_size * num_images_per_prompt ,
466+ num_channels_latents ,
467+ height ,
468+ width ,
469+ text_embeddings .dtype ,
470+ device ,
471+ generator ,
472+ latents ,
473+ )
436474
437- # check if the scheduler accepts generator
438- accepts_generator = "generator" in set (inspect .signature (self .scheduler .step ).parameters .keys ())
439- if accepts_generator :
440- extra_step_kwargs ["generator" ] = generator
475+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
476+ extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
441477
442- for i , t in enumerate (self .progress_bar (timesteps_tensor )):
478+ # 7. Denoising loop
479+ for i , t in enumerate (self .progress_bar (timesteps )):
443480 # expand the latents if we are doing classifier free guidance
444481 latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
445482 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
@@ -459,22 +496,13 @@ def __call__(
459496 if callback is not None and i % callback_steps == 0 :
460497 callback (i , t , latents )
461498
462- latents = 1 / 0.18215 * latents
463- image = self .vae .decode (latents ).sample
464-
465- image = (image / 2 + 0.5 ).clamp (0 , 1 )
466-
467- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
468- image = image .cpu ().permute (0 , 2 , 3 , 1 ).float ().numpy ()
499+ # 8. Post-processing
500+ image = self .decode_latents (latents )
469501
470- if self .safety_checker is not None :
471- safety_checker_input = self .feature_extractor (self .numpy_to_pil (image ), return_tensors = "pt" ).to (device )
472- image , has_nsfw_concept = self .safety_checker (
473- images = image , clip_input = safety_checker_input .pixel_values .to (text_embeddings .dtype )
474- )
475- else :
476- has_nsfw_concept = None
502+ # 9. Run safety checker
503+ image , has_nsfw_concept = self .run_safety_checker (image , device , text_embeddings .dtype )
477504
505+ # 10. Convert to PIL
478506 if output_type == "pil" :
479507 image = self .numpy_to_pil (image )
480508
0 commit comments