Skip to content

Commit d14dff7

Browse files
committed
[WIP] scheduler scaffolding
imports init scheduler hacking more work fixes and docs more cleaning
1 parent 1b1ee17 commit d14dff7

File tree

7 files changed

+520
-61
lines changed

7 files changed

+520
-61
lines changed

scripts/convert_vq_diffusion_to_diffusers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@
3232

3333
import yaml
3434
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
35-
from diffusers import VQModel
35+
from diffusers import VQModel, VQDiffusionPipeline, VQDiffusionScheduler
3636
from diffusers.models.vq_diffusion_attention import VQDiffusionTransformer
37-
from diffusers.pipelines import VQDiffusionPipeline
3837
from transformers import CLIPTextModel, CLIPTokenizer
3938
from yaml.loader import FullLoader
4039

@@ -846,10 +845,16 @@ def read_config_file(filename):
846845

847846
# done text encoder
848847

848+
scheduler_model = VQDiffusionScheduler()
849+
849850
print(f"saving VQ diffusion model, path: {args.dump_path}")
850851

851852
pipe = VQDiffusionPipeline(
852-
vqvae=vqvae_model, transformer=transformer_model, tokenizer=tokenizer_model, text_encoder=text_encoder_model
853+
vqvae=vqvae_model,
854+
transformer=transformer_model,
855+
tokenizer=tokenizer_model,
856+
text_encoder=text_encoder_model,
857+
scheduler=scheduler_model,
853858
)
854859
pipe.save_pretrained(args.dump_path)
855860

src/diffusers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@
2929
get_scheduler,
3030
)
3131
from .pipeline_utils import DiffusionPipeline
32-
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
32+
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline, VQDiffusionPipeline
3333
from .schedulers import (
3434
DDIMScheduler,
3535
DDPMScheduler,
3636
KarrasVeScheduler,
3737
PNDMScheduler,
3838
SchedulerMixin,
3939
ScoreSdeVeScheduler,
40+
VQDiffusionScheduler
4041
)
4142
from .training_utils import EMAModel
4243
else:

src/diffusers/models/embeddings.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -115,52 +115,3 @@ def forward(self, x):
115115
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
116116
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
117117
return out
118-
119-
120-
# TODO(will) - document this. check if throwing errors internally is appropriate
121-
class DalleMaskImageEmbedding(nn.Module):
122-
def __init__(
123-
self,
124-
num_embed,
125-
height,
126-
width,
127-
embed_dim,
128-
):
129-
super().__init__()
130-
131-
self.height = height
132-
self.width = width
133-
# TODO(will) add docs on why this is incremented by 1. (Has to do with mask?)
134-
self.num_embed = num_embed + 1
135-
self.embed_dim = embed_dim
136-
137-
self.emb = nn.Embedding(self.num_embed, embed_dim)
138-
self.height_emb = nn.Embedding(self.height, embed_dim)
139-
self.width_emb = nn.Embedding(self.width, embed_dim)
140-
141-
def forward(self, index):
142-
assert index.dim() == 2 # B x L
143-
try:
144-
index[index < 0] = 0
145-
emb = self.emb(index)
146-
except:
147-
raise RuntimeError(
148-
"IndexError: index out of range in self, max index {}, num embed {}".format(
149-
index.max(), self.num_embed
150-
)
151-
)
152-
153-
# add col and row embedding
154-
if emb.shape[1] > 0:
155-
height_emb = self.height_emb(
156-
torch.arange(self.height, device=index.device).view(1, self.height)
157-
).unsqueeze(
158-
2
159-
) # 1 x H x D -> 1 x H x 1 x D
160-
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)).unsqueeze(
161-
1
162-
) # 1 x W x D -> 1 x 1 x W x D
163-
pos_emb = (height_emb + width_emb).view(1, self.height * self.width, -1) # 1 x H x W x D -> 1 x L xD
164-
emb = emb + pos_emb[:, : emb.shape[1], :]
165-
166-
return emb

src/diffusers/models/vq_diffusion_attention.py

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import torch
44
from torch import nn
5+
import torch.nn.functional as F
56

67
from diffusers.configuration_utils import ConfigMixin, register_to_config
78
from diffusers.modeling_utils import ModelMixin
8-
from diffusers.models.embeddings import DalleMaskImageEmbedding
99

1010
from .attention import CrossAttention
1111

@@ -23,20 +23,27 @@ def __init__(
2323
width: int,
2424
diffusion_steps: int,
2525
dropout: float = 0.0,
26+
min_logged_value: float = -70.0
2627
):
2728
super().__init__()
29+
2830
self.n_heads = n_heads
2931
self.d_head = d_head
30-
inner_dim = n_heads * d_head
32+
self.inner_dim = n_heads * d_head
33+
self.min_logged_value = min_logged_value
3134

35+
# The input to the `DalleMaskImageEmbedding` layer is the
36+
# embedding indices from the quantized codebook with an additional
37+
# index for the masked value.
38+
num_embed_with_mask = num_embed + 1
3239
self.latent_image_embedding = DalleMaskImageEmbedding(
33-
num_embed=num_embed, embed_dim=inner_dim, height=height, width=width
40+
num_embed=num_embed_with_mask, embed_dim=self.inner_dim, height=height, width=width
3441
)
3542

3643
self.transformer_blocks = nn.ModuleList(
3744
[
3845
BasicTransformerBlock(
39-
inner_dim,
46+
self.inner_dim,
4047
n_heads,
4148
d_head,
4249
dropout=dropout,
@@ -48,21 +55,75 @@ def __init__(
4855
]
4956
)
5057

51-
self.norm_out = nn.LayerNorm(inner_dim)
52-
self.out = nn.Linear(inner_dim, num_embed)
58+
self.norm_out = nn.LayerNorm(self.inner_dim)
59+
60+
# The output from the transformer is the embedding indices for the
61+
# quantized codebook. It does not include additional index for the
62+
# masked value because the transformer predicts the unnoised image
63+
# which has no masks
64+
self.out = nn.Linear(self.inner_dim, num_embed)
5365

5466
def forward(self, latent_images, cond_emb, t):
67+
bsz = latent_images.shape[0]
68+
5569
embedded_latent_images = self.latent_image_embedding(latent_images)
5670
hidden_states = embedded_latent_images
5771

5872
for block in self.transformer_blocks:
5973
hidden_states = block(hidden_states, cond_emb, t)
6074

6175
logits = self.out(self.norm_out(hidden_states))
62-
out = logits.permute(0, 2, 1)
6376

64-
return out
77+
# equivalent to `torch.zeros((bsz, self.inner_dim, 1)).log().clamp(self.min_logged_value)`
78+
log_zero_vector = torch.full((bsz, self.inner_dim, 1), self.min_logged_value, device=logits.device)
79+
80+
log_x_0 = F.log_softmax(logits.double(), dim=-1).float().clamp(self.min_logged_value)
81+
log_x_0 = torch.cat((log_x_0, log_zero_vector), dim=-1)
82+
83+
# TODO(will) can remove?
84+
log_x_0 = log_x_0.permute(0, 2, 1)
85+
86+
return log_x_0
87+
88+
89+
# TODO(will) - document this
90+
class DalleMaskImageEmbedding(nn.Module):
91+
def __init__(
92+
self,
93+
num_embed,
94+
height,
95+
width,
96+
embed_dim,
97+
):
98+
super().__init__()
99+
100+
self.height = height
101+
self.width = width
102+
self.num_embed = num_embed
103+
self.embed_dim = embed_dim
104+
105+
self.emb = nn.Embedding(self.num_embed, embed_dim)
106+
self.height_emb = nn.Embedding(self.height, embed_dim)
107+
self.width_emb = nn.Embedding(self.width, embed_dim)
108+
109+
def forward(self, index):
110+
emb = self.emb(index)
111+
112+
height_emb = self.height_emb(
113+
torch.arange(self.height, device=index.device).view(1, self.height)
114+
).unsqueeze(
115+
2
116+
) # 1 x H x D -> 1 x H x 1 x D
117+
118+
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)).unsqueeze(
119+
1
120+
) # 1 x W x D -> 1 x 1 x W x D
121+
122+
pos_emb = (height_emb + width_emb).view(1, self.height * self.width, -1) # 1 x H x W x D -> 1 x L xD
123+
124+
emb = emb + pos_emb[:, : emb.shape[1], :]
65125

126+
return emb
66127

67128
class BasicTransformerBlock(nn.Module):
68129
def __init__(

src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,32 @@
1+
from dataclasses import dataclass
2+
from typing import Callable, List, Optional, Union
3+
4+
import numpy as np
5+
import torch
6+
7+
import PIL
18
from diffusers import VQDiffusionTransformer, VQModel
9+
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
210
from transformers import CLIPTextModel, CLIPTokenizer
311

412
from ...pipeline_utils import DiffusionPipeline
13+
from ...utils import BaseOutput, logging
14+
15+
16+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17+
18+
19+
@dataclass
20+
class VQDiffusionPipelineOutput(BaseOutput):
21+
"""
22+
Args:
23+
Output class for VQ Diffusion pipelines.
24+
images (`List[PIL.Image.Image]` or `np.ndarray`)
25+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
26+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
27+
"""
28+
29+
images: Union[List[PIL.Image.Image], np.ndarray]
530

631

732
# This class is a placeholder and does not have the full VQ-diffusion pipeline built out yet
@@ -14,18 +39,127 @@
1439
class VQDiffusionPipeline(DiffusionPipeline):
1540
vqvae: VQModel
1641
transformer: VQDiffusionTransformer
42+
text_encoder: CLIPTextModel
43+
tokenizer: CLIPTokenizer
44+
scheduler: VQDiffusionScheduler
1745

1846
def __init__(
1947
self,
2048
vqvae: VQModel,
2149
transformer: VQDiffusionTransformer,
2250
text_encoder: CLIPTextModel,
2351
tokenizer: CLIPTokenizer,
52+
scheduler: VQDiffusionScheduler,
2453
):
2554
super().__init__()
55+
2656
self.register_modules(
2757
vqvae=vqvae,
2858
transformer=transformer,
2959
text_encoder=text_encoder,
3060
tokenizer=tokenizer,
61+
scheduler=scheduler,
3162
)
63+
64+
@torch.no_grad()
65+
def __call__(
66+
self,
67+
prompt: Union[str, List[str]],
68+
height: int = 256,
69+
width: int = 256,
70+
num_inference_steps: int = 100,
71+
num_images_per_prompt: int = 1,
72+
generator: Optional[torch.Generator] = None,
73+
latents: Optional[torch.FloatTensor] = None,
74+
output_type: Optional[str] = "pil",
75+
return_dict: bool = True,
76+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
77+
callback_steps: Optional[int] = 1,
78+
):
79+
if isinstance(prompt, str):
80+
batch_size = 1
81+
elif isinstance(prompt, list):
82+
batch_size = len(prompt)
83+
else:
84+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
85+
86+
if (callback_steps is None) or (
87+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
88+
):
89+
raise ValueError(
90+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
91+
f" {type(callback_steps)}."
92+
)
93+
94+
# get prompt text embeddings
95+
text_inputs = self.tokenizer(
96+
prompt,
97+
padding="max_length",
98+
max_length=self.tokenizer.model_max_length,
99+
return_tensors="pt",
100+
)
101+
text_input_ids = text_inputs.input_ids
102+
103+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
104+
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
105+
logger.warning(
106+
"The following part of your input was truncated because CLIP can only handle sequences up to"
107+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
108+
)
109+
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
110+
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
111+
112+
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
113+
# While CLIP does normalize the pooled output of the text transformer when combining
114+
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
115+
#
116+
# CLIP normalizing the pooled output.
117+
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
118+
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
119+
120+
# duplicate text embeddings for each generation per prompt
121+
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
122+
123+
# get the initial random noise unless the user supplied it
124+
125+
# TODO I believe the latents are the indices of the of the vectors
126+
127+
# TODO HERE - what's the input shape?
128+
latents_shape = TODO # (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
129+
latents_dtype = text_embeddings.dtype
130+
if latents is None:
131+
# all masked?
132+
latents = TODO # torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(self.device)
133+
else:
134+
if latents.shape != latents_shape:
135+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
136+
latents = latents.to(self.device)
137+
138+
# set timesteps
139+
self.scheduler.set_timesteps(num_inference_steps)
140+
141+
timesteps_tensor = self.scheduler.timesteps.to(self.device)
142+
143+
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
144+
# predict the un-noised image
145+
log_x_start = TODO # self.transformer(latents, t, encoder_hidden_states=text_embeddings).sample
146+
147+
# compute the previous noisy sample x_t -> x_t-1
148+
latents = TODO # self.scheduler.step(x0_pred, t, latents).prev_sample
149+
150+
# call the callback, if provided
151+
if callback is not None and i % callback_steps == 0:
152+
callback(i, t, latents)
153+
154+
image = self.vqvae.decode(latents).sample
155+
156+
image = (image / 2 + 0.5).clamp(0, 1)
157+
image = image.cpu().permute(0, 2, 3, 1).numpy()
158+
159+
if output_type == "pil":
160+
image = self.numpy_to_pil(image)
161+
162+
if not return_dict:
163+
return image
164+
165+
return VQDiffusionPipelineOutput(images=image)

src/diffusers/schedulers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .scheduling_sde_ve import ScoreSdeVeScheduler
2525
from .scheduling_sde_vp import ScoreSdeVpScheduler
2626
from .scheduling_utils import SchedulerMixin
27+
from .scheduling_vq_diffusion import VQDiffusionScheduler
2728
else:
2829
from ..utils.dummy_pt_objects import * # noqa F403
2930

0 commit comments

Comments
 (0)