Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions scripts/convert_vq_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@

import yaml
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.models.attention import Transformer2DModel
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
from transformers import CLIPTextModel, CLIPTokenizer
from yaml.loader import FullLoader

Expand Down Expand Up @@ -826,6 +826,20 @@ def read_config_file(filename):
transformer_model, checkpoint
)

# classifier free sampling embeddings interlude

# The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate
# model, so we pull them off the checkpoint before the checkpoint is deleted.

learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf

if learnable_classifier_free_sampling_embeddings:
learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"]
else:
learned_classifier_free_sampling_embeddings_embeddings = None

# done classifier free sampling embeddings interlude

with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
del diffusers_transformer_checkpoint
Expand Down Expand Up @@ -871,13 +885,39 @@ def read_config_file(filename):

# done scheduler

# learned classifier free sampling embeddings

with init_empty_weights():
learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings(
learnable_classifier_free_sampling_embeddings,
hidden_size=text_encoder_model.config.hidden_size,
length=tokenizer_model.model_max_length,
)

learned_classifier_free_sampling_checkpoint = {
"embeddings": learned_classifier_free_sampling_embeddings_embeddings.float()
}

with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file:
torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name)
del learned_classifier_free_sampling_checkpoint
del learned_classifier_free_sampling_embeddings_embeddings
load_checkpoint_and_dispatch(
learned_classifier_free_sampling_embeddings_model,
learned_classifier_free_sampling_checkpoint_file.name,
device_map="auto",
)

# done learned classifier free sampling embeddings

print(f"saving VQ diffusion model, path: {args.dump_path}")

pipe = VQDiffusionPipeline(
vqvae=vqvae_model,
transformer=transformer_model,
tokenizer=tokenizer_model,
text_encoder=text_encoder_model,
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model,
scheduler=scheduler_model,
)
pipe.save_pretrained(args.dump_path)
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/pipelines/vq_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .pipeline_vq_diffusion import VQDiffusionPipeline
from ...utils import is_torch_available, is_transformers_available


if is_transformers_available() and is_torch_available():
from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline
142 changes: 112 additions & 30 deletions src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,37 @@
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
from transformers import CLIPTextModel, CLIPTokenizer

from ...configuration_utils import ConfigMixin, register_to_config
from ...modeling_utils import ModelMixin
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...utils import logging


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin):
"""
Utility class for storing learned text embeddings for classifier free sampling
"""

@register_to_config
def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None):
super().__init__()

self.learnable = learnable

if self.learnable:
assert hidden_size is not None, "learnable=True requires `hidden_size` to be set"
assert length is not None, "learnable=True requires `length` to be set"

embeddings = torch.zeros(length, hidden_size)
else:
embeddings = None

self.embeddings = torch.nn.Parameter(embeddings)


class VQDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using VQ Diffusion
Expand Down Expand Up @@ -55,6 +79,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel
tokenizer: CLIPTokenizer
transformer: Transformer2DModel
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings
scheduler: VQDiffusionScheduler

def __init__(
Expand All @@ -64,6 +89,7 @@ def __init__(
tokenizer: CLIPTokenizer,
transformer: Transformer2DModel,
scheduler: VQDiffusionScheduler,
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's definitely the right way to do it - it's quite specific to vq-diffusion IMO though, so will move it here :-)

):
super().__init__()

Expand All @@ -73,13 +99,78 @@ def __init__(
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings,
)

def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
batch_size = len(prompt) if isinstance(prompt, list) else 1

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids

if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

if do_classifier_free_guidance:
if self.learned_classifier_free_sampling_embeddings.learnable:
uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings
uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)
else:
uncond_tokens = [""] * batch_size

max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# See comment for normalizing text embeddings
uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True)

# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

return text_embeddings

@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_inference_steps: int = 100,
guidance_scale: float = 5.0,
truncation_rate: float = 1.0,
num_images_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
Expand All @@ -98,6 +189,12 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)):
Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at
most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above
Expand Down Expand Up @@ -137,6 +234,10 @@ def __call__(

batch_size = batch_size * num_images_per_prompt

do_classifier_free_guidance = guidance_scale > 1.0

text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance)

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
Expand All @@ -145,35 +246,6 @@ def __call__(
f" {type(callback_steps)}."
)

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids

if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

# get the initial completely masked latents unless the user supplied it

latents_shape = (batch_size, self.transformer.num_latent_pixels)
Expand All @@ -198,9 +270,19 @@ def __call__(
sample = latents

for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the sample if we are doing classifier free guidance
latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample

# predict the un-noised image
# model_output == `log_p_x_0`
model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample
model_output = self.transformer(
latent_model_input, encoder_hidden_states=text_embeddings, timestep=t
).sample

if do_classifier_free_guidance:
model_output_uncond, model_output_text = model_output.chunk(2)
model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond)
model_output -= torch.logsumexp(model_output, dim=1, keepdim=True)

model_output = self.truncate(model_output, truncation_rate)

Expand Down
Loading