Skip to content

Commit 6b2bc69

Browse files
committed
[WIP] scheduler
1 parent 1b1ee17 commit 6b2bc69

File tree

7 files changed

+532
-61
lines changed

7 files changed

+532
-61
lines changed

scripts/convert_vq_diffusion_to_diffusers.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@
3232

3333
import yaml
3434
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
35-
from diffusers import VQModel
35+
from diffusers import VQModel, VQDiffusionPipeline, VQDiffusionScheduler
3636
from diffusers.models.vq_diffusion_attention import VQDiffusionTransformer
37-
from diffusers.pipelines import VQDiffusionPipeline
3837
from transformers import CLIPTextModel, CLIPTokenizer
3938
from yaml.loader import FullLoader
4039

@@ -492,7 +491,12 @@ def transformer_model_from_original_config(
492491

493492
depth = original_transformer_config["n_layer"]
494493
context_dim = original_transformer_config["condition_dim"]
494+
495495
num_embed = original_content_embedding_config["num_embed"]
496+
# the number of embeddings in the transformer includes the mask embedding.
497+
# the content embedding (the vqvae) does not include the mask embedding.
498+
num_embed = num_embed + 1
499+
496500
height = original_transformer_config["content_spatial_size"][0]
497501
width = original_transformer_config["content_spatial_size"][1]
498502
dropout = original_transformer_config["resid_pdrop"]
@@ -846,10 +850,23 @@ def read_config_file(filename):
846850

847851
# done text encoder
848852

853+
# scheduler
854+
855+
scheduler_model = VQDiffusionScheduler(
856+
# the scheduler has the same number of embeddings as the transformer
857+
num_embed=transformer_model.num_embed
858+
)
859+
860+
# done scheduler
861+
849862
print(f"saving VQ diffusion model, path: {args.dump_path}")
850863

851864
pipe = VQDiffusionPipeline(
852-
vqvae=vqvae_model, transformer=transformer_model, tokenizer=tokenizer_model, text_encoder=text_encoder_model
865+
vqvae=vqvae_model,
866+
transformer=transformer_model,
867+
tokenizer=tokenizer_model,
868+
text_encoder=text_encoder_model,
869+
scheduler=scheduler_model,
853870
)
854871
pipe.save_pretrained(args.dump_path)
855872

src/diffusers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@
2929
get_scheduler,
3030
)
3131
from .pipeline_utils import DiffusionPipeline
32-
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
32+
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline, VQDiffusionPipeline
3333
from .schedulers import (
3434
DDIMScheduler,
3535
DDPMScheduler,
3636
KarrasVeScheduler,
3737
PNDMScheduler,
3838
SchedulerMixin,
3939
ScoreSdeVeScheduler,
40+
VQDiffusionScheduler
4041
)
4142
from .training_utils import EMAModel
4243
else:

src/diffusers/models/embeddings.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -115,52 +115,3 @@ def forward(self, x):
115115
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
116116
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
117117
return out
118-
119-
120-
# TODO(will) - document this. check if throwing errors internally is appropriate
121-
class DalleMaskImageEmbedding(nn.Module):
122-
def __init__(
123-
self,
124-
num_embed,
125-
height,
126-
width,
127-
embed_dim,
128-
):
129-
super().__init__()
130-
131-
self.height = height
132-
self.width = width
133-
# TODO(will) add docs on why this is incremented by 1. (Has to do with mask?)
134-
self.num_embed = num_embed + 1
135-
self.embed_dim = embed_dim
136-
137-
self.emb = nn.Embedding(self.num_embed, embed_dim)
138-
self.height_emb = nn.Embedding(self.height, embed_dim)
139-
self.width_emb = nn.Embedding(self.width, embed_dim)
140-
141-
def forward(self, index):
142-
assert index.dim() == 2 # B x L
143-
try:
144-
index[index < 0] = 0
145-
emb = self.emb(index)
146-
except:
147-
raise RuntimeError(
148-
"IndexError: index out of range in self, max index {}, num embed {}".format(
149-
index.max(), self.num_embed
150-
)
151-
)
152-
153-
# add col and row embedding
154-
if emb.shape[1] > 0:
155-
height_emb = self.height_emb(
156-
torch.arange(self.height, device=index.device).view(1, self.height)
157-
).unsqueeze(
158-
2
159-
) # 1 x H x D -> 1 x H x 1 x D
160-
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)).unsqueeze(
161-
1
162-
) # 1 x W x D -> 1 x 1 x W x D
163-
pos_emb = (height_emb + width_emb).view(1, self.height * self.width, -1) # 1 x H x W x D -> 1 x L xD
164-
emb = emb + pos_emb[:, : emb.shape[1], :]
165-
166-
return emb

src/diffusers/models/vq_diffusion_attention.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import torch
44
from torch import nn
5+
import torch.nn.functional as F
56

67
from diffusers.configuration_utils import ConfigMixin, register_to_config
78
from diffusers.modeling_utils import ModelMixin
8-
from diffusers.models.embeddings import DalleMaskImageEmbedding
99

1010
from .attention import CrossAttention
1111

@@ -23,20 +23,27 @@ def __init__(
2323
width: int,
2424
diffusion_steps: int,
2525
dropout: float = 0.0,
26+
min_logged_value: float = -70.0
2627
):
2728
super().__init__()
29+
2830
self.n_heads = n_heads
2931
self.d_head = d_head
30-
inner_dim = n_heads * d_head
32+
self.inner_dim = n_heads * d_head
33+
self.min_logged_value = min_logged_value
34+
self.num_embed = num_embed
35+
self.height = height
36+
self.width = width
37+
self.num_latent_pixels = self.height * self.width
3138

3239
self.latent_image_embedding = DalleMaskImageEmbedding(
33-
num_embed=num_embed, embed_dim=inner_dim, height=height, width=width
40+
num_embed=self.num_embed, embed_dim=self.inner_dim, height=height, width=width
3441
)
3542

3643
self.transformer_blocks = nn.ModuleList(
3744
[
3845
BasicTransformerBlock(
39-
inner_dim,
46+
self.inner_dim,
4047
n_heads,
4148
d_head,
4249
dropout=dropout,
@@ -48,21 +55,80 @@ def __init__(
4855
]
4956
)
5057

51-
self.norm_out = nn.LayerNorm(inner_dim)
52-
self.out = nn.Linear(inner_dim, num_embed)
58+
self.norm_out = nn.LayerNorm(self.inner_dim)
59+
60+
# The output from the transformer is the embedding indices for the
61+
# quantized codebook. the output dimension is `num_embed - 1` because
62+
# it does not include additional index for the masked value since
63+
# the transformer predicts the unnoised image which has no masks
64+
self.out = nn.Linear(self.inner_dim, self.num_embed - 1)
5365

5466
def forward(self, latent_images, cond_emb, t):
67+
# bsz = latent_images.shape[0]
68+
5569
embedded_latent_images = self.latent_image_embedding(latent_images)
5670
hidden_states = embedded_latent_images
5771

5872
for block in self.transformer_blocks:
5973
hidden_states = block(hidden_states, cond_emb, t)
6074

6175
logits = self.out(self.norm_out(hidden_states))
62-
out = logits.permute(0, 2, 1)
76+
# (batch, self.num_embed - 1, self.num_latent_pixels)
77+
logits = logits.permute(0, 2, 1)
78+
79+
# TODO remove
80+
torch.save(logits, f"/content/diffusers-out/transformer_logits_diffusers-{t[0]}.pt")
81+
82+
# TODO document why we append the zero vector
83+
# equivalent to `torch.zeros((bsz, 1, self.num_latent_pixels)).log().clamp(self.min_logged_value)`
84+
# log_zero_vector = torch.full((bsz, 1, self.num_latent_pixels), self.min_logged_value, device=logits.device)
85+
86+
log_p_x_0 = F.log_softmax(logits.double(), dim=1).float().clamp(self.min_logged_value)
87+
88+
# (batch, self.num_embed, self.inner_dim)
89+
# log_p_x_0 = torch.cat((log_p_x_0, log_zero_vector), dim=1)
90+
91+
return log_p_x_0
92+
93+
94+
# TODO(will) - document this
95+
class DalleMaskImageEmbedding(nn.Module):
96+
def __init__(
97+
self,
98+
num_embed,
99+
height,
100+
width,
101+
embed_dim,
102+
):
103+
super().__init__()
104+
105+
self.height = height
106+
self.width = width
107+
self.num_embed = num_embed
108+
self.embed_dim = embed_dim
109+
110+
self.emb = nn.Embedding(self.num_embed, embed_dim)
111+
self.height_emb = nn.Embedding(self.height, embed_dim)
112+
self.width_emb = nn.Embedding(self.width, embed_dim)
113+
114+
def forward(self, index):
115+
emb = self.emb(index)
116+
117+
height_emb = self.height_emb(
118+
torch.arange(self.height, device=index.device).view(1, self.height)
119+
).unsqueeze(
120+
2
121+
) # 1 x H x D -> 1 x H x 1 x D
122+
123+
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)).unsqueeze(
124+
1
125+
) # 1 x W x D -> 1 x 1 x W x D
126+
127+
pos_emb = (height_emb + width_emb).view(1, self.height * self.width, -1) # 1 x H x W x D -> 1 x L xD
63128

64-
return out
129+
emb = emb + pos_emb[:, : emb.shape[1], :]
65130

131+
return emb
66132

67133
class BasicTransformerBlock(nn.Module):
68134
def __init__(

0 commit comments

Comments
 (0)