Skip to content

Commit b35d88c

Browse files
patil-surajpatrickvonplatenanton-l
authored
Stable diffusion inpainting. (#904)
* begin pipe * add new pipeline * add tests * correct fast test * up * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py * Update tests/test_pipelines.py * up * up * make style * add fp16 test * doc, comments * up Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Anton Lozhkov <[email protected]>
1 parent 83b696e commit b35d88c

File tree

2 files changed

+333
-111
lines changed

2 files changed

+333
-111
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 117 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66

77
import PIL
8-
from tqdm.auto import tqdm
98
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
109

1110
from ...configuration_utils import FrozenDict
@@ -17,30 +16,24 @@
1716
from .safety_checker import StableDiffusionSafetyChecker
1817

1918

20-
logger = logging.get_logger(__name__)
19+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2120

2221

23-
def preprocess_image(image):
24-
w, h = image.size
25-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
26-
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
27-
image = np.array(image).astype(np.float32) / 255.0
22+
def prepare_mask_and_masked_image(image, mask):
23+
image = np.array(image.convert("RGB"))
2824
image = image[None].transpose(0, 3, 1, 2)
29-
image = torch.from_numpy(image)
30-
return 2.0 * image - 1.0
31-
32-
33-
def preprocess_mask(mask):
34-
mask = mask.convert("L")
35-
w, h = mask.size
36-
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
37-
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
38-
mask = np.array(mask).astype(np.float32) / 255.0
39-
mask = np.tile(mask, (4, 1, 1))
40-
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
41-
mask = 1 - mask # repaint white, keep black
25+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
26+
27+
mask = np.array(mask.convert("L"))
28+
mask = mask.astype(np.float32) / 255.0
29+
mask = mask[None, None]
30+
mask[mask < 0.5] = 0
31+
mask[mask >= 0.5] = 1
4232
mask = torch.from_numpy(mask)
43-
return mask
33+
34+
masked_image = image * (mask < 0.5)
35+
36+
return mask, masked_image
4437

4538

4639
class StableDiffusionInpaintPipeline(DiffusionPipeline):
@@ -82,6 +75,7 @@ def __init__(
8275
feature_extractor: CLIPFeatureExtractor,
8376
):
8477
super().__init__()
78+
8579
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
8680
deprecation_message = (
8781
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
@@ -140,22 +134,24 @@ def disable_attention_slicing(self):
140134
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
141135
back to computing attention in one step.
142136
"""
143-
# set slice_size = `None` to disable `set_attention_slice`
137+
# set slice_size = `None` to disable `attention slicing`
144138
self.enable_attention_slicing(None)
145139

146140
@torch.no_grad()
147141
def __call__(
148142
self,
149143
prompt: Union[str, List[str]],
150-
init_image: Union[torch.FloatTensor, PIL.Image.Image],
144+
image: Union[torch.FloatTensor, PIL.Image.Image],
151145
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
152-
strength: float = 0.8,
153-
num_inference_steps: Optional[int] = 50,
154-
guidance_scale: Optional[float] = 7.5,
146+
height: int = 512,
147+
width: int = 512,
148+
num_inference_steps: int = 50,
149+
guidance_scale: float = 7.5,
155150
negative_prompt: Optional[Union[str, List[str]]] = None,
156151
num_images_per_prompt: Optional[int] = 1,
157-
eta: Optional[float] = 0.0,
152+
eta: float = 0.0,
158153
generator: Optional[torch.Generator] = None,
154+
latents: Optional[torch.FloatTensor] = None,
159155
output_type: Optional[str] = "pil",
160156
return_dict: bool = True,
161157
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -168,22 +164,21 @@ def __call__(
168164
Args:
169165
prompt (`str` or `List[str]`):
170166
The prompt or prompts to guide the image generation.
171-
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
172-
`Image`, or tensor representing an image batch, that will be used as the starting point for the
173-
process. This is the image whose masked region will be inpainted.
174-
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
175-
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
176-
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
177-
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
178-
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
179-
strength (`float`, *optional*, defaults to 0.8):
180-
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
181-
is 1, the denoising process will be run on the masked area for the full number of iterations specified
182-
in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
183-
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
167+
image (`PIL.Image.Image`):
168+
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
169+
be masked out with `mask_image` and repainted according to `prompt`.
170+
mask_image (`PIL.Image.Image`):
171+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
172+
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
173+
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
174+
instead of 3, so the expected shape would be `(B, H, W, 1)`.
175+
height (`int`, *optional*, defaults to 512):
176+
The height in pixels of the generated image.
177+
width (`int`, *optional*, defaults to 512):
178+
The width in pixels of the generated image.
184179
num_inference_steps (`int`, *optional*, defaults to 50):
185-
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
186-
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
180+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
181+
expense of slower inference.
187182
guidance_scale (`float`, *optional*, defaults to 7.5):
188183
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
189184
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -201,6 +196,10 @@ def __call__(
201196
generator (`torch.Generator`, *optional*):
202197
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
203198
deterministic.
199+
latents (`torch.FloatTensor`, *optional*):
200+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
201+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
202+
tensor will ge generated by sampling using the supplied random `generator`.
204203
output_type (`str`, *optional*, defaults to `"pil"`):
205204
The output format of the generate image. Choose between
206205
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -221,7 +220,6 @@ def __call__(
221220
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
222221
(nsfw) content, according to the `safety_checker`.
223222
"""
224-
# TODO(Suraj) - adapt to your use case
225223

226224
if isinstance(prompt, str):
227225
batch_size = 1
@@ -230,8 +228,8 @@ def __call__(
230228
else:
231229
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
232230

233-
if strength < 0 or strength > 1:
234-
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
231+
if height % 8 != 0 or width % 8 != 0:
232+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
235233

236234
if (callback_steps is None) or (
237235
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
@@ -241,9 +239,6 @@ def __call__(
241239
f" {type(callback_steps)}."
242240
)
243241

244-
# set timesteps
245-
self.scheduler.set_timesteps(num_inference_steps)
246-
247242
# get prompt text embeddings
248243
text_inputs = self.tokenizer(
249244
prompt,
@@ -262,8 +257,10 @@ def __call__(
262257
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
263258
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
264259

265-
# duplicate text embeddings for each generation per prompt
266-
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
260+
# duplicate text embeddings for each generation per prompt, using mps friendly method
261+
bs_embed, seq_len, _ = text_embeddings.shape
262+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
263+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
267264

268265
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
269266
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -300,50 +297,78 @@ def __call__(
300297
)
301298
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
302299

303-
# duplicate unconditional embeddings for each generation per prompt
304-
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
300+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
301+
seq_len = uncond_embeddings.shape[1]
302+
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
303+
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
305304

306305
# For classifier free guidance, we need to do two forward passes.
307306
# Here we concatenate the unconditional and text embeddings into a single batch
308307
# to avoid doing two forward passes
309308
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
310309

311-
# preprocess image
312-
if not isinstance(init_image, torch.FloatTensor):
313-
init_image = preprocess_image(init_image)
314-
315-
# encode the init image into latents and scale the latents
310+
# get the initial random noise unless the user supplied it
311+
# Unlike in other pipelines, latents need to be generated in the target device
312+
# for 1-to-1 results reproducibility with the CompVis implementation.
313+
# However this currently doesn't work in `mps`.
314+
num_channels_latents = self.vae.config.latent_channels
315+
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
316316
latents_dtype = text_embeddings.dtype
317-
init_image = init_image.to(device=self.device, dtype=latents_dtype)
318-
init_latent_dist = self.vae.encode(init_image).latent_dist
319-
init_latents = init_latent_dist.sample(generator=generator)
320-
init_latents = 0.18215 * init_latents
317+
if latents is None:
318+
if self.device.type == "mps":
319+
# randn does not exist on mps
320+
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
321+
self.device
322+
)
323+
else:
324+
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
325+
else:
326+
if latents.shape != latents_shape:
327+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
328+
latents = latents.to(self.device)
329+
330+
# prepare mask and masked_image
331+
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
332+
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
333+
masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
321334

322-
# Expand init_latents for batch_size and num_images_per_prompt
323-
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
324-
init_latents_orig = init_latents
335+
# resize the mask to latents shape as we concatenate the mask to the latents
336+
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
337+
338+
# encode the mask image into latents space so we can concatenate it to the latents
339+
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
340+
masked_image_latents = 0.18215 * masked_image_latents
341+
342+
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
343+
mask = mask.repeat(num_images_per_prompt, 1, 1, 1)
344+
masked_image_latents = masked_image_latents.repeat(num_images_per_prompt, 1, 1, 1)
345+
346+
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
347+
masked_image_latents = (
348+
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
349+
)
325350

326-
# preprocess mask
327-
if not isinstance(mask_image, torch.FloatTensor):
328-
mask_image = preprocess_mask(mask_image)
329-
mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
330-
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
351+
num_channels_mask = mask.shape[1]
352+
num_channels_masked_image = masked_image_latents.shape[1]
331353

332-
# check sizes
333-
if not mask.shape == init_latents.shape:
334-
raise ValueError("The mask and init_image should be the same size!")
354+
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
355+
raise ValueError(
356+
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
357+
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
358+
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
359+
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
360+
" `pipeline.unet` or your `mask_image` or `image` input."
361+
)
335362

336-
# get the original timestep using init_timestep
337-
offset = self.scheduler.config.get("steps_offset", 0)
338-
init_timestep = int(num_inference_steps * strength) + offset
339-
init_timestep = min(init_timestep, num_inference_steps)
363+
# set timesteps
364+
self.scheduler.set_timesteps(num_inference_steps)
340365

341-
timesteps = self.scheduler.timesteps[-init_timestep]
342-
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
366+
# Some schedulers like PNDM have timesteps as arrays
367+
# It's more optimized to move all timesteps to correct device beforehand
368+
timesteps_tensor = self.scheduler.timesteps.to(self.device)
343369

344-
# add noise to latents using the timesteps
345-
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
346-
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
370+
# scale the initial noise by the standard deviation required by the scheduler
371+
latents = latents * self.scheduler.init_noise_sigma
347372

348373
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
349374
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -354,17 +379,13 @@ def __call__(
354379
if accepts_eta:
355380
extra_step_kwargs["eta"] = eta
356381

357-
latents = init_latents
358-
359-
t_start = max(num_inference_steps - init_timestep + offset, 0)
360-
361-
# Some schedulers like PNDM have timesteps as arrays
362-
# It's more optimized to move all timesteps to correct device beforehand
363-
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
364-
365-
for i, t in tqdm(enumerate(timesteps)):
382+
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
366383
# expand the latents if we are doing classifier free guidance
367384
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
385+
386+
# concat latents, mask, masked_image_latents in the channel dimension
387+
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
388+
368389
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
369390

370391
# predict the noise residual
@@ -377,10 +398,6 @@ def __call__(
377398

378399
# compute the previous noisy sample x_t -> x_t-1
379400
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
380-
# masking
381-
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
382-
383-
latents = (init_latents_proper * mask) + (latents * (1 - mask))
384401

385402
# call the callback, if provided
386403
if callback is not None and i % callback_steps == 0:
@@ -390,13 +407,17 @@ def __call__(
390407
image = self.vae.decode(latents).sample
391408

392409
image = (image / 2 + 0.5).clamp(0, 1)
393-
image = image.cpu().permute(0, 2, 3, 1).numpy()
410+
411+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
412+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
394413

395414
if self.safety_checker is not None:
396415
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
397416
self.device
398417
)
399-
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
418+
image, has_nsfw_concept = self.safety_checker(
419+
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
420+
)
400421
else:
401422
has_nsfw_concept = None
402423

0 commit comments

Comments
 (0)