44"""
55import warnings
66from typing import List , Optional , Union
7+ import inspect
78
89import numpy as np
910import torch
1920from diffusers .schedulers import DDIMScheduler , LMSDiscreteScheduler , PNDMScheduler
2021from tqdm .auto import tqdm
2122from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
23+ from diffusers .utils import logging
24+
25+ logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
2226
2327
2428def freeze_params (params ):
@@ -115,18 +119,18 @@ def disable_attention_slicing(self):
115119 # set slice_size = `None` to disable `attention slicing`
116120 self .enable_attention_slicing (None )
117121
118- # @torch.no_grad()
119- def __call__ (
122+
123+ def train (
120124 self ,
121125 prompt : Union [str , List [str ]],
122126 init_image : Union [torch .FloatTensor , PIL .Image .Image ],
123- alpha : float = 1.2 ,
124127 height : Optional [int ] = 512 ,
125128 width : Optional [int ] = 512 ,
126- num_inference_steps : Optional [int ] = 50 ,
127- guidance_scale : Optional [float ] = 7.5 ,
128129 generator : Optional [torch .Generator ] = None ,
129- return_dict : bool = True ,
130+ embedding_learning_rate : float = 0.001 ,
131+ diffusion_model_learning_rate : float = 2e-6 ,
132+ text_embedding_optimization_steps : int = 500 ,
133+ model_fine_tuning_optimization_steps : int = 1000 ,
130134 ** kwargs ,
131135 ):
132136 r"""
@@ -170,11 +174,6 @@ def __call__(
170174 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
171175 (nsfw) content, according to the `safety_checker`.
172176 """
173- embedding_learning_rate = 0.001
174- diffusion_model_learning_rate = 2e-6
175- text_embedding_optimization_steps = 500
176- model_fine_tuning_optimization_steps = 1000
177-
178177 accelerator = Accelerator (
179178 gradient_accumulation_steps = 1 ,
180179 mixed_precision = "fp16" ,
@@ -197,9 +196,9 @@ def __call__(
197196 raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
198197
199198 # Freeze vae and unet
200- freeze_params ( self .vae .parameters () )
201- freeze_params ( self .unet .parameters () )
202- freeze_params ( self .text_encoder .parameters () )
199+ self .vae .requires_grad_ ( False )
200+ self .unet .requires_grad_ ( False )
201+ self .text_encoder .requires_grad_ ( False )
203202 self .unet .eval ()
204203 self .vae .eval ()
205204 self .text_encoder .eval ()
@@ -243,23 +242,12 @@ def __call__(
243242 init_image_latents = init_latent_image_dist .sample (generator = generator )
244243 init_image_latents = 0.18215 * init_image_latents
245244
246- pipeline = StableDiffusionPipeline (
247- text_encoder = self .text_encoder ,
248- vae = self .vae ,
249- unet = self .unet ,
250- tokenizer = self .tokenizer ,
251- scheduler = self .scheduler ,
252- safety_checker = self .safety_checker ,
253- feature_extractor = self .feature_extractor ,
254- )
255- pipeline = pipeline .to ("cuda" )
256-
257245 progress_bar = tqdm (range (text_embedding_optimization_steps ), disable = not accelerator .is_local_main_process )
258246 progress_bar .set_description ("Steps" )
259247
260248 global_step = 0
261249
262- print ("First optimizing the text embedding to better reconstruct the init image" )
250+ logger . info ("First optimizing the text embedding to better reconstruct the init image" )
263251 for _ in range (text_embedding_optimization_steps ):
264252 with accelerator .accumulate (text_embeddings ):
265253 # Sample noise that we'll add to the latents
@@ -291,18 +279,17 @@ def __call__(
291279 accelerator .wait_for_everyone ()
292280
293281 text_embeddings .requires_grad_ (False )
294- freeze_params (text_embeddings )
295282
296283 # Now we fine tune the unet to better reconstruct the image
297- unfreeze_params ( self .unet .parameters () )
284+ self .unet .requires_grad_ ( True )
298285 self .unet .train ()
299286 optimizer = torch .optim .Adam (
300287 self .unet .parameters (), # only optimize unet
301288 lr = diffusion_model_learning_rate ,
302289 )
303290 progress_bar = tqdm (range (model_fine_tuning_optimization_steps ), disable = not accelerator .is_local_main_process )
304291
305- print ("Next fine tuning the entire model to better reconstruct the init image" )
292+ logger . info ("Next fine tuning the entire model to better reconstruct the init image" )
306293 for _ in range (model_fine_tuning_optimization_steps ):
307294 with accelerator .accumulate (self .unet .parameters ()):
308295 # Sample noise that we'll add to the latents
@@ -332,19 +319,172 @@ def __call__(
332319 accelerator .log (logs , step = global_step )
333320
334321 accelerator .wait_for_everyone ()
322+ self .text_embeddings_orig = text_embeddings_orig
323+ self .text_embeddings = text_embeddings
335324
336- new_text_embeddings = alpha * text_embeddings_orig + (1 - alpha ) * text_embeddings
337- image = pipeline (
338- prompt , text_embeddings = new_text_embeddings , scale = 7.5 , num_inference_steps = num_inference_steps
339- ).images [
340- 0
341- ] # , latents=noise_latents).images[0]
325+ @torch .no_grad ()
326+ def __call__ (
327+ self ,
328+ alpha : float = 1.2 ,
329+ height : Optional [int ] = 512 ,
330+ width : Optional [int ] = 512 ,
331+ num_inference_steps : Optional [int ] = 50 ,
332+ generator : Optional [torch .Generator ] = None ,
333+ output_type : Optional [str ] = "pil" ,
334+ return_dict : bool = True ,
335+ guidance_scale : float = 7.5 ,
336+ eta : float = 0.0 ,
337+ ** kwargs ,
338+ ):
339+ r"""
340+ Function invoked when calling the pipeline for generation.
341+ Args:
342+ prompt (`str` or `List[str]`):
343+ The prompt or prompts to guide the image generation.
344+ height (`int`, *optional*, defaults to 512):
345+ The height in pixels of the generated image.
346+ width (`int`, *optional*, defaults to 512):
347+ The width in pixels of the generated image.
348+ num_inference_steps (`int`, *optional*, defaults to 50):
349+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
350+ expense of slower inference.
351+ guidance_scale (`float`, *optional*, defaults to 7.5):
352+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
353+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
354+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
355+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
356+ usually at the expense of lower image quality.
357+ eta (`float`, *optional*, defaults to 0.0):
358+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
359+ [`schedulers.DDIMScheduler`], will be ignored for others.
360+ generator (`torch.Generator`, *optional*):
361+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
362+ deterministic.
363+ latents (`torch.FloatTensor`, *optional*):
364+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
365+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
366+ tensor will ge generated by sampling using the supplied random `generator`.
367+ output_type (`str`, *optional*, defaults to `"pil"`):
368+ The output format of the generate image. Choose between
369+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
370+ return_dict (`bool`, *optional*, defaults to `True`):
371+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
372+ plain tuple.
373+ Returns:
374+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
375+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
376+ When returning a tuple, the first element is a list with the generated images, and the second element is a
377+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
378+ (nsfw) content, according to the `safety_checker`.
379+ """
380+ if height % 8 != 0 or width % 8 != 0 :
381+ raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
382+ if self .text_embeddings is None :
383+ raise ValueError ("Please run the pipe.train() before trying to generate an image." )
384+ if self .text_embeddings_orig is None :
385+ raise ValueError ("Please run the pipe.train() before trying to generate an image." )
386+
387+ text_embeddings = alpha * self .text_embeddings_orig + (1 - alpha ) * self .text_embeddings
388+
389+
390+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
391+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
392+ # corresponds to doing no classifier free guidance.
393+ do_classifier_free_guidance = guidance_scale > 1.0
394+ # get unconditional embeddings for classifier free guidance
395+ if do_classifier_free_guidance :
396+ uncond_tokens = ["" ]
397+ max_length = self .tokenizer .model_max_length
398+ uncond_input = self .tokenizer (
399+ uncond_tokens ,
400+ padding = "max_length" ,
401+ max_length = max_length ,
402+ truncation = True ,
403+ return_tensors = "pt" ,
404+ )
405+ uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
406+
407+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
408+ seq_len = uncond_embeddings .shape [1 ]
409+ uncond_embeddings = uncond_embeddings .view (1 , seq_len , - 1 )
410+
411+ # For classifier free guidance, we need to do two forward passes.
412+ # Here we concatenate the unconditional and text embeddings into a single batch
413+ # to avoid doing two forward passes
414+ text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
415+
416+ # get the initial random noise unless the user supplied it
417+
418+ # Unlike in other pipelines, latents need to be generated in the target device
419+ # for 1-to-1 results reproducibility with the CompVis implementation.
420+ # However this currently doesn't work in `mps`.
421+ latents_shape = (1 , self .unet .in_channels , height // 8 , width // 8 )
422+ latents_dtype = text_embeddings .dtype
423+ if self .device .type == "mps" :
424+ # randn does not exist on mps
425+ latents = torch .randn (latents_shape , generator = generator , device = "cpu" , dtype = latents_dtype ).to (
426+ self .device
427+ )
428+ else :
429+ latents = torch .randn (latents_shape , generator = generator , device = self .device , dtype = latents_dtype )
430+
431+ # set timesteps
432+ self .scheduler .set_timesteps (num_inference_steps )
433+
434+ # Some schedulers like PNDM have timesteps as arrays
435+ # It's more optimized to move all timesteps to correct device beforehand
436+ timesteps_tensor = self .scheduler .timesteps .to (self .device )
437+
438+ # scale the initial noise by the standard deviation required by the scheduler
439+ latents = latents * self .scheduler .init_noise_sigma
440+
441+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
442+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
443+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
444+ # and should be between [0, 1]
445+ accepts_eta = "eta" in set (inspect .signature (self .scheduler .step ).parameters .keys ())
446+ extra_step_kwargs = {}
447+ if accepts_eta :
448+ extra_step_kwargs ["eta" ] = eta
449+
450+ for i , t in enumerate (self .progress_bar (timesteps_tensor )):
451+ # expand the latents if we are doing classifier free guidance
452+ latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
453+ latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
454+
455+ # predict the noise residual
456+ noise_pred = self .unet (latent_model_input , t , encoder_hidden_states = text_embeddings ).sample
457+
458+ # perform guidance
459+ if do_classifier_free_guidance :
460+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
461+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
462+
463+ # compute the previous noisy sample x_t -> x_t-1
464+ latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ).prev_sample
465+
466+ latents = 1 / 0.18215 * latents
467+ image = self .vae .decode (latents ).sample
468+
469+ image = (image / 2 + 0.5 ).clamp (0 , 1 )
470+
471+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
472+ image = image .cpu ().permute (0 , 2 , 3 , 1 ).float ().numpy ()
473+
474+ if self .safety_checker is not None :
475+ safety_checker_input = self .feature_extractor (self .numpy_to_pil (image ), return_tensors = "pt" ).to (
476+ self .device
477+ )
478+ image , has_nsfw_concept = self .safety_checker (
479+ images = image , clip_input = safety_checker_input .pixel_values .to (text_embeddings .dtype )
480+ )
481+ else :
482+ has_nsfw_concept = None
342483
343- # run safety checker
344- safety_cheker_input = self .feature_extractor (image , return_tensors = "pt" ).to (self .device )
345- image , has_nsfw_concept = self .safety_checker (images = image , clip_input = safety_cheker_input .pixel_values )
484+ if output_type == "pil" :
485+ image = self .numpy_to_pil (image )
346486
347487 if not return_dict :
348488 return (image , has_nsfw_concept )
349489
350- return StableDiffusionPipelineOutput (images = image , nsfw_content_detected = has_nsfw_concept )
490+ return StableDiffusionPipelineOutput (images = image , nsfw_content_detected = has_nsfw_concept )
0 commit comments