Skip to content

Commit 9c6bf05

Browse files
committed
comments from PR review for imagic stable diffusion
1 parent 10349e5 commit 9c6bf05

File tree

1 file changed

+179
-39
lines changed

1 file changed

+179
-39
lines changed

examples/community/imagic_stable_diffusion.py

Lines changed: 179 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
import warnings
66
from typing import List, Optional, Union
7+
import inspect
78

89
import numpy as np
910
import torch
@@ -19,6 +20,9 @@
1920
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
2021
from tqdm.auto import tqdm
2122
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
23+
from diffusers.utils import logging
24+
25+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2226

2327

2428
def freeze_params(params):
@@ -115,18 +119,18 @@ def disable_attention_slicing(self):
115119
# set slice_size = `None` to disable `attention slicing`
116120
self.enable_attention_slicing(None)
117121

118-
# @torch.no_grad()
119-
def __call__(
122+
123+
def train(
120124
self,
121125
prompt: Union[str, List[str]],
122126
init_image: Union[torch.FloatTensor, PIL.Image.Image],
123-
alpha: float = 1.2,
124127
height: Optional[int] = 512,
125128
width: Optional[int] = 512,
126-
num_inference_steps: Optional[int] = 50,
127-
guidance_scale: Optional[float] = 7.5,
128129
generator: Optional[torch.Generator] = None,
129-
return_dict: bool = True,
130+
embedding_learning_rate: float = 0.001,
131+
diffusion_model_learning_rate: float = 2e-6,
132+
text_embedding_optimization_steps: int = 500,
133+
model_fine_tuning_optimization_steps: int = 1000,
130134
**kwargs,
131135
):
132136
r"""
@@ -170,11 +174,6 @@ def __call__(
170174
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
171175
(nsfw) content, according to the `safety_checker`.
172176
"""
173-
embedding_learning_rate = 0.001
174-
diffusion_model_learning_rate = 2e-6
175-
text_embedding_optimization_steps = 500
176-
model_fine_tuning_optimization_steps = 1000
177-
178177
accelerator = Accelerator(
179178
gradient_accumulation_steps=1,
180179
mixed_precision="fp16",
@@ -197,9 +196,9 @@ def __call__(
197196
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
198197

199198
# Freeze vae and unet
200-
freeze_params(self.vae.parameters())
201-
freeze_params(self.unet.parameters())
202-
freeze_params(self.text_encoder.parameters())
199+
self.vae.requires_grad_(False)
200+
self.unet.requires_grad_(False)
201+
self.text_encoder.requires_grad_(False)
203202
self.unet.eval()
204203
self.vae.eval()
205204
self.text_encoder.eval()
@@ -243,23 +242,12 @@ def __call__(
243242
init_image_latents = init_latent_image_dist.sample(generator=generator)
244243
init_image_latents = 0.18215 * init_image_latents
245244

246-
pipeline = StableDiffusionPipeline(
247-
text_encoder=self.text_encoder,
248-
vae=self.vae,
249-
unet=self.unet,
250-
tokenizer=self.tokenizer,
251-
scheduler=self.scheduler,
252-
safety_checker=self.safety_checker,
253-
feature_extractor=self.feature_extractor,
254-
)
255-
pipeline = pipeline.to("cuda")
256-
257245
progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process)
258246
progress_bar.set_description("Steps")
259247

260248
global_step = 0
261249

262-
print("First optimizing the text embedding to better reconstruct the init image")
250+
logger.info("First optimizing the text embedding to better reconstruct the init image")
263251
for _ in range(text_embedding_optimization_steps):
264252
with accelerator.accumulate(text_embeddings):
265253
# Sample noise that we'll add to the latents
@@ -291,18 +279,17 @@ def __call__(
291279
accelerator.wait_for_everyone()
292280

293281
text_embeddings.requires_grad_(False)
294-
freeze_params(text_embeddings)
295282

296283
# Now we fine tune the unet to better reconstruct the image
297-
unfreeze_params(self.unet.parameters())
284+
self.unet.requires_grad_(True)
298285
self.unet.train()
299286
optimizer = torch.optim.Adam(
300287
self.unet.parameters(), # only optimize unet
301288
lr=diffusion_model_learning_rate,
302289
)
303290
progress_bar = tqdm(range(model_fine_tuning_optimization_steps), disable=not accelerator.is_local_main_process)
304291

305-
print("Next fine tuning the entire model to better reconstruct the init image")
292+
logger.info("Next fine tuning the entire model to better reconstruct the init image")
306293
for _ in range(model_fine_tuning_optimization_steps):
307294
with accelerator.accumulate(self.unet.parameters()):
308295
# Sample noise that we'll add to the latents
@@ -332,19 +319,172 @@ def __call__(
332319
accelerator.log(logs, step=global_step)
333320

334321
accelerator.wait_for_everyone()
322+
self.text_embeddings_orig = text_embeddings_orig
323+
self.text_embeddings = text_embeddings
335324

336-
new_text_embeddings = alpha * text_embeddings_orig + (1 - alpha) * text_embeddings
337-
image = pipeline(
338-
prompt, text_embeddings=new_text_embeddings, scale=7.5, num_inference_steps=num_inference_steps
339-
).images[
340-
0
341-
] # , latents=noise_latents).images[0]
325+
@torch.no_grad()
326+
def __call__(
327+
self,
328+
alpha: float = 1.2,
329+
height: Optional[int] = 512,
330+
width: Optional[int] = 512,
331+
num_inference_steps: Optional[int] = 50,
332+
generator: Optional[torch.Generator] = None,
333+
output_type: Optional[str] = "pil",
334+
return_dict: bool = True,
335+
guidance_scale: float = 7.5,
336+
eta: float = 0.0,
337+
**kwargs,
338+
):
339+
r"""
340+
Function invoked when calling the pipeline for generation.
341+
Args:
342+
prompt (`str` or `List[str]`):
343+
The prompt or prompts to guide the image generation.
344+
height (`int`, *optional*, defaults to 512):
345+
The height in pixels of the generated image.
346+
width (`int`, *optional*, defaults to 512):
347+
The width in pixels of the generated image.
348+
num_inference_steps (`int`, *optional*, defaults to 50):
349+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
350+
expense of slower inference.
351+
guidance_scale (`float`, *optional*, defaults to 7.5):
352+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
353+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
354+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
355+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
356+
usually at the expense of lower image quality.
357+
eta (`float`, *optional*, defaults to 0.0):
358+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
359+
[`schedulers.DDIMScheduler`], will be ignored for others.
360+
generator (`torch.Generator`, *optional*):
361+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
362+
deterministic.
363+
latents (`torch.FloatTensor`, *optional*):
364+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
365+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
366+
tensor will ge generated by sampling using the supplied random `generator`.
367+
output_type (`str`, *optional*, defaults to `"pil"`):
368+
The output format of the generate image. Choose between
369+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
370+
return_dict (`bool`, *optional*, defaults to `True`):
371+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
372+
plain tuple.
373+
Returns:
374+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
375+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
376+
When returning a tuple, the first element is a list with the generated images, and the second element is a
377+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
378+
(nsfw) content, according to the `safety_checker`.
379+
"""
380+
if height % 8 != 0 or width % 8 != 0:
381+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
382+
if self.text_embeddings is None:
383+
raise ValueError("Please run the pipe.train() before trying to generate an image.")
384+
if self.text_embeddings_orig is None:
385+
raise ValueError("Please run the pipe.train() before trying to generate an image.")
386+
387+
text_embeddings = alpha * self.text_embeddings_orig + (1 - alpha) * self.text_embeddings
388+
389+
390+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
391+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
392+
# corresponds to doing no classifier free guidance.
393+
do_classifier_free_guidance = guidance_scale > 1.0
394+
# get unconditional embeddings for classifier free guidance
395+
if do_classifier_free_guidance:
396+
uncond_tokens = [""]
397+
max_length = self.tokenizer.model_max_length
398+
uncond_input = self.tokenizer(
399+
uncond_tokens,
400+
padding="max_length",
401+
max_length=max_length,
402+
truncation=True,
403+
return_tensors="pt",
404+
)
405+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
406+
407+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
408+
seq_len = uncond_embeddings.shape[1]
409+
uncond_embeddings = uncond_embeddings.view(1, seq_len, -1)
410+
411+
# For classifier free guidance, we need to do two forward passes.
412+
# Here we concatenate the unconditional and text embeddings into a single batch
413+
# to avoid doing two forward passes
414+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
415+
416+
# get the initial random noise unless the user supplied it
417+
418+
# Unlike in other pipelines, latents need to be generated in the target device
419+
# for 1-to-1 results reproducibility with the CompVis implementation.
420+
# However this currently doesn't work in `mps`.
421+
latents_shape = (1, self.unet.in_channels, height // 8, width // 8)
422+
latents_dtype = text_embeddings.dtype
423+
if self.device.type == "mps":
424+
# randn does not exist on mps
425+
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
426+
self.device
427+
)
428+
else:
429+
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
430+
431+
# set timesteps
432+
self.scheduler.set_timesteps(num_inference_steps)
433+
434+
# Some schedulers like PNDM have timesteps as arrays
435+
# It's more optimized to move all timesteps to correct device beforehand
436+
timesteps_tensor = self.scheduler.timesteps.to(self.device)
437+
438+
# scale the initial noise by the standard deviation required by the scheduler
439+
latents = latents * self.scheduler.init_noise_sigma
440+
441+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
442+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
443+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
444+
# and should be between [0, 1]
445+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
446+
extra_step_kwargs = {}
447+
if accepts_eta:
448+
extra_step_kwargs["eta"] = eta
449+
450+
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
451+
# expand the latents if we are doing classifier free guidance
452+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
453+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
454+
455+
# predict the noise residual
456+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
457+
458+
# perform guidance
459+
if do_classifier_free_guidance:
460+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
461+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
462+
463+
# compute the previous noisy sample x_t -> x_t-1
464+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
465+
466+
latents = 1 / 0.18215 * latents
467+
image = self.vae.decode(latents).sample
468+
469+
image = (image / 2 + 0.5).clamp(0, 1)
470+
471+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
472+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
473+
474+
if self.safety_checker is not None:
475+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
476+
self.device
477+
)
478+
image, has_nsfw_concept = self.safety_checker(
479+
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
480+
)
481+
else:
482+
has_nsfw_concept = None
342483

343-
# run safety checker
344-
safety_cheker_input = self.feature_extractor(image, return_tensors="pt").to(self.device)
345-
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
484+
if output_type == "pil":
485+
image = self.numpy_to_pil(image)
346486

347487
if not return_dict:
348488
return (image, has_nsfw_concept)
349489

350-
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
490+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

0 commit comments

Comments
 (0)