Skip to content

Commit 9d33a33

Browse files
committed
[WIP] scheduler
1 parent 1b1ee17 commit 9d33a33

File tree

8 files changed

+687
-61
lines changed

8 files changed

+687
-61
lines changed

orig_scheduler.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import torch
2+
import numpy as np
3+
4+
5+
def log_1_min_a(a):
6+
return torch.log(1 - a.exp() + 1e-40)
7+
8+
def log_add_exp(a, b):
9+
maximum = torch.max(a, b)
10+
return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))
11+
12+
def extract(a, t, x_shape):
13+
b, *_ = t.shape
14+
out = a.gather(-1, t)
15+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
16+
17+
def log_categorical(log_x_start, log_prob):
18+
return (log_x_start.exp() * log_prob).sum(dim=1)
19+
20+
def index_to_log_onehot(x, num_classes):
21+
assert x.max().item() < num_classes, \
22+
f'Error: {x.max().item()} >= {num_classes}'
23+
x_onehot = F.one_hot(x, num_classes)
24+
permute_order = (0, -1) + tuple(range(1, len(x.size())))
25+
x_onehot = x_onehot.permute(permute_order)
26+
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
27+
return log_x
28+
29+
def log_onehot_to_index(log_x):
30+
return log_x.argmax(1)
31+
32+
def alpha_schedule(time_step, N=100, att_1 = 0.99999, att_T = 0.000009, ctt_1 = 0.000009, ctt_T = 0.99999):
33+
att = np.arange(0, time_step)/(time_step-1)*(att_T - att_1) + att_1
34+
att = np.concatenate(([1], att))
35+
at = att[1:]/att[:-1]
36+
ctt = np.arange(0, time_step)/(time_step-1)*(ctt_T - ctt_1) + ctt_1
37+
ctt = np.concatenate(([0], ctt))
38+
one_minus_ctt = 1 - ctt
39+
one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1]
40+
ct = 1-one_minus_ct
41+
bt = (1-at-ct)/N
42+
att = np.concatenate((att[1:], [1]))
43+
ctt = np.concatenate((ctt[1:], [0]))
44+
btt = (1-att-ctt)/N
45+
return at, bt, ct, att, btt, ctt
46+
47+
48+
class OrigScheduler:
49+
def __init__(self, *, num_classes, content_seq_len, num_timesteps=100):
50+
self.num_timesteps = num_timesteps
51+
self.num_classes = num_classes
52+
self.content_seq_len = content_seq_len
53+
54+
at, bt, ct, att, btt, ctt = alpha_schedule(self.num_timesteps, N=self.num_classes-1)
55+
56+
at = torch.tensor(at.astype('float64'))
57+
bt = torch.tensor(bt.astype('float64'))
58+
ct = torch.tensor(ct.astype('float64'))
59+
log_at = torch.log(at)
60+
log_bt = torch.log(bt)
61+
log_ct = torch.log(ct)
62+
att = torch.tensor(att.astype('float64'))
63+
btt = torch.tensor(btt.astype('float64'))
64+
ctt = torch.tensor(ctt.astype('float64'))
65+
log_cumprod_at = torch.log(att)
66+
log_cumprod_bt = torch.log(btt)
67+
log_cumprod_ct = torch.log(ctt)
68+
69+
log_1_min_ct = log_1_min_a(log_ct)
70+
log_1_min_cumprod_ct = log_1_min_a(log_cumprod_ct)
71+
72+
assert log_add_exp(log_ct, log_1_min_ct).abs().sum().item() < 1.e-5
73+
assert log_add_exp(log_cumprod_ct, log_1_min_cumprod_ct).abs().sum().item() < 1.e-5
74+
75+
# Convert to float32 and register buffers.
76+
self.log_at = log_at.float()
77+
self.log_bt = log_bt.float()
78+
self.log_ct = log_ct.float()
79+
self.log_cumprod_at = log_cumprod_at.float()
80+
self.log_cumprod_bt = log_cumprod_bt.float()
81+
self.log_cumprod_ct = log_cumprod_ct.float()
82+
self.log_1_min_ct = log_1_min_ct.float()
83+
self.log_1_min_cumprod_ct = log_1_min_cumprod_ct.float()
84+
85+
86+
87+
def q_posterior(self, log_x_start, log_x_t, t): # p_theta(xt_1|xt) = sum(q(xt-1|xt,x0')*p(x0'))
88+
# notice that log_x_t is onehot
89+
assert t.min().item() >= 0 and t.max().item() < self.num_timesteps
90+
batch_size = log_x_start.size()[0]
91+
onehot_x_t = log_onehot_to_index(log_x_t)
92+
mask = (onehot_x_t == self.num_classes-1).unsqueeze(1)
93+
log_one_vector = torch.zeros(batch_size, 1, 1).type_as(log_x_t)
94+
log_zero_vector = torch.log(log_one_vector+1.0e-30).expand(-1, -1, self.content_seq_len)
95+
96+
log_qt = self.q_pred(log_x_t, t) # q(xt|x0)
97+
# log_qt = torch.cat((log_qt[:,:-1,:], log_zero_vector), dim=1)
98+
log_qt = log_qt[:,:-1,:]
99+
log_cumprod_ct = extract(self.log_cumprod_ct, t, log_x_start.shape) # ct~
100+
ct_cumprod_vector = log_cumprod_ct.expand(-1, self.num_classes-1, -1)
101+
# ct_cumprod_vector = torch.cat((ct_cumprod_vector, log_one_vector), dim=1)
102+
log_qt = (~mask)*log_qt + mask*ct_cumprod_vector
103+
104+
105+
log_qt_one_timestep = self.q_pred_one_timestep(log_x_t, t) # q(xt|xt_1)
106+
log_qt_one_timestep = torch.cat((log_qt_one_timestep[:,:-1,:], log_zero_vector), dim=1)
107+
log_ct = extract(self.log_ct, t, log_x_start.shape) # ct
108+
ct_vector = log_ct.expand(-1, self.num_classes-1, -1)
109+
ct_vector = torch.cat((ct_vector, log_one_vector), dim=1)
110+
log_qt_one_timestep = (~mask)*log_qt_one_timestep + mask*ct_vector
111+
112+
# log_x_start = torch.cat((log_x_start, log_zero_vector), dim=1)
113+
# q = log_x_start - log_qt
114+
q = log_x_start[:,:-1,:] - log_qt
115+
q = torch.cat((q, log_zero_vector), dim=1)
116+
q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True)
117+
q = q - q_log_sum_exp
118+
log_EV_xtmin_given_xt_given_xstart = self.q_pred(q, t-1) + log_qt_one_timestep + q_log_sum_exp
119+
return torch.clamp(log_EV_xtmin_given_xt_given_xstart, -70, 0)
120+
121+
122+
def q_pred_one_timestep(self, log_x_t, t): # q(xt|xt_1)
123+
log_at = extract(self.log_at, t, log_x_t.shape) # at
124+
log_bt = extract(self.log_bt, t, log_x_t.shape) # bt
125+
log_ct = extract(self.log_ct, t, log_x_t.shape) # ct
126+
log_1_min_ct = extract(self.log_1_min_ct, t, log_x_t.shape) # 1-ct
127+
128+
log_probs = torch.cat(
129+
[
130+
log_add_exp(log_x_t[:,:-1,:]+log_at, log_bt),
131+
log_add_exp(log_x_t[:, -1:, :] + log_1_min_ct, log_ct)
132+
],
133+
dim=1
134+
)
135+
136+
return log_probs
137+
138+
def q_pred(self, log_x_start, t): # q(xt|x0)
139+
# log_x_start can be onehot or not
140+
t = (t + (self.num_timesteps + 1))%(self.num_timesteps + 1)
141+
log_cumprod_at = extract(self.log_cumprod_at, t, log_x_start.shape) # at~
142+
log_cumprod_bt = extract(self.log_cumprod_bt, t, log_x_start.shape) # bt~
143+
log_cumprod_ct = extract(self.log_cumprod_ct, t, log_x_start.shape) # ct~
144+
log_1_min_cumprod_ct = extract(self.log_1_min_cumprod_ct, t, log_x_start.shape) # 1-ct~
145+
146+
147+
log_probs = torch.cat(
148+
[
149+
log_add_exp(log_x_start[:,:-1,:]+log_cumprod_at, log_cumprod_bt),
150+
log_add_exp(log_x_start[:,-1:,:]+log_1_min_cumprod_ct, log_cumprod_ct)
151+
],
152+
dim=1
153+
)
154+
155+
return log_probs

scripts/convert_vq_diffusion_to_diffusers.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@
3232

3333
import yaml
3434
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
35-
from diffusers import VQModel
35+
from diffusers import VQModel, VQDiffusionPipeline, VQDiffusionScheduler
3636
from diffusers.models.vq_diffusion_attention import VQDiffusionTransformer
37-
from diffusers.pipelines import VQDiffusionPipeline
3837
from transformers import CLIPTextModel, CLIPTokenizer
3938
from yaml.loader import FullLoader
4039

@@ -492,7 +491,12 @@ def transformer_model_from_original_config(
492491

493492
depth = original_transformer_config["n_layer"]
494493
context_dim = original_transformer_config["condition_dim"]
494+
495495
num_embed = original_content_embedding_config["num_embed"]
496+
# the number of embeddings in the transformer includes the mask embedding.
497+
# the content embedding (the vqvae) does not include the mask embedding.
498+
num_embed = num_embed + 1
499+
496500
height = original_transformer_config["content_spatial_size"][0]
497501
width = original_transformer_config["content_spatial_size"][1]
498502
dropout = original_transformer_config["resid_pdrop"]
@@ -846,10 +850,23 @@ def read_config_file(filename):
846850

847851
# done text encoder
848852

853+
# scheduler
854+
855+
scheduler_model = VQDiffusionScheduler(
856+
# the scheduler has the same number of embeddings as the transformer
857+
num_embed=transformer_model.num_embed
858+
)
859+
860+
# done scheduler
861+
849862
print(f"saving VQ diffusion model, path: {args.dump_path}")
850863

851864
pipe = VQDiffusionPipeline(
852-
vqvae=vqvae_model, transformer=transformer_model, tokenizer=tokenizer_model, text_encoder=text_encoder_model
865+
vqvae=vqvae_model,
866+
transformer=transformer_model,
867+
tokenizer=tokenizer_model,
868+
text_encoder=text_encoder_model,
869+
scheduler=scheduler_model,
853870
)
854871
pipe.save_pretrained(args.dump_path)
855872

src/diffusers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@
2929
get_scheduler,
3030
)
3131
from .pipeline_utils import DiffusionPipeline
32-
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
32+
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline, VQDiffusionPipeline
3333
from .schedulers import (
3434
DDIMScheduler,
3535
DDPMScheduler,
3636
KarrasVeScheduler,
3737
PNDMScheduler,
3838
SchedulerMixin,
3939
ScoreSdeVeScheduler,
40+
VQDiffusionScheduler
4041
)
4142
from .training_utils import EMAModel
4243
else:

src/diffusers/models/embeddings.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -115,52 +115,3 @@ def forward(self, x):
115115
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
116116
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
117117
return out
118-
119-
120-
# TODO(will) - document this. check if throwing errors internally is appropriate
121-
class DalleMaskImageEmbedding(nn.Module):
122-
def __init__(
123-
self,
124-
num_embed,
125-
height,
126-
width,
127-
embed_dim,
128-
):
129-
super().__init__()
130-
131-
self.height = height
132-
self.width = width
133-
# TODO(will) add docs on why this is incremented by 1. (Has to do with mask?)
134-
self.num_embed = num_embed + 1
135-
self.embed_dim = embed_dim
136-
137-
self.emb = nn.Embedding(self.num_embed, embed_dim)
138-
self.height_emb = nn.Embedding(self.height, embed_dim)
139-
self.width_emb = nn.Embedding(self.width, embed_dim)
140-
141-
def forward(self, index):
142-
assert index.dim() == 2 # B x L
143-
try:
144-
index[index < 0] = 0
145-
emb = self.emb(index)
146-
except:
147-
raise RuntimeError(
148-
"IndexError: index out of range in self, max index {}, num embed {}".format(
149-
index.max(), self.num_embed
150-
)
151-
)
152-
153-
# add col and row embedding
154-
if emb.shape[1] > 0:
155-
height_emb = self.height_emb(
156-
torch.arange(self.height, device=index.device).view(1, self.height)
157-
).unsqueeze(
158-
2
159-
) # 1 x H x D -> 1 x H x 1 x D
160-
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)).unsqueeze(
161-
1
162-
) # 1 x W x D -> 1 x 1 x W x D
163-
pos_emb = (height_emb + width_emb).view(1, self.height * self.width, -1) # 1 x H x W x D -> 1 x L xD
164-
emb = emb + pos_emb[:, : emb.shape[1], :]
165-
166-
return emb

0 commit comments

Comments
 (0)