From 2d4d23e2ed3bf891f64c7643d912f4dc58c3daa5 Mon Sep 17 00:00:00 2001 From: Abhinay1997 Date: Mon, 6 Feb 2023 15:35:42 +0530 Subject: [PATCH 01/10] UnCLIP Text Interpolation Pipeline --- examples/community/README.md | 38 + .../community/unclip_text_interpolation.py | 670 ++++++++++++++++++ 2 files changed, 708 insertions(+) create mode 100644 examples/community/unclip_text_interpolation.py diff --git a/examples/community/README.md b/examples/community/README.md index 67dc30b7164b..a61d2e1f1ce0 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -27,6 +27,8 @@ If a community doesn't work as expected, please open an issue and ping the autho Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) | MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) | | Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - |[Ray Wang](https://wrong.wang) | +| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | + @@ -951,3 +953,39 @@ print(pipeline.prior_scheduler) ![shiba-inu](https://user-images.githubusercontent.com/16448529/209185639-6e5ec794-ce9d-4883-aa29-bd6852a2abad.jpg) +### UnCLIP Text Interpolation Pipeline + +This Diffusion Pipeline takes two prompts and interpolates between the two input prompts using spherical interpolation ( slerp ). The input prompts are converted to text embeddings by the pipeline's text_encoder and the interpolation is done on the resulting text_embeddings over the number of steps specified. Defaults to 5 steps. + +```python +import torch +from diffusers import DiffusionPipeline + +device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") + +pipe = DiffusionPipeline.from_pretrained( + "kakaobrain/karlo-v1-alpha", + torch_dtype=torch.float16, + custom_pipeline="unclip_text_interpolation" +) +pipe.to(device) + +start_prompt = "A photograph of a frog" +end_prompt = "A painting of a horse" +#For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths. +generator = torch.Generator(device=device).manual_seed(42) + +output = pipe(start_prompt, end_prompt, steps = 6, generator = generator, enable_sequential_cpu_offload=False) + +for i,image in enumerate(output.images): + img.save('result%s.jpg' % i) +``` + +The resulting images in order:- + +![result_0](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/frog_horse_0.png) +![result_1](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/frog_horse_1.png) +![result_2](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/frog_horse_2.png) +![result_3](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/frog_horse_3.png) +![result_4](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/frog_horse_4.png) +![result_5](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/frog_horse_5.png) diff --git a/examples/community/unclip_text_interpolation.py b/examples/community/unclip_text_interpolation.py new file mode 100644 index 000000000000..a3c7a5971fcc --- /dev/null +++ b/examples/community/unclip_text_interpolation.py @@ -0,0 +1,670 @@ +import inspect +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch.nn import functional as F + +from diffusers import ( + DiffusionPipeline, + ImagePipelineOutput, + PriorTransformer, + UnCLIPScheduler, + UNet2DConditionModel, + UNet2DModel, +) +from diffusers.pipelines.unclip import UnCLIPTextProjModel +from diffusers.utils import is_accelerate_available, logging, randn_tensor +from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput + + +class UnCLIPTextInterpolationPipeline(DiffusionPipeline): + + """ + Pipeline for prompt-to-prompt interpolation on CLIP text embeddings and using the UnCLIP / Dall-E to decode them to images. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`UNet2DConditionModel`]): + The decoder to invert the image embedding into an image. + super_res_first ([`UNet2DModel`]): + Super resolution unet. Used in all but the last step of the super resolution diffusion process. + super_res_last ([`UNet2DModel`]): + Super resolution unet. Used in the last step of the super resolution diffusion process. + prior_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the prior denoising process. Just a modified DDPMScheduler. + decoder_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. + super_res_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. + + """ + + prior: PriorTransformer + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + prior_scheduler: UnCLIPScheduler + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.__init__ + def __init__( + self, + prior: PriorTransformer, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + prior_scheduler: UnCLIPScheduler, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + prior=prior, + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + super_res_first=super_res_first, + super_res_last=super_res_last, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + ): + if text_model_output is None: + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + else: + batch_size = text_model_output[0].shape[0] + prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1] + text_mask = text_attention_mask + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list + models = [ + self.decoder, + self.text_proj, + self.text_encoder, + self.super_res_first, + self.super_res_last, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"): + return self.device + for module in self.decoder.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.__call__ + def _generate( + self, + prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + prior_num_inference_steps: int = 25, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prior_latents: Optional[torch.FloatTensor] = None, + decoder_latents: Optional[torch.FloatTensor] = None, + super_res_latents: Optional[torch.FloatTensor] = None, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + prior_guidance_scale: float = 4.0, + decoder_guidance_scale: float = 8.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. This can only be left undefined if + `text_model_output` and `text_attention_mask` is passed. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prior_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the prior. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): + Pre-generated noisy latents to be used as inputs for the prior. + decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + text_model_output (`CLIPTextModelOutput`, *optional*): + Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs + can be passed for tasks like text embedding interpolations. Make sure to also pass + `text_attention_mask` in this case. `prompt` can the be left to `None`. + text_attention_mask (`torch.Tensor`, *optional*): + Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention + masks are necessary when passing `text_model_output`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + """ + if prompt is not None: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + else: + batch_size = text_model_output[0].shape[0] + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask + ) + + # prior + + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) + prior_timesteps_tensor = self.prior_scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + prompt_embeds.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + prior_latents = self.prior.post_process_latents(prior_latents) + + image_embeddings = prior_latents + + # done prior + + # decoder + + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + prompt_embeds=prompt_embeds, + text_encoder_hidden_states=text_encoder_hidden_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + if device.type == "mps": + # HACK: MPS: There is a panic when padding bool tensors, + # so cast to int tensor for the pad and back to bool afterwards + text_mask = text_mask.type(torch.int) + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + decoder_text_mask = decoder_text_mask.type(torch.bool) + else: + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.in_channels + height = self.decoder.sample_size + width = self.decoder.sample_size + + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.in_channels // 2 + height = self.super_res_first.sample_size + width = self.super_res_first.sample_size + + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + if device.type == "mps": + # MPS does not support many interpolations + image_upscaled = F.interpolate(image_small, size=[height, width]) + else: + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + image = super_res_latents + # done super res + + # post processing + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + @staticmethod + def slerp(val, low, high): + """ + Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic. + """ + low_norm = low / torch.norm(low) + high_norm = high / torch.norm(high) + omega = torch.acos((low_norm * high_norm)) + so = torch.sin(omega) + res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high + return res + + @torch.no_grad() + def __call__( + self, + start_prompt: str, + end_prompt: str, + steps: int = 5, + prior_num_inference_steps: int = 25, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prior_guidance_scale: float = 4.0, + decoder_guidance_scale: float = 8.0, + enable_sequential_cpu_offload=True, + gpu_id=0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + start_prompt (`str`): + The prompt to start the image generation interpolation from. + end_prompt (`str`): + The prompt to end the image generation interpolation at. + steps (`int`, *optional*, defaults to 5): + The number of steps over which to interpolate from start_prompt to end_prompt. The pipeline returns + the same number of images as this value. + prior_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the prior. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + enable_sequential_cpu_offload (`bool`, *optional*, defaults to `True`): + If True, offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + gpu_id (`int`, *optional*, defaults to `0`): + The gpu_id to be passed to enable_sequential_cpu_offload. Only works when enable_sequential_cpu_offload is set to True. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + """ + + if enable_sequential_cpu_offload: + self.enable_sequential_cpu_offload(gpu_id=gpu_id) + + device = self._execution_device + + # Turn the prompts into embeddings. + inputs = self.tokenizer( + [start_prompt, end_prompt], + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + inputs.to(device) + text_model_output = self.text_encoder(**inputs) + + attention_mask = torch.max(inputs.attention_mask[0], inputs.attention_mask[1]) + attention_mask = attention_mask.unsqueeze(0).to(device) + + start_text_embeds = text_model_output.text_embeds[0] + end_text_embeds = text_model_output.text_embeds[1] + + start_last_hidden_state = text_model_output.last_hidden_state[0] + end_last_hidden_state = text_model_output.last_hidden_state[1] + + # Interpolate from the start to end prompt using slerp and add the generated images to an image output pipeline + out = [] + for interp_val in np.linspace(0, 1, steps): + # Use the start and end prompts for 0 and 1 values as slerp results are subjectively worse than slerp results for the same. + if interp_val == 0: + text_embeds = start_text_embeds + last_hidden_state = start_last_hidden_state + elif interp_val == 1: + text_embeds = end_text_embeds + last_hidden_state = end_last_hidden_state + else: + text_embeds = UnCLIPTextInterpolationPipeline.slerp(interp_val, start_text_embeds, end_text_embeds) + last_hidden_state = UnCLIPTextInterpolationPipeline.slerp( + interp_val, start_last_hidden_state, end_last_hidden_state + ) + + text_model_output.text_embeds = text_embeds.unsqueeze(0).to(device) + text_model_output.last_hidden_state = last_hidden_state.unsqueeze(0).to(device) + + res = self._generate( + text_model_output=text_model_output, + text_attention_mask=attention_mask, + generator=generator, + prior_num_inference_steps=prior_num_inference_steps, + decoder_num_inference_steps=decoder_num_inference_steps, + super_res_num_inference_steps=super_res_num_inference_steps, + prior_guidance_scale=prior_guidance_scale, + decoder_guidance_scale=decoder_guidance_scale, + output_type=output_type, + return_dict=return_dict, + ) + + if isinstance(res, ImagePipelineOutput): + out.append(res.images[0]) + elif isinstance(res, tuple): + out.append(res[0][0]) + + return ImagePipelineOutput(images=out) if return_dict else (out,) From 985a53f8243c7b9d0766d624c3c7dbc16c0d8dd6 Mon Sep 17 00:00:00 2001 From: Abhinay1997 Date: Mon, 6 Feb 2023 15:41:13 +0530 Subject: [PATCH 02/10] Formatter fixes --- examples/community/unclip_text_interpolation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/community/unclip_text_interpolation.py b/examples/community/unclip_text_interpolation.py index a3c7a5971fcc..5be6983a5a2e 100644 --- a/examples/community/unclip_text_interpolation.py +++ b/examples/community/unclip_text_interpolation.py @@ -1,5 +1,5 @@ import inspect -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -19,6 +19,9 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + class UnCLIPTextInterpolationPipeline(DiffusionPipeline): """ From d5b20e73a950291ad7a4f0b7604ecec0a1d5b945 Mon Sep 17 00:00:00 2001 From: Abhinay1997 Date: Wed, 8 Feb 2023 18:46:33 +0530 Subject: [PATCH 03/10] Changes based on feedback --- .../community/unclip_text_interpolation.py | 389 +++++++++++++++--- 1 file changed, 334 insertions(+), 55 deletions(-) diff --git a/examples/community/unclip_text_interpolation.py b/examples/community/unclip_text_interpolation.py index 5be6983a5a2e..2d33c20340ba 100644 --- a/examples/community/unclip_text_interpolation.py +++ b/examples/community/unclip_text_interpolation.py @@ -22,6 +22,18 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +def slerp(val, low, high): + """ + Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic. + """ + low_norm = low / torch.norm(low) + high_norm = high / torch.norm(high) + omega = torch.acos((low_norm * high_norm)) + so = torch.sin(omega) + res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high + return res + + class UnCLIPTextInterpolationPipeline(DiffusionPipeline): """ @@ -531,18 +543,6 @@ def _generate( return ImagePipelineOutput(images=image) - @staticmethod - def slerp(val, low, high): - """ - Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic. - """ - low_norm = low / torch.norm(low) - high_norm = high / torch.norm(high) - omega = torch.acos((low_norm * high_norm)) - so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high - return res - @torch.no_grad() def __call__( self, @@ -608,6 +608,299 @@ def __call__( Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. """ + def _generate( + self, + prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + prior_num_inference_steps: int = 25, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prior_latents: Optional[torch.FloatTensor] = None, + decoder_latents: Optional[torch.FloatTensor] = None, + super_res_latents: Optional[torch.FloatTensor] = None, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + prior_guidance_scale: float = 4.0, + decoder_guidance_scale: float = 8.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. This can only be left undefined if + `text_model_output` and `text_attention_mask` is passed. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prior_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the prior. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): + Pre-generated noisy latents to be used as inputs for the prior. + decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + text_model_output (`CLIPTextModelOutput`, *optional*): + Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs + can be passed for tasks like text embedding interpolations. Make sure to also pass + `text_attention_mask` in this case. `prompt` can the be left to `None`. + text_attention_mask (`torch.Tensor`, *optional*): + Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention + masks are necessary when passing `text_model_output`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + """ + if prompt is not None: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + else: + batch_size = text_model_output[0].shape[0] + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output, + text_attention_mask, + ) + + # prior + + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) + prior_timesteps_tensor = self.prior_scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + prompt_embeds.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk( + 2 + ) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + prior_latents = self.prior.post_process_latents(prior_latents) + + image_embeddings = prior_latents + + # done prior + + # decoder + + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + prompt_embeds=prompt_embeds, + text_encoder_hidden_states=text_encoder_hidden_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + if device.type == "mps": + # HACK: MPS: There is a panic when padding bool tensors, + # so cast to int tensor for the pad and back to bool afterwards + text_mask = text_mask.type(torch.int) + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + decoder_text_mask = decoder_text_mask.type(torch.bool) + else: + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.in_channels + height = self.decoder.sample_size + width = self.decoder.sample_size + + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + ) + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.in_channels // 2 + height = self.super_res_first.sample_size + width = self.super_res_first.sample_size + + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + if device.type == "mps": + # MPS does not support many interpolations + image_upscaled = F.interpolate(image_small, size=[height, width]) + else: + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + image = super_res_latents + # done super res + + # post processing + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + if enable_sequential_cpu_offload: self.enable_sequential_cpu_offload(gpu_id=gpu_id) @@ -624,50 +917,36 @@ def __call__( inputs.to(device) text_model_output = self.text_encoder(**inputs) - attention_mask = torch.max(inputs.attention_mask[0], inputs.attention_mask[1]) - attention_mask = attention_mask.unsqueeze(0).to(device) - - start_text_embeds = text_model_output.text_embeds[0] - end_text_embeds = text_model_output.text_embeds[1] - - start_last_hidden_state = text_model_output.last_hidden_state[0] - end_last_hidden_state = text_model_output.last_hidden_state[1] + batch_attention_mask = torch.max(inputs.attention_mask[0], inputs.attention_mask[1]) + batch_attention_mask = torch.cat([batch_attention_mask.unsqueeze(0)] * steps).to(device) # Interpolate from the start to end prompt using slerp and add the generated images to an image output pipeline - out = [] - for interp_val in np.linspace(0, 1, steps): - # Use the start and end prompts for 0 and 1 values as slerp results are subjectively worse than slerp results for the same. - if interp_val == 0: - text_embeds = start_text_embeds - last_hidden_state = start_last_hidden_state - elif interp_val == 1: - text_embeds = end_text_embeds - last_hidden_state = end_last_hidden_state - else: - text_embeds = UnCLIPTextInterpolationPipeline.slerp(interp_val, start_text_embeds, end_text_embeds) - last_hidden_state = UnCLIPTextInterpolationPipeline.slerp( - interp_val, start_last_hidden_state, end_last_hidden_state - ) - - text_model_output.text_embeds = text_embeds.unsqueeze(0).to(device) - text_model_output.last_hidden_state = last_hidden_state.unsqueeze(0).to(device) + batch_text_embeds = [] + batch_last_hidden_state = [] - res = self._generate( - text_model_output=text_model_output, - text_attention_mask=attention_mask, - generator=generator, - prior_num_inference_steps=prior_num_inference_steps, - decoder_num_inference_steps=decoder_num_inference_steps, - super_res_num_inference_steps=super_res_num_inference_steps, - prior_guidance_scale=prior_guidance_scale, - decoder_guidance_scale=decoder_guidance_scale, - output_type=output_type, - return_dict=return_dict, + for interp_val in torch.linspace(0, 1, steps): + text_embeds = slerp(interp_val, text_model_output.text_embeds[0], text_model_output.text_embeds[1]) + last_hidden_state = slerp( + interp_val, text_model_output.last_hidden_state[0], text_model_output.last_hidden_state[1] ) - - if isinstance(res, ImagePipelineOutput): - out.append(res.images[0]) - elif isinstance(res, tuple): - out.append(res[0][0]) - - return ImagePipelineOutput(images=out) if return_dict else (out,) + batch_text_embeds.append(text_embeds.unsqueeze(0)) + batch_last_hidden_state.append(last_hidden_state.unsqueeze(0)) + + batch_text_embeds = torch.cat(batch_text_embeds) + batch_last_hidden_state = torch.cat(batch_last_hidden_state) + + res = self._generate( + text_model_output=CLIPTextModelOutput( + text_embeds=batch_text_embeds, last_hidden_state=batch_last_hidden_state + ), + text_attention_mask=batch_attention_mask, + generator=generator, + prior_num_inference_steps=prior_num_inference_steps, + decoder_num_inference_steps=decoder_num_inference_steps, + super_res_num_inference_steps=super_res_num_inference_steps, + prior_guidance_scale=prior_guidance_scale, + decoder_guidance_scale=decoder_guidance_scale, + output_type=output_type, + return_dict=return_dict, + ) + return res From 373ac44762093647fbf9e64b82e282989c4a517e Mon Sep 17 00:00:00 2001 From: Abhinay1997 Date: Wed, 8 Feb 2023 18:51:53 +0530 Subject: [PATCH 04/10] Formatting fix --- examples/community/unclip_text_interpolation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/community/unclip_text_interpolation.py b/examples/community/unclip_text_interpolation.py index 2d33c20340ba..d781e9ca94ec 100644 --- a/examples/community/unclip_text_interpolation.py +++ b/examples/community/unclip_text_interpolation.py @@ -1,7 +1,6 @@ import inspect from typing import List, Optional, Tuple, Union -import numpy as np import torch from torch.nn import functional as F From 270e121393ff3f9f01d6ade73d14718c36be8ec2 Mon Sep 17 00:00:00 2001 From: Abhinay1997 Date: Wed, 8 Feb 2023 18:54:13 +0530 Subject: [PATCH 05/10] Formatting fix --- examples/community/unclip_text_interpolation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/community/unclip_text_interpolation.py b/examples/community/unclip_text_interpolation.py index d781e9ca94ec..3cf4b052c873 100644 --- a/examples/community/unclip_text_interpolation.py +++ b/examples/community/unclip_text_interpolation.py @@ -14,6 +14,7 @@ ) from diffusers.pipelines.unclip import UnCLIPTextProjModel from diffusers.utils import is_accelerate_available, logging, randn_tensor + from transformers import CLIPTextModelWithProjection, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPTextModelOutput From 8d571e666d159f8bffa89a3f1a127bc3155c0a2a Mon Sep 17 00:00:00 2001 From: Abhinay1997 Date: Wed, 8 Feb 2023 19:00:46 +0530 Subject: [PATCH 06/10] isort formatting fix(?) --- examples/community/unclip_text_interpolation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/community/unclip_text_interpolation.py b/examples/community/unclip_text_interpolation.py index 3cf4b052c873..d781e9ca94ec 100644 --- a/examples/community/unclip_text_interpolation.py +++ b/examples/community/unclip_text_interpolation.py @@ -14,7 +14,6 @@ ) from diffusers.pipelines.unclip import UnCLIPTextProjModel from diffusers.utils import is_accelerate_available, logging, randn_tensor - from transformers import CLIPTextModelWithProjection, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPTextModelOutput From 135e49a1f021e2d9bf3fdc58c84a230a77e45351 Mon Sep 17 00:00:00 2001 From: Abhinay1997 Date: Wed, 8 Feb 2023 19:17:22 +0530 Subject: [PATCH 07/10] Remove duplicate code --- .../community/unclip_text_interpolation.py | 295 +----------------- 1 file changed, 5 insertions(+), 290 deletions(-) diff --git a/examples/community/unclip_text_interpolation.py b/examples/community/unclip_text_interpolation.py index d781e9ca94ec..05adb655a0eb 100644 --- a/examples/community/unclip_text_interpolation.py +++ b/examples/community/unclip_text_interpolation.py @@ -256,292 +256,6 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - @torch.no_grad() - # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.__call__ - def _generate( - self, - prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - prior_num_inference_steps: int = 25, - decoder_num_inference_steps: int = 25, - super_res_num_inference_steps: int = 7, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - prior_latents: Optional[torch.FloatTensor] = None, - decoder_latents: Optional[torch.FloatTensor] = None, - super_res_latents: Optional[torch.FloatTensor] = None, - text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, - text_attention_mask: Optional[torch.Tensor] = None, - prior_guidance_scale: float = 4.0, - decoder_guidance_scale: float = 8.0, - output_type: Optional[str] = "pil", - return_dict: bool = True, - ): - """ - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. This can only be left undefined if - `text_model_output` and `text_attention_mask` is passed. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - prior_num_inference_steps (`int`, *optional*, defaults to 25): - The number of denoising steps for the prior. More denoising steps usually lead to a higher quality - image at the expense of slower inference. - decoder_num_inference_steps (`int`, *optional*, defaults to 25): - The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality - image at the expense of slower inference. - super_res_num_inference_steps (`int`, *optional*, defaults to 7): - The number of denoising steps for super resolution. More denoising steps usually lead to a higher - quality image at the expense of slower inference. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): - Pre-generated noisy latents to be used as inputs for the prior. - decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): - Pre-generated noisy latents to be used as inputs for the decoder. - super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): - Pre-generated noisy latents to be used as inputs for the decoder. - prior_guidance_scale (`float`, *optional*, defaults to 4.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - decoder_guidance_scale (`float`, *optional*, defaults to 4.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - text_model_output (`CLIPTextModelOutput`, *optional*): - Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs - can be passed for tasks like text embedding interpolations. Make sure to also pass - `text_attention_mask` in this case. `prompt` can the be left to `None`. - text_attention_mask (`torch.Tensor`, *optional*): - Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention - masks are necessary when passing `text_model_output`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. - """ - if prompt is not None: - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - else: - batch_size = text_model_output[0].shape[0] - - device = self._execution_device - - batch_size = batch_size * num_images_per_prompt - - do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 - - prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask - ) - - # prior - - self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) - prior_timesteps_tensor = self.prior_scheduler.timesteps - - embedding_dim = self.prior.config.embedding_dim - - prior_latents = self.prepare_latents( - (batch_size, embedding_dim), - prompt_embeds.dtype, - device, - generator, - prior_latents, - self.prior_scheduler, - ) - - for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents - - predicted_image_embedding = self.prior( - latent_model_input, - timestep=t, - proj_embedding=prompt_embeds, - encoder_hidden_states=text_encoder_hidden_states, - attention_mask=text_mask, - ).predicted_image_embedding - - if do_classifier_free_guidance: - predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) - predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( - predicted_image_embedding_text - predicted_image_embedding_uncond - ) - - if i + 1 == prior_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = prior_timesteps_tensor[i + 1] - - prior_latents = self.prior_scheduler.step( - predicted_image_embedding, - timestep=t, - sample=prior_latents, - generator=generator, - prev_timestep=prev_timestep, - ).prev_sample - - prior_latents = self.prior.post_process_latents(prior_latents) - - image_embeddings = prior_latents - - # done prior - - # decoder - - text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( - image_embeddings=image_embeddings, - prompt_embeds=prompt_embeds, - text_encoder_hidden_states=text_encoder_hidden_states, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - - if device.type == "mps": - # HACK: MPS: There is a panic when padding bool tensors, - # so cast to int tensor for the pad and back to bool afterwards - text_mask = text_mask.type(torch.int) - decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) - decoder_text_mask = decoder_text_mask.type(torch.bool) - else: - decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) - - self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) - decoder_timesteps_tensor = self.decoder_scheduler.timesteps - - num_channels_latents = self.decoder.in_channels - height = self.decoder.sample_size - width = self.decoder.sample_size - - decoder_latents = self.prepare_latents( - (batch_size, num_channels_latents, height, width), - text_encoder_hidden_states.dtype, - device, - generator, - decoder_latents, - self.decoder_scheduler, - ) - - for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents - - noise_pred = self.decoder( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=text_encoder_hidden_states, - class_labels=additive_clip_time_embeddings, - attention_mask=decoder_text_mask, - ).sample - - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) - noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) - noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) - noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) - - if i + 1 == decoder_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = decoder_timesteps_tensor[i + 1] - - # compute the previous noisy sample x_t -> x_t-1 - decoder_latents = self.decoder_scheduler.step( - noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator - ).prev_sample - - decoder_latents = decoder_latents.clamp(-1, 1) - - image_small = decoder_latents - - # done decoder - - # super res - - self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) - super_res_timesteps_tensor = self.super_res_scheduler.timesteps - - channels = self.super_res_first.in_channels // 2 - height = self.super_res_first.sample_size - width = self.super_res_first.sample_size - - super_res_latents = self.prepare_latents( - (batch_size, channels, height, width), - image_small.dtype, - device, - generator, - super_res_latents, - self.super_res_scheduler, - ) - - if device.type == "mps": - # MPS does not support many interpolations - image_upscaled = F.interpolate(image_small, size=[height, width]) - else: - interpolate_antialias = {} - if "antialias" in inspect.signature(F.interpolate).parameters: - interpolate_antialias["antialias"] = True - - image_upscaled = F.interpolate( - image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias - ) - - for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): - # no classifier free guidance - - if i == super_res_timesteps_tensor.shape[0] - 1: - unet = self.super_res_last - else: - unet = self.super_res_first - - latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) - - noise_pred = unet( - sample=latent_model_input, - timestep=t, - ).sample - - if i + 1 == super_res_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = super_res_timesteps_tensor[i + 1] - - # compute the previous noisy sample x_t -> x_t-1 - super_res_latents = self.super_res_scheduler.step( - noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator - ).prev_sample - - image = super_res_latents - # done super res - - # post processing - - image = image * 0.5 + 0.5 - image = image.clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) - @torch.no_grad() def __call__( self, @@ -606,9 +320,10 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. """ - - def _generate( - self, + + @torch.no_grad() + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.__call__ + def generate( prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, prior_num_inference_steps: int = 25, @@ -934,7 +649,7 @@ def _generate( batch_text_embeds = torch.cat(batch_text_embeds) batch_last_hidden_state = torch.cat(batch_last_hidden_state) - res = self._generate( + res = generate( text_model_output=CLIPTextModelOutput( text_embeds=batch_text_embeds, last_hidden_state=batch_last_hidden_state ), From ba6343059581694e90c5588dfc2001e57bba2821 Mon Sep 17 00:00:00 2001 From: Abhinay1997 Date: Wed, 8 Feb 2023 19:20:31 +0530 Subject: [PATCH 08/10] Formatting fix --- examples/community/unclip_text_interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/unclip_text_interpolation.py b/examples/community/unclip_text_interpolation.py index 05adb655a0eb..d76455c9d4b4 100644 --- a/examples/community/unclip_text_interpolation.py +++ b/examples/community/unclip_text_interpolation.py @@ -320,7 +320,7 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. """ - + @torch.no_grad() # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.__call__ def generate( From ea97db73e5a3a0bd85da738434897cf306a7bb78 Mon Sep 17 00:00:00 2001 From: Abhinay1997 Date: Sat, 11 Feb 2023 15:19:19 +0530 Subject: [PATCH 09/10] Refactor __call__ and change example in readme. --- examples/community/README.md | 16 +- .../community/unclip_text_interpolation.py | 519 +++++++----------- 2 files changed, 221 insertions(+), 314 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index a61d2e1f1ce0..905f7b887b46 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -970,8 +970,8 @@ pipe = DiffusionPipeline.from_pretrained( ) pipe.to(device) -start_prompt = "A photograph of a frog" -end_prompt = "A painting of a horse" +start_prompt = "A photograph of an adult lion" +end_prompt = "A photograph of a lion cub" #For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths. generator = torch.Generator(device=device).manual_seed(42) @@ -983,9 +983,9 @@ for i,image in enumerate(output.images): The resulting images in order:- -![result_0](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/frog_horse_0.png) -![result_1](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/frog_horse_1.png) -![result_2](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/frog_horse_2.png) -![result_3](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/frog_horse_3.png) -![result_4](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/frog_horse_4.png) -![result_5](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/frog_horse_5.png) +![result_0](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_0.png) +![result_1](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_1.png) +![result_2](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_2.png) +![result_3](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_3.png) +![result_4](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_4.png) +![result_5](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_5.png) diff --git a/examples/community/unclip_text_interpolation.py b/examples/community/unclip_text_interpolation.py index d76455c9d4b4..97797851beb6 100644 --- a/examples/community/unclip_text_interpolation.py +++ b/examples/community/unclip_text_interpolation.py @@ -321,346 +321,253 @@ def __call__( Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. """ - @torch.no_grad() - # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.__call__ - def generate( - prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - prior_num_inference_steps: int = 25, - decoder_num_inference_steps: int = 25, - super_res_num_inference_steps: int = 7, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - prior_latents: Optional[torch.FloatTensor] = None, - decoder_latents: Optional[torch.FloatTensor] = None, - super_res_latents: Optional[torch.FloatTensor] = None, - text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, - text_attention_mask: Optional[torch.Tensor] = None, - prior_guidance_scale: float = 4.0, - decoder_guidance_scale: float = 8.0, - output_type: Optional[str] = "pil", - return_dict: bool = True, - ): - """ - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. This can only be left undefined if - `text_model_output` and `text_attention_mask` is passed. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - prior_num_inference_steps (`int`, *optional*, defaults to 25): - The number of denoising steps for the prior. More denoising steps usually lead to a higher quality - image at the expense of slower inference. - decoder_num_inference_steps (`int`, *optional*, defaults to 25): - The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality - image at the expense of slower inference. - super_res_num_inference_steps (`int`, *optional*, defaults to 7): - The number of denoising steps for super resolution. More denoising steps usually lead to a higher - quality image at the expense of slower inference. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): - Pre-generated noisy latents to be used as inputs for the prior. - decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): - Pre-generated noisy latents to be used as inputs for the decoder. - super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): - Pre-generated noisy latents to be used as inputs for the decoder. - prior_guidance_scale (`float`, *optional*, defaults to 4.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - decoder_guidance_scale (`float`, *optional*, defaults to 4.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - text_model_output (`CLIPTextModelOutput`, *optional*): - Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs - can be passed for tasks like text embedding interpolations. Make sure to also pass - `text_attention_mask` in this case. `prompt` can the be left to `None`. - text_attention_mask (`torch.Tensor`, *optional*): - Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention - masks are necessary when passing `text_model_output`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. - """ - if prompt is not None: - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - else: - batch_size = text_model_output[0].shape[0] + if not isinstance(start_prompt, str) or not isinstance(end_prompt, str): + raise ValueError( + f"`start_prompt` and `end_prompt` should be of type `str` but got {type(start_prompt)} and" + f" {type(end_prompt)} instead" + ) - device = self._execution_device + if enable_sequential_cpu_offload: + self.enable_sequential_cpu_offload(gpu_id=gpu_id) - batch_size = batch_size * num_images_per_prompt + device = self._execution_device - do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 + # Turn the prompts into embeddings. + inputs = self.tokenizer( + [start_prompt, end_prompt], + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + inputs.to(device) + text_model_output = self.text_encoder(**inputs) - prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - text_model_output, - text_attention_mask, + text_attention_mask = torch.max(inputs.attention_mask[0], inputs.attention_mask[1]) + text_attention_mask = torch.cat([text_attention_mask.unsqueeze(0)] * steps).to(device) + + # Interpolate from the start to end prompt using slerp and add the generated images to an image output pipeline + batch_text_embeds = [] + batch_last_hidden_state = [] + + for interp_val in torch.linspace(0, 1, steps): + text_embeds = slerp(interp_val, text_model_output.text_embeds[0], text_model_output.text_embeds[1]) + last_hidden_state = slerp( + interp_val, text_model_output.last_hidden_state[0], text_model_output.last_hidden_state[1] ) + batch_text_embeds.append(text_embeds.unsqueeze(0)) + batch_last_hidden_state.append(last_hidden_state.unsqueeze(0)) - # prior + batch_text_embeds = torch.cat(batch_text_embeds) + batch_last_hidden_state = torch.cat(batch_last_hidden_state) - self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) - prior_timesteps_tensor = self.prior_scheduler.timesteps + text_model_output = CLIPTextModelOutput( + text_embeds=batch_text_embeds, last_hidden_state=batch_last_hidden_state + ) - embedding_dim = self.prior.config.embedding_dim + batch_size = text_model_output[0].shape[0] - prior_latents = self.prepare_latents( - (batch_size, embedding_dim), - prompt_embeds.dtype, - device, - generator, - prior_latents, - self.prior_scheduler, - ) + do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 - for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents - - predicted_image_embedding = self.prior( - latent_model_input, - timestep=t, - proj_embedding=prompt_embeds, - encoder_hidden_states=text_encoder_hidden_states, - attention_mask=text_mask, - ).predicted_image_embedding - - if do_classifier_free_guidance: - predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk( - 2 - ) - predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( - predicted_image_embedding_text - predicted_image_embedding_uncond - ) - - if i + 1 == prior_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = prior_timesteps_tensor[i + 1] - - prior_latents = self.prior_scheduler.step( - predicted_image_embedding, - timestep=t, - sample=prior_latents, - generator=generator, - prev_timestep=prev_timestep, - ).prev_sample - - prior_latents = self.prior.post_process_latents(prior_latents) - - image_embeddings = prior_latents - - # done prior - - # decoder - - text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( - image_embeddings=image_embeddings, - prompt_embeds=prompt_embeds, - text_encoder_hidden_states=text_encoder_hidden_states, - do_classifier_free_guidance=do_classifier_free_guidance, - ) + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt=None, + device=device, + num_images_per_prompt=1, + do_classifier_free_guidance=do_classifier_free_guidance, + text_model_output=text_model_output, + text_attention_mask=text_attention_mask, + ) - if device.type == "mps": - # HACK: MPS: There is a panic when padding bool tensors, - # so cast to int tensor for the pad and back to bool afterwards - text_mask = text_mask.type(torch.int) - decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) - decoder_text_mask = decoder_text_mask.type(torch.bool) - else: - decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) - - self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) - decoder_timesteps_tensor = self.decoder_scheduler.timesteps - - num_channels_latents = self.decoder.in_channels - height = self.decoder.sample_size - width = self.decoder.sample_size - - decoder_latents = self.prepare_latents( - (batch_size, num_channels_latents, height, width), - text_encoder_hidden_states.dtype, - device, - generator, - decoder_latents, - self.decoder_scheduler, - ) + # prior - for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents - ) + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) + prior_timesteps_tensor = self.prior_scheduler.timesteps - noise_pred = self.decoder( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=text_encoder_hidden_states, - class_labels=additive_clip_time_embeddings, - attention_mask=decoder_text_mask, - ).sample - - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) - noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) - noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) - noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) - - if i + 1 == decoder_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = decoder_timesteps_tensor[i + 1] - - # compute the previous noisy sample x_t -> x_t-1 - decoder_latents = self.decoder_scheduler.step( - noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator - ).prev_sample - - decoder_latents = decoder_latents.clamp(-1, 1) - - image_small = decoder_latents - - # done decoder - - # super res - - self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) - super_res_timesteps_tensor = self.super_res_scheduler.timesteps - - channels = self.super_res_first.in_channels // 2 - height = self.super_res_first.sample_size - width = self.super_res_first.sample_size - - super_res_latents = self.prepare_latents( - (batch_size, channels, height, width), - image_small.dtype, - device, - generator, - super_res_latents, - self.super_res_scheduler, - ) + embedding_dim = self.prior.config.embedding_dim - if device.type == "mps": - # MPS does not support many interpolations - image_upscaled = F.interpolate(image_small, size=[height, width]) - else: - interpolate_antialias = {} - if "antialias" in inspect.signature(F.interpolate).parameters: - interpolate_antialias["antialias"] = True + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + prompt_embeds.dtype, + device, + generator, + None, + self.prior_scheduler, + ) - image_upscaled = F.interpolate( - image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond ) - for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): - # no classifier free guidance + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + prior_latents = self.prior.post_process_latents(prior_latents) - if i == super_res_timesteps_tensor.shape[0] - 1: - unet = self.super_res_last - else: - unet = self.super_res_first + image_embeddings = prior_latents - latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + # done prior - noise_pred = unet( - sample=latent_model_input, - timestep=t, - ).sample + # decoder - if i + 1 == super_res_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = super_res_timesteps_tensor[i + 1] + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + prompt_embeds=prompt_embeds, + text_encoder_hidden_states=text_encoder_hidden_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) - # compute the previous noisy sample x_t -> x_t-1 - super_res_latents = self.super_res_scheduler.step( - noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator - ).prev_sample + if device.type == "mps": + # HACK: MPS: There is a panic when padding bool tensors, + # so cast to int tensor for the pad and back to bool afterwards + text_mask = text_mask.type(torch.int) + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + decoder_text_mask = decoder_text_mask.type(torch.bool) + else: + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.in_channels + height = self.decoder.sample_size + width = self.decoder.sample_size + + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + None, + self.decoder_scheduler, + ) - image = super_res_latents - # done super res + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] - # post processing + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample - image = image * 0.5 + 0.5 - image = image.clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).float().numpy() + decoder_latents = decoder_latents.clamp(-1, 1) - if output_type == "pil": - image = self.numpy_to_pil(image) + image_small = decoder_latents - if not return_dict: - return (image,) + # done decoder - return ImagePipelineOutput(images=image) + # super res - if enable_sequential_cpu_offload: - self.enable_sequential_cpu_offload(gpu_id=gpu_id) + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps - device = self._execution_device + channels = self.super_res_first.in_channels // 2 + height = self.super_res_first.sample_size + width = self.super_res_first.sample_size - # Turn the prompts into embeddings. - inputs = self.tokenizer( - [start_prompt, end_prompt], - padding="max_length", - truncation=True, - max_length=self.tokenizer.model_max_length, - return_tensors="pt", + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + None, + self.super_res_scheduler, ) - inputs.to(device) - text_model_output = self.text_encoder(**inputs) - batch_attention_mask = torch.max(inputs.attention_mask[0], inputs.attention_mask[1]) - batch_attention_mask = torch.cat([batch_attention_mask.unsqueeze(0)] * steps).to(device) - - # Interpolate from the start to end prompt using slerp and add the generated images to an image output pipeline - batch_text_embeds = [] - batch_last_hidden_state = [] + if device.type == "mps": + # MPS does not support many interpolations + image_upscaled = F.interpolate(image_small, size=[height, width]) + else: + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True - for interp_val in torch.linspace(0, 1, steps): - text_embeds = slerp(interp_val, text_model_output.text_embeds[0], text_model_output.text_embeds[1]) - last_hidden_state = slerp( - interp_val, text_model_output.last_hidden_state[0], text_model_output.last_hidden_state[1] + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias ) - batch_text_embeds.append(text_embeds.unsqueeze(0)) - batch_last_hidden_state.append(last_hidden_state.unsqueeze(0)) - batch_text_embeds = torch.cat(batch_text_embeds) - batch_last_hidden_state = torch.cat(batch_last_hidden_state) + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance - res = generate( - text_model_output=CLIPTextModelOutput( - text_embeds=batch_text_embeds, last_hidden_state=batch_last_hidden_state - ), - text_attention_mask=batch_attention_mask, - generator=generator, - prior_num_inference_steps=prior_num_inference_steps, - decoder_num_inference_steps=decoder_num_inference_steps, - super_res_num_inference_steps=super_res_num_inference_steps, - prior_guidance_scale=prior_guidance_scale, - decoder_guidance_scale=decoder_guidance_scale, - output_type=output_type, - return_dict=return_dict, - ) - return res + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + image = super_res_latents + # done super res + + # post processing + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) From 1242e4d97d705e3b2063b587ea7406d0edc6559e Mon Sep 17 00:00:00 2001 From: Naga Sai Abhinay Date: Mon, 13 Feb 2023 09:39:56 +0530 Subject: [PATCH 10/10] Update examples/community/unclip_text_interpolation.py Refactor to linter formatting Co-authored-by: Will Berman --- examples/community/unclip_text_interpolation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/community/unclip_text_interpolation.py b/examples/community/unclip_text_interpolation.py index 97797851beb6..ac6b73d974b6 100644 --- a/examples/community/unclip_text_interpolation.py +++ b/examples/community/unclip_text_interpolation.py @@ -3,6 +3,8 @@ import torch from torch.nn import functional as F +from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput from diffusers import ( DiffusionPipeline, @@ -14,8 +16,6 @@ ) from diffusers.pipelines.unclip import UnCLIPTextProjModel from diffusers.utils import is_accelerate_available, logging, randn_tensor -from transformers import CLIPTextModelWithProjection, CLIPTokenizer -from transformers.models.clip.modeling_clip import CLIPTextModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name