From 08984abc7886b64bcb3f489d20552b2d63e5fa84 Mon Sep 17 00:00:00 2001 From: William Berman Date: Mon, 14 Nov 2022 22:49:02 -0800 Subject: [PATCH 1/3] vq diffusion classifier free sampling --- scripts/convert_vq_diffusion_to_diffusers.py | 49 ++++++- src/diffusers/__init__.py | 10 +- src/diffusers/models/__init__.py | 1 + src/diffusers/models/embeddings.py | 26 ++++ .../vq_diffusion/pipeline_vq_diffusion.py | 120 +++++++++++++----- src/diffusers/utils/dummy_pt_objects.py | 15 +++ .../vq_diffusion/test_vq_diffusion.py | 91 ++++++++++++- 7 files changed, 275 insertions(+), 37 deletions(-) diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py index ae105e30362e..33198845202e 100644 --- a/scripts/convert_vq_diffusion_to_diffusers.py +++ b/scripts/convert_vq_diffusion_to_diffusers.py @@ -39,8 +39,13 @@ import yaml from accelerate import init_empty_weights, load_checkpoint_and_dispatch -from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel -from diffusers.models.attention import Transformer2DModel +from diffusers import ( + LearnedClassifierFreeSamplingEmbeddings, + Transformer2DModel, + VQDiffusionPipeline, + VQDiffusionScheduler, + VQModel, +) from transformers import CLIPTextModel, CLIPTokenizer from yaml.loader import FullLoader @@ -826,6 +831,20 @@ def read_config_file(filename): transformer_model, checkpoint ) + # classifier free sampling embeddings interlude + + # The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate + # model, so we pull them off the checkpoint before the checkpoint is deleted. + + learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf + + if learnable_classifier_free_sampling_embeddings: + learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"] + else: + learned_classifier_free_sampling_embeddings_embeddings = None + + # done classifier free sampling embeddings interlude + with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file: torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name) del diffusers_transformer_checkpoint @@ -871,6 +890,31 @@ def read_config_file(filename): # done scheduler + # learned classifier free sampling embeddings + + with init_empty_weights(): + learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings( + learnable_classifier_free_sampling_embeddings, + hidden_size=text_encoder_model.config.hidden_size, + length=tokenizer_model.model_max_length, + ) + + learned_classifier_free_sampling_checkpoint = { + "embeddings": learned_classifier_free_sampling_embeddings_embeddings.float() + } + + with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file: + torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name) + del learned_classifier_free_sampling_checkpoint + del learned_classifier_free_sampling_embeddings_embeddings + load_checkpoint_and_dispatch( + learned_classifier_free_sampling_embeddings_model, + learned_classifier_free_sampling_checkpoint_file.name, + device_map="auto", + ) + + # done learned classifier free sampling embeddings + print(f"saving VQ diffusion model, path: {args.dump_path}") pipe = VQDiffusionPipeline( @@ -878,6 +922,7 @@ def read_config_file(filename): transformer=transformer_model, tokenizer=tokenizer_model, text_encoder=text_encoder_model, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model, scheduler=scheduler_model, ) pipe.save_pretrained(args.dump_path) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 42cb2cb585c8..62a0c917f205 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,7 +18,15 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel + from .models import ( + AutoencoderKL, + LearnedClassifierFreeSamplingEmbeddings, + Transformer2DModel, + UNet1DModel, + UNet2DConditionModel, + UNet2DModel, + VQModel, + ) from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 5b101d169148..3dffd67fc351 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -17,6 +17,7 @@ if is_torch_available(): from .attention import Transformer2DModel + from .embeddings import LearnedClassifierFreeSamplingEmbeddings from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0221d891f171..c5d5ec4cfd9e 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from typing import Optional import numpy as np import torch from torch import nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.modeling_utils import ModelMixin + def get_timestep_embedding( timesteps: torch.Tensor, @@ -198,3 +202,25 @@ def forward(self, index): emb = emb + pos_emb[:, : emb.shape[1], :] return emb + + +class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): + """ + Utility class for storing learned text embeddings for classifier free sampling + """ + + @register_to_config + def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None): + super().__init__() + + self.learnable = learnable + + if self.learnable: + assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" + assert length is not None, "learnable=True requires `length` to be set" + + embeddings = torch.zeros(length, hidden_size) + else: + embeddings = None + + self.embeddings = nn.Parameter(embeddings) diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index 6e5325ba7ef5..74cab38eb91f 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -16,7 +16,7 @@ import torch -from diffusers import Transformer2DModel, VQModel +from diffusers import LearnedClassifierFreeSamplingEmbeddings, Transformer2DModel, VQModel from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler from transformers import CLIPTextModel, CLIPTokenizer @@ -55,6 +55,7 @@ class VQDiffusionPipeline(DiffusionPipeline): text_encoder: CLIPTextModel tokenizer: CLIPTokenizer transformer: Transformer2DModel + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings scheduler: VQDiffusionScheduler def __init__( @@ -64,6 +65,7 @@ def __init__( tokenizer: CLIPTokenizer, transformer: Transformer2DModel, scheduler: VQDiffusionScheduler, + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings, ): super().__init__() @@ -73,13 +75,78 @@ def __init__( text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, ) + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. + # While CLIP does normalize the pooled output of the text transformer when combining + # the image and text embeddings, CLIP does not directly normalize the last hidden state. + # + # CLIP normalizing the pooled output. + # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 + text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) + + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + if self.learned_classifier_free_sampling_embeddings.learnable: + uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings + uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1) + else: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + # See comment for normalizing text embeddings + uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], num_inference_steps: int = 100, + guidance_scale: float = 5.0, truncation_rate: float = 1.0, num_images_per_prompt: int = 1, generator: Optional[torch.Generator] = None, @@ -98,6 +165,12 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)): Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above @@ -137,6 +210,10 @@ def __call__( batch_size = batch_size * num_images_per_prompt + do_classifier_free_guidance = guidance_scale > 1.0 + + text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) + if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): @@ -145,35 +222,6 @@ def __call__( f" {type(callback_steps)}." ) - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] - - # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. - # While CLIP does normalize the pooled output of the text transformer when combining - # the image and text embeddings, CLIP does not directly normalize the last hidden state. - # - # CLIP normalizing the pooled output. - # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 - text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) - - # duplicate text embeddings for each generation per prompt - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) - # get the initial completely masked latents unless the user supplied it latents_shape = (batch_size, self.transformer.num_latent_pixels) @@ -198,9 +246,19 @@ def __call__( sample = latents for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the sample if we are doing classifier free guidance + latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample + # predict the un-noised image # model_output == `log_p_x_0` - model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample + model_output = self.transformer( + latent_model_input, encoder_hidden_states=text_embeddings, timestep=t + ).sample + + if do_classifier_free_guidance: + model_output_uncond, model_output_text = model_output.chunk(2) + model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond) + model_output -= torch.logsumexp(model_output, dim=1, keepdim=True) model_output = self.truncate(model_output, truncation_rate) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index af2e0c7c61d6..82302d1a010e 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -34,6 +34,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LearnedClassifierFreeSamplingEmbeddings(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class Transformer2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/vq_diffusion/test_vq_diffusion.py b/tests/pipelines/vq_diffusion/test_vq_diffusion.py index 5eb32d40d4f3..967be7f70cc3 100644 --- a/tests/pipelines/vq_diffusion/test_vq_diffusion.py +++ b/tests/pipelines/vq_diffusion/test_vq_diffusion.py @@ -19,7 +19,13 @@ import numpy as np import torch -from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel +from diffusers import ( + LearnedClassifierFreeSamplingEmbeddings, + Transformer2DModel, + VQDiffusionPipeline, + VQDiffusionScheduler, + VQModel, +) from diffusers.utils import load_image, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -45,6 +51,10 @@ def num_embed(self): def num_embeds_ada_norm(self): return 12 + @property + def text_embedder_hidden_size(self): + return 32 + @property def dummy_vqvae(self): torch.manual_seed(0) @@ -71,7 +81,7 @@ def dummy_text_encoder(self): config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, - hidden_size=32, + hidden_size=self.text_embedder_hidden_size, intermediate_size=37, layer_norm_eps=1e-05, num_attention_heads=4, @@ -111,9 +121,15 @@ def test_vq_diffusion(self): tokenizer = self.dummy_tokenizer transformer = self.dummy_transformer scheduler = VQDiffusionScheduler(self.num_embed) + learned_classifier_free_sampling_embeddings = LearnedClassifierFreeSamplingEmbeddings(learnable=False) pipe = VQDiffusionPipeline( - vqvae=vqvae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler + vqvae=vqvae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, ) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -139,6 +155,50 @@ def test_vq_diffusion(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_vq_diffusion_classifier_free_sampling(self): + device = "cpu" + + vqvae = self.dummy_vqvae + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + transformer = self.dummy_transformer + scheduler = VQDiffusionScheduler(self.num_embed) + learned_classifier_free_sampling_embeddings = LearnedClassifierFreeSamplingEmbeddings( + learnable=True, hidden_size=self.text_embedder_hidden_size, length=tokenizer.model_max_length + ) + + pipe = VQDiffusionPipeline( + vqvae=vqvae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, + ) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + prompt = "teddy bear playing in the pool" + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = pipe( + [prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2 + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 24, 24, 3) + + expected_slice = np.array([0.6647, 0.6531, 0.5303, 0.5891, 0.5726, 0.4439, 0.6304, 0.5564, 0.4912]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + @slow @require_torch_gpu @@ -163,6 +223,7 @@ def test_vq_diffusion(self): generator = torch.Generator(device=torch_device).manual_seed(0) output = pipeline( "teddy bear playing in the pool", + guidance_scale=1.0, truncation_rate=0.86, num_images_per_prompt=1, generator=generator, @@ -173,3 +234,27 @@ def test_vq_diffusion(self): assert image.shape == (256, 256, 3) assert np.abs(expected_image - image).max() < 1e-2 + + def test_vq_diffusion_classifier_free_sampling(self): + expected_image = load_image( + "https://huggingface.co/datasets/williamberman/misc/resolve/main" + "/vq_diffusion/teddy_bear_pool_classifier_free_sampling.png" + ) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 + + pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq") + pipeline = pipeline.to(torch_device) + pipeline.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipeline( + "teddy bear playing in the pool", + num_images_per_prompt=1, + generator=generator, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (256, 256, 3) + assert np.abs(expected_image - image).max() < 1e-2 From d1577c263f08a36bc35f6e08fbec1273c4b97358 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 16 Nov 2022 16:10:27 +0000 Subject: [PATCH 2/3] correct --- scripts/convert_vq_diffusion_to_diffusers.py | 9 +--- src/diffusers/__init__.py | 10 +---- src/diffusers/models/__init__.py | 1 - src/diffusers/models/embeddings.py | 26 ----------- .../pipelines/vq_diffusion/__init__.py | 6 ++- .../vq_diffusion/pipeline_vq_diffusion.py | 26 ++++++++++- .../vq_diffusion/test_vq_diffusion.py | 44 +++---------------- 7 files changed, 39 insertions(+), 83 deletions(-) diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py index 33198845202e..85db67844acd 100644 --- a/scripts/convert_vq_diffusion_to_diffusers.py +++ b/scripts/convert_vq_diffusion_to_diffusers.py @@ -39,13 +39,8 @@ import yaml from accelerate import init_empty_weights, load_checkpoint_and_dispatch -from diffusers import ( - LearnedClassifierFreeSamplingEmbeddings, - Transformer2DModel, - VQDiffusionPipeline, - VQDiffusionScheduler, - VQModel, -) +from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel +from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings from transformers import CLIPTextModel, CLIPTokenizer from yaml.loader import FullLoader diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 62a0c917f205..42cb2cb585c8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,15 +18,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import ( - AutoencoderKL, - LearnedClassifierFreeSamplingEmbeddings, - Transformer2DModel, - UNet1DModel, - UNet2DConditionModel, - UNet2DModel, - VQModel, - ) + from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 3dffd67fc351..5b101d169148 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -17,7 +17,6 @@ if is_torch_available(): from .attention import Transformer2DModel - from .embeddings import LearnedClassifierFreeSamplingEmbeddings from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c5d5ec4cfd9e..0221d891f171 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -12,15 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional import numpy as np import torch from torch import nn -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.modeling_utils import ModelMixin - def get_timestep_embedding( timesteps: torch.Tensor, @@ -202,25 +198,3 @@ def forward(self, index): emb = emb + pos_emb[:, : emb.shape[1], :] return emb - - -class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): - """ - Utility class for storing learned text embeddings for classifier free sampling - """ - - @register_to_config - def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None): - super().__init__() - - self.learnable = learnable - - if self.learnable: - assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" - assert length is not None, "learnable=True requires `length` to be set" - - embeddings = torch.zeros(length, hidden_size) - else: - embeddings = None - - self.embeddings = nn.Parameter(embeddings) diff --git a/src/diffusers/pipelines/vq_diffusion/__init__.py b/src/diffusers/pipelines/vq_diffusion/__init__.py index edf6f570f5bf..8c9f14f00064 100644 --- a/src/diffusers/pipelines/vq_diffusion/__init__.py +++ b/src/diffusers/pipelines/vq_diffusion/__init__.py @@ -1 +1,5 @@ -from .pipeline_vq_diffusion import VQDiffusionPipeline +from ...utils import is_torch_available, is_transformers_available + + +if is_transformers_available() and is_torch_available(): + from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index 74cab38eb91f..333599d7ecf8 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -16,10 +16,12 @@ import torch -from diffusers import LearnedClassifierFreeSamplingEmbeddings, Transformer2DModel, VQModel +from diffusers import Transformer2DModel, VQModel from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler from transformers import CLIPTextModel, CLIPTokenizer +from ...configuration_utils import ConfigMixin, register_to_config +from ...modeling_utils import ModelMixin from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...utils import logging @@ -27,6 +29,28 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): + """ + Utility class for storing learned text embeddings for classifier free sampling + """ + + @register_to_config + def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None): + super().__init__() + + self.learnable = learnable + + if self.learnable: + assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" + assert length is not None, "learnable=True requires `length` to be set" + + embeddings = torch.zeros(length, hidden_size) + else: + embeddings = None + + self.embeddings = torch.nn.Parameter(embeddings) + + class VQDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using VQ Diffusion diff --git a/tests/pipelines/vq_diffusion/test_vq_diffusion.py b/tests/pipelines/vq_diffusion/test_vq_diffusion.py index 967be7f70cc3..87e29cbc97de 100644 --- a/tests/pipelines/vq_diffusion/test_vq_diffusion.py +++ b/tests/pipelines/vq_diffusion/test_vq_diffusion.py @@ -19,14 +19,9 @@ import numpy as np import torch -from diffusers import ( - LearnedClassifierFreeSamplingEmbeddings, - Transformer2DModel, - VQDiffusionPipeline, - VQDiffusionScheduler, - VQModel, -) -from diffusers.utils import load_image, slow, torch_device +from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel +from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings +from diffusers.utils import load_numpy, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -209,38 +204,11 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() - def test_vq_diffusion(self): - expected_image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/vq_diffusion/teddy_bear_pool.png" - ) - expected_image = np.array(expected_image, dtype=np.float32) / 255.0 - - pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq") - pipeline = pipeline.to(torch_device) - pipeline.set_progress_bar_config(disable=None) - - generator = torch.Generator(device=torch_device).manual_seed(0) - output = pipeline( - "teddy bear playing in the pool", - guidance_scale=1.0, - truncation_rate=0.86, - num_images_per_prompt=1, - generator=generator, - output_type="np", - ) - - image = output.images[0] - - assert image.shape == (256, 256, 3) - assert np.abs(expected_image - image).max() < 1e-2 - def test_vq_diffusion_classifier_free_sampling(self): - expected_image = load_image( - "https://huggingface.co/datasets/williamberman/misc/resolve/main" - "/vq_diffusion/teddy_bear_pool_classifier_free_sampling.png" + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/vq_diffusion/teddy_bear_pool_classifier_free_sampling.npy" ) - expected_image = np.array(expected_image, dtype=np.float32) / 255.0 pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq") pipeline = pipeline.to(torch_device) From dac2ece596ba4d89805d172b2f2cf09c3efa7232 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 16 Nov 2022 16:25:05 +0000 Subject: [PATCH 3/3] uP --- src/diffusers/utils/dummy_pt_objects.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 82302d1a010e..af2e0c7c61d6 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -34,21 +34,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class LearnedClassifierFreeSamplingEmbeddings(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class Transformer2DModel(metaclass=DummyObject): _backends = ["torch"]