diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 05ea84ae0326..806135f2f131 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -145,8 +145,9 @@ def __call__( process. This is the image whose masked region will be inpainted. mask_image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be - converted to a single channel (luminance) before use. + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` is 1, the denoising process will be run on the masked area for the full number of iterations specified @@ -202,10 +203,12 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) # preprocess image - init_image = preprocess_image(init_image).to(self.device) + if not isinstance(init_image, torch.FloatTensor): + init_image = preprocess_image(init_image) + init_image.to(self.device) # encode the init image into latents and scale the latents - init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist + init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample(generator=generator) init_latents = 0.18215 * init_latents @@ -215,8 +218,10 @@ def __call__( init_latents_orig = init_latents # preprocess mask - mask = preprocess_mask(mask_image).to(self.device) - mask = torch.cat([mask] * batch_size) + if not isinstance(mask_image, torch.FloatTensor): + mask_image = preprocess_mask(mask_image) + mask_image.to(self.device) + mask = torch.cat([mask_image] * batch_size) # check sizes if not mask.shape == init_latents.shape: