diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py index ae105e30362e..85db67844acd 100644 --- a/scripts/convert_vq_diffusion_to_diffusers.py +++ b/scripts/convert_vq_diffusion_to_diffusers.py @@ -39,8 +39,8 @@ import yaml from accelerate import init_empty_weights, load_checkpoint_and_dispatch -from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel -from diffusers.models.attention import Transformer2DModel +from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel +from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings from transformers import CLIPTextModel, CLIPTokenizer from yaml.loader import FullLoader @@ -826,6 +826,20 @@ def read_config_file(filename): transformer_model, checkpoint ) + # classifier free sampling embeddings interlude + + # The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate + # model, so we pull them off the checkpoint before the checkpoint is deleted. + + learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf + + if learnable_classifier_free_sampling_embeddings: + learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"] + else: + learned_classifier_free_sampling_embeddings_embeddings = None + + # done classifier free sampling embeddings interlude + with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file: torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name) del diffusers_transformer_checkpoint @@ -871,6 +885,31 @@ def read_config_file(filename): # done scheduler + # learned classifier free sampling embeddings + + with init_empty_weights(): + learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings( + learnable_classifier_free_sampling_embeddings, + hidden_size=text_encoder_model.config.hidden_size, + length=tokenizer_model.model_max_length, + ) + + learned_classifier_free_sampling_checkpoint = { + "embeddings": learned_classifier_free_sampling_embeddings_embeddings.float() + } + + with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file: + torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name) + del learned_classifier_free_sampling_checkpoint + del learned_classifier_free_sampling_embeddings_embeddings + load_checkpoint_and_dispatch( + learned_classifier_free_sampling_embeddings_model, + learned_classifier_free_sampling_checkpoint_file.name, + device_map="auto", + ) + + # done learned classifier free sampling embeddings + print(f"saving VQ diffusion model, path: {args.dump_path}") pipe = VQDiffusionPipeline( @@ -878,6 +917,7 @@ def read_config_file(filename): transformer=transformer_model, tokenizer=tokenizer_model, text_encoder=text_encoder_model, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model, scheduler=scheduler_model, ) pipe.save_pretrained(args.dump_path) diff --git a/src/diffusers/pipelines/vq_diffusion/__init__.py b/src/diffusers/pipelines/vq_diffusion/__init__.py index edf6f570f5bf..8c9f14f00064 100644 --- a/src/diffusers/pipelines/vq_diffusion/__init__.py +++ b/src/diffusers/pipelines/vq_diffusion/__init__.py @@ -1 +1,5 @@ -from .pipeline_vq_diffusion import VQDiffusionPipeline +from ...utils import is_torch_available, is_transformers_available + + +if is_transformers_available() and is_torch_available(): + from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index 6e5325ba7ef5..333599d7ecf8 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -20,6 +20,8 @@ from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler from transformers import CLIPTextModel, CLIPTokenizer +from ...configuration_utils import ConfigMixin, register_to_config +from ...modeling_utils import ModelMixin from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...utils import logging @@ -27,6 +29,28 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): + """ + Utility class for storing learned text embeddings for classifier free sampling + """ + + @register_to_config + def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None): + super().__init__() + + self.learnable = learnable + + if self.learnable: + assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" + assert length is not None, "learnable=True requires `length` to be set" + + embeddings = torch.zeros(length, hidden_size) + else: + embeddings = None + + self.embeddings = torch.nn.Parameter(embeddings) + + class VQDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using VQ Diffusion @@ -55,6 +79,7 @@ class VQDiffusionPipeline(DiffusionPipeline): text_encoder: CLIPTextModel tokenizer: CLIPTokenizer transformer: Transformer2DModel + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings scheduler: VQDiffusionScheduler def __init__( @@ -64,6 +89,7 @@ def __init__( tokenizer: CLIPTokenizer, transformer: Transformer2DModel, scheduler: VQDiffusionScheduler, + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings, ): super().__init__() @@ -73,13 +99,78 @@ def __init__( text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, + ) + + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. + # While CLIP does normalize the pooled output of the text transformer when combining + # the image and text embeddings, CLIP does not directly normalize the last hidden state. + # + # CLIP normalizing the pooled output. + # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 + text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) + + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + if self.learned_classifier_free_sampling_embeddings.learnable: + uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings + uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1) + else: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + # See comment for normalizing text embeddings + uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], num_inference_steps: int = 100, + guidance_scale: float = 5.0, truncation_rate: float = 1.0, num_images_per_prompt: int = 1, generator: Optional[torch.Generator] = None, @@ -98,6 +189,12 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)): Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above @@ -137,6 +234,10 @@ def __call__( batch_size = batch_size * num_images_per_prompt + do_classifier_free_guidance = guidance_scale > 1.0 + + text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) + if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): @@ -145,35 +246,6 @@ def __call__( f" {type(callback_steps)}." ) - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] - - # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. - # While CLIP does normalize the pooled output of the text transformer when combining - # the image and text embeddings, CLIP does not directly normalize the last hidden state. - # - # CLIP normalizing the pooled output. - # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 - text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) - - # duplicate text embeddings for each generation per prompt - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) - # get the initial completely masked latents unless the user supplied it latents_shape = (batch_size, self.transformer.num_latent_pixels) @@ -198,9 +270,19 @@ def __call__( sample = latents for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the sample if we are doing classifier free guidance + latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample + # predict the un-noised image # model_output == `log_p_x_0` - model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample + model_output = self.transformer( + latent_model_input, encoder_hidden_states=text_embeddings, timestep=t + ).sample + + if do_classifier_free_guidance: + model_output_uncond, model_output_text = model_output.chunk(2) + model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond) + model_output -= torch.logsumexp(model_output, dim=1, keepdim=True) model_output = self.truncate(model_output, truncation_rate) diff --git a/tests/pipelines/vq_diffusion/test_vq_diffusion.py b/tests/pipelines/vq_diffusion/test_vq_diffusion.py index 5eb32d40d4f3..87e29cbc97de 100644 --- a/tests/pipelines/vq_diffusion/test_vq_diffusion.py +++ b/tests/pipelines/vq_diffusion/test_vq_diffusion.py @@ -20,7 +20,8 @@ import torch from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel -from diffusers.utils import load_image, slow, torch_device +from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings +from diffusers.utils import load_numpy, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -45,6 +46,10 @@ def num_embed(self): def num_embeds_ada_norm(self): return 12 + @property + def text_embedder_hidden_size(self): + return 32 + @property def dummy_vqvae(self): torch.manual_seed(0) @@ -71,7 +76,7 @@ def dummy_text_encoder(self): config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, - hidden_size=32, + hidden_size=self.text_embedder_hidden_size, intermediate_size=37, layer_norm_eps=1e-05, num_attention_heads=4, @@ -111,9 +116,15 @@ def test_vq_diffusion(self): tokenizer = self.dummy_tokenizer transformer = self.dummy_transformer scheduler = VQDiffusionScheduler(self.num_embed) + learned_classifier_free_sampling_embeddings = LearnedClassifierFreeSamplingEmbeddings(learnable=False) pipe = VQDiffusionPipeline( - vqvae=vqvae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler + vqvae=vqvae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, ) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -139,6 +150,50 @@ def test_vq_diffusion(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_vq_diffusion_classifier_free_sampling(self): + device = "cpu" + + vqvae = self.dummy_vqvae + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + transformer = self.dummy_transformer + scheduler = VQDiffusionScheduler(self.num_embed) + learned_classifier_free_sampling_embeddings = LearnedClassifierFreeSamplingEmbeddings( + learnable=True, hidden_size=self.text_embedder_hidden_size, length=tokenizer.model_max_length + ) + + pipe = VQDiffusionPipeline( + vqvae=vqvae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, + ) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + prompt = "teddy bear playing in the pool" + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = pipe( + [prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2 + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 24, 24, 3) + + expected_slice = np.array([0.6647, 0.6531, 0.5303, 0.5891, 0.5726, 0.4439, 0.6304, 0.5564, 0.4912]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + @slow @require_torch_gpu @@ -149,12 +204,11 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() - def test_vq_diffusion(self): - expected_image = load_image( + def test_vq_diffusion_classifier_free_sampling(self): + expected_image = load_numpy( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/vq_diffusion/teddy_bear_pool.png" + "/vq_diffusion/teddy_bear_pool_classifier_free_sampling.npy" ) - expected_image = np.array(expected_image, dtype=np.float32) / 255.0 pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq") pipeline = pipeline.to(torch_device) @@ -163,7 +217,6 @@ def test_vq_diffusion(self): generator = torch.Generator(device=torch_device).manual_seed(0) output = pipeline( "teddy bear playing in the pool", - truncation_rate=0.86, num_images_per_prompt=1, generator=generator, output_type="np",