55import torch
66
77import PIL
8- from tqdm .auto import tqdm
98from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
109
1110from ...configuration_utils import FrozenDict
1716from .safety_checker import StableDiffusionSafetyChecker
1817
1918
20- logger = logging .get_logger (__name__ )
19+ logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
2120
2221
23- def preprocess_image (image ):
24- w , h = image .size
25- w , h = map (lambda x : x - x % 32 , (w , h )) # resize to integer multiple of 32
26- image = image .resize ((w , h ), resample = PIL .Image .LANCZOS )
27- image = np .array (image ).astype (np .float32 ) / 255.0
22+ def prepare_mask_and_masked_image (image , mask ):
23+ image = np .array (image .convert ("RGB" ))
2824 image = image [None ].transpose (0 , 3 , 1 , 2 )
29- image = torch .from_numpy (image )
30- return 2.0 * image - 1.0
31-
32-
33- def preprocess_mask (mask ):
34- mask = mask .convert ("L" )
35- w , h = mask .size
36- w , h = map (lambda x : x - x % 32 , (w , h )) # resize to integer multiple of 32
37- mask = mask .resize ((w // 8 , h // 8 ), resample = PIL .Image .NEAREST )
38- mask = np .array (mask ).astype (np .float32 ) / 255.0
39- mask = np .tile (mask , (4 , 1 , 1 ))
40- mask = mask [None ].transpose (0 , 1 , 2 , 3 ) # what does this step do?
41- mask = 1 - mask # repaint white, keep black
25+ image = torch .from_numpy (image ).to (dtype = torch .float32 ) / 127.5 - 1.0
26+
27+ mask = np .array (mask .convert ("L" ))
28+ mask = mask .astype (np .float32 ) / 255.0
29+ mask = mask [None , None ]
30+ mask [mask < 0.5 ] = 0
31+ mask [mask >= 0.5 ] = 1
4232 mask = torch .from_numpy (mask )
43- return mask
33+
34+ masked_image = image * (mask < 0.5 )
35+
36+ return mask , masked_image
4437
4538
4639class StableDiffusionInpaintPipeline (DiffusionPipeline ):
@@ -82,6 +75,7 @@ def __init__(
8275 feature_extractor : CLIPFeatureExtractor ,
8376 ):
8477 super ().__init__ ()
78+
8579 if hasattr (scheduler .config , "steps_offset" ) and scheduler .config .steps_offset != 1 :
8680 deprecation_message = (
8781 f"The configuration file of this scheduler: { scheduler } is outdated. `steps_offset`"
@@ -140,22 +134,24 @@ def disable_attention_slicing(self):
140134 Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
141135 back to computing attention in one step.
142136 """
143- # set slice_size = `None` to disable `set_attention_slice `
137+ # set slice_size = `None` to disable `attention slicing `
144138 self .enable_attention_slicing (None )
145139
146140 @torch .no_grad ()
147141 def __call__ (
148142 self ,
149143 prompt : Union [str , List [str ]],
150- init_image : Union [torch .FloatTensor , PIL .Image .Image ],
144+ image : Union [torch .FloatTensor , PIL .Image .Image ],
151145 mask_image : Union [torch .FloatTensor , PIL .Image .Image ],
152- strength : float = 0.8 ,
153- num_inference_steps : Optional [int ] = 50 ,
154- guidance_scale : Optional [float ] = 7.5 ,
146+ height : int = 512 ,
147+ width : int = 512 ,
148+ num_inference_steps : int = 50 ,
149+ guidance_scale : float = 7.5 ,
155150 negative_prompt : Optional [Union [str , List [str ]]] = None ,
156151 num_images_per_prompt : Optional [int ] = 1 ,
157- eta : Optional [ float ] = 0.0 ,
152+ eta : float = 0.0 ,
158153 generator : Optional [torch .Generator ] = None ,
154+ latents : Optional [torch .FloatTensor ] = None ,
159155 output_type : Optional [str ] = "pil" ,
160156 return_dict : bool = True ,
161157 callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
@@ -168,22 +164,21 @@ def __call__(
168164 Args:
169165 prompt (`str` or `List[str]`):
170166 The prompt or prompts to guide the image generation.
171- init_image (`torch.FloatTensor` or `PIL.Image.Image`):
172- `Image`, or tensor representing an image batch, that will be used as the starting point for the
173- process. This is the image whose masked region will be inpainted.
174- mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
175- `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
176- replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
177- PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
178- contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
179- strength (`float`, *optional*, defaults to 0.8):
180- Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
181- is 1, the denoising process will be run on the masked area for the full number of iterations specified
182- in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
183- noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
167+ image (`PIL.Image.Image`):
168+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
169+ be masked out with `mask_image` and repainted according to `prompt`.
170+ mask_image (`PIL.Image.Image`):
171+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
172+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
173+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
174+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
175+ height (`int`, *optional*, defaults to 512):
176+ The height in pixels of the generated image.
177+ width (`int`, *optional*, defaults to 512):
178+ The width in pixels of the generated image.
184179 num_inference_steps (`int`, *optional*, defaults to 50):
185- The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
186- the expense of slower inference. This parameter will be modulated by `strength`, as explained above .
180+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
181+ expense of slower inference.
187182 guidance_scale (`float`, *optional*, defaults to 7.5):
188183 Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
189184 `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -201,6 +196,10 @@ def __call__(
201196 generator (`torch.Generator`, *optional*):
202197 A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
203198 deterministic.
199+ latents (`torch.FloatTensor`, *optional*):
200+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
201+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
202+ tensor will ge generated by sampling using the supplied random `generator`.
204203 output_type (`str`, *optional*, defaults to `"pil"`):
205204 The output format of the generate image. Choose between
206205 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -221,7 +220,6 @@ def __call__(
221220 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
222221 (nsfw) content, according to the `safety_checker`.
223222 """
224- # TODO(Suraj) - adapt to your use case
225223
226224 if isinstance (prompt , str ):
227225 batch_size = 1
@@ -230,8 +228,8 @@ def __call__(
230228 else :
231229 raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
232230
233- if strength < 0 or strength > 1 :
234- raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
231+ if height % 8 != 0 or width % 8 != 0 :
232+ raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } . " )
235233
236234 if (callback_steps is None ) or (
237235 callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
@@ -241,9 +239,6 @@ def __call__(
241239 f" { type (callback_steps )} ."
242240 )
243241
244- # set timesteps
245- self .scheduler .set_timesteps (num_inference_steps )
246-
247242 # get prompt text embeddings
248243 text_inputs = self .tokenizer (
249244 prompt ,
@@ -262,8 +257,10 @@ def __call__(
262257 text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
263258 text_embeddings = self .text_encoder (text_input_ids .to (self .device ))[0 ]
264259
265- # duplicate text embeddings for each generation per prompt
266- text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
260+ # duplicate text embeddings for each generation per prompt, using mps friendly method
261+ bs_embed , seq_len , _ = text_embeddings .shape
262+ text_embeddings = text_embeddings .repeat (1 , num_images_per_prompt , 1 )
263+ text_embeddings = text_embeddings .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
267264
268265 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
269266 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -300,50 +297,78 @@ def __call__(
300297 )
301298 uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
302299
303- # duplicate unconditional embeddings for each generation per prompt
304- uncond_embeddings = uncond_embeddings .repeat_interleave (batch_size * num_images_per_prompt , dim = 0 )
300+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
301+ seq_len = uncond_embeddings .shape [1 ]
302+ uncond_embeddings = uncond_embeddings .repeat (batch_size , num_images_per_prompt , 1 )
303+ uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
305304
306305 # For classifier free guidance, we need to do two forward passes.
307306 # Here we concatenate the unconditional and text embeddings into a single batch
308307 # to avoid doing two forward passes
309308 text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
310309
311- # preprocess image
312- if not isinstance (init_image , torch .FloatTensor ):
313- init_image = preprocess_image (init_image )
314-
315- # encode the init image into latents and scale the latents
310+ # get the initial random noise unless the user supplied it
311+ # Unlike in other pipelines, latents need to be generated in the target device
312+ # for 1-to-1 results reproducibility with the CompVis implementation.
313+ # However this currently doesn't work in `mps`.
314+ num_channels_latents = self .vae .config .latent_channels
315+ latents_shape = (batch_size * num_images_per_prompt , num_channels_latents , height // 8 , width // 8 )
316316 latents_dtype = text_embeddings .dtype
317- init_image = init_image .to (device = self .device , dtype = latents_dtype )
318- init_latent_dist = self .vae .encode (init_image ).latent_dist
319- init_latents = init_latent_dist .sample (generator = generator )
320- init_latents = 0.18215 * init_latents
317+ if latents is None :
318+ if self .device .type == "mps" :
319+ # randn does not exist on mps
320+ latents = torch .randn (latents_shape , generator = generator , device = "cpu" , dtype = latents_dtype ).to (
321+ self .device
322+ )
323+ else :
324+ latents = torch .randn (latents_shape , generator = generator , device = self .device , dtype = latents_dtype )
325+ else :
326+ if latents .shape != latents_shape :
327+ raise ValueError (f"Unexpected latents shape, got { latents .shape } , expected { latents_shape } " )
328+ latents = latents .to (self .device )
329+
330+ # prepare mask and masked_image
331+ mask , masked_image = prepare_mask_and_masked_image (image , mask_image )
332+ mask = mask .to (device = self .device , dtype = text_embeddings .dtype )
333+ masked_image = masked_image .to (device = self .device , dtype = text_embeddings .dtype )
321334
322- # Expand init_latents for batch_size and num_images_per_prompt
323- init_latents = torch .cat ([init_latents ] * batch_size * num_images_per_prompt , dim = 0 )
324- init_latents_orig = init_latents
335+ # resize the mask to latents shape as we concatenate the mask to the latents
336+ mask = torch .nn .functional .interpolate (mask , size = (height // 8 , width // 8 ))
337+
338+ # encode the mask image into latents space so we can concatenate it to the latents
339+ masked_image_latents = self .vae .encode (masked_image ).latent_dist .sample (generator = generator )
340+ masked_image_latents = 0.18215 * masked_image_latents
341+
342+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
343+ mask = mask .repeat (num_images_per_prompt , 1 , 1 , 1 )
344+ masked_image_latents = masked_image_latents .repeat (num_images_per_prompt , 1 , 1 , 1 )
345+
346+ mask = torch .cat ([mask ] * 2 ) if do_classifier_free_guidance else mask
347+ masked_image_latents = (
348+ torch .cat ([masked_image_latents ] * 2 ) if do_classifier_free_guidance else masked_image_latents
349+ )
325350
326- # preprocess mask
327- if not isinstance (mask_image , torch .FloatTensor ):
328- mask_image = preprocess_mask (mask_image )
329- mask_image = mask_image .to (device = self .device , dtype = latents_dtype )
330- mask = torch .cat ([mask_image ] * batch_size * num_images_per_prompt )
351+ num_channels_mask = mask .shape [1 ]
352+ num_channels_masked_image = masked_image_latents .shape [1 ]
331353
332- # check sizes
333- if not mask .shape == init_latents .shape :
334- raise ValueError ("The mask and init_image should be the same size!" )
354+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self .unet .config .in_channels :
355+ raise ValueError (
356+ f"Incorrect configuration settings! The config of `pipeline.unet`: { self .unet .config } expects"
357+ f" { self .unet .config .in_channels } but received `num_channels_latents`: { num_channels_latents } +"
358+ f" `num_channels_mask`: { num_channels_mask } + `num_channels_masked_image`: { num_channels_masked_image } "
359+ f" = { num_channels_latents + num_channels_masked_image + num_channels_mask } . Please verify the config of"
360+ " `pipeline.unet` or your `mask_image` or `image` input."
361+ )
335362
336- # get the original timestep using init_timestep
337- offset = self .scheduler .config .get ("steps_offset" , 0 )
338- init_timestep = int (num_inference_steps * strength ) + offset
339- init_timestep = min (init_timestep , num_inference_steps )
363+ # set timesteps
364+ self .scheduler .set_timesteps (num_inference_steps )
340365
341- timesteps = self .scheduler .timesteps [- init_timestep ]
342- timesteps = torch .tensor ([timesteps ] * batch_size * num_images_per_prompt , device = self .device )
366+ # Some schedulers like PNDM have timesteps as arrays
367+ # It's more optimized to move all timesteps to correct device beforehand
368+ timesteps_tensor = self .scheduler .timesteps .to (self .device )
343369
344- # add noise to latents using the timesteps
345- noise = torch .randn (init_latents .shape , generator = generator , device = self .device , dtype = latents_dtype )
346- init_latents = self .scheduler .add_noise (init_latents , noise , timesteps )
370+ # scale the initial noise by the standard deviation required by the scheduler
371+ latents = latents * self .scheduler .init_noise_sigma
347372
348373 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
349374 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -354,17 +379,13 @@ def __call__(
354379 if accepts_eta :
355380 extra_step_kwargs ["eta" ] = eta
356381
357- latents = init_latents
358-
359- t_start = max (num_inference_steps - init_timestep + offset , 0 )
360-
361- # Some schedulers like PNDM have timesteps as arrays
362- # It's more optimized to move all timesteps to correct device beforehand
363- timesteps = self .scheduler .timesteps [t_start :].to (self .device )
364-
365- for i , t in tqdm (enumerate (timesteps )):
382+ for i , t in enumerate (self .progress_bar (timesteps_tensor )):
366383 # expand the latents if we are doing classifier free guidance
367384 latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
385+
386+ # concat latents, mask, masked_image_latents in the channel dimension
387+ latent_model_input = torch .cat ([latent_model_input , mask , masked_image_latents ], dim = 1 )
388+
368389 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
369390
370391 # predict the noise residual
@@ -377,10 +398,6 @@ def __call__(
377398
378399 # compute the previous noisy sample x_t -> x_t-1
379400 latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ).prev_sample
380- # masking
381- init_latents_proper = self .scheduler .add_noise (init_latents_orig , noise , torch .tensor ([t ]))
382-
383- latents = (init_latents_proper * mask ) + (latents * (1 - mask ))
384401
385402 # call the callback, if provided
386403 if callback is not None and i % callback_steps == 0 :
@@ -390,13 +407,17 @@ def __call__(
390407 image = self .vae .decode (latents ).sample
391408
392409 image = (image / 2 + 0.5 ).clamp (0 , 1 )
393- image = image .cpu ().permute (0 , 2 , 3 , 1 ).numpy ()
410+
411+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
412+ image = image .cpu ().permute (0 , 2 , 3 , 1 ).float ().numpy ()
394413
395414 if self .safety_checker is not None :
396415 safety_checker_input = self .feature_extractor (self .numpy_to_pil (image ), return_tensors = "pt" ).to (
397416 self .device
398417 )
399- image , has_nsfw_concept = self .safety_checker (images = image , clip_input = safety_checker_input .pixel_values )
418+ image , has_nsfw_concept = self .safety_checker (
419+ images = image , clip_input = safety_checker_input .pixel_values .to (text_embeddings .dtype )
420+ )
400421 else :
401422 has_nsfw_concept = None
402423
0 commit comments