Skip to content

Commit 9c4cd06

Browse files
authored
Merge pull request #4 from huggingface/add-glide
Convert glide weights
2 parents f39020b + d04051e commit 9c4cd06

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import argparse
2+
3+
import torch
4+
from torch import nn
5+
6+
from transformers import CLIPTextConfig, CLIPTextModel, GPT2Tokenizer
7+
8+
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
9+
state_dict = torch.load("base.pt", map_location="cpu")
10+
state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}
11+
config = CLIPTextConfig(
12+
hidden_size=512,
13+
intermediate_size=2048,
14+
num_hidden_layers=16,
15+
num_attention_heads=8,
16+
max_position_embeddings=128
17+
)
18+
model = CLIPTextModel(config).eval()
19+
tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>")
20+
tokenizer.save_pretrained("./glide-base")
21+
22+
hf_encoder = model.text_model
23+
24+
hf_encoder.embeddings.token_embedding.weight = state_dict["token_embedding.weight"]
25+
hf_encoder.embeddings.position_embedding.weight.data = state_dict["positional_embedding"]
26+
hf_encoder.embeddings.padding_embedding.weight.data = state_dict["padding_embedding"]
27+
28+
hf_encoder.final_layer_norm.weight = state_dict["final_ln.weight"]
29+
hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]
30+
31+
for layer_idx in range(config.num_hidden_layers):
32+
hf_layer = hf_encoder.encoder.layers[layer_idx]
33+
q_proj, k_proj, v_proj = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"].chunk(3, dim=0)
34+
q_proj_bias, k_proj_bias, v_proj_bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"].chunk(3, dim=0)
35+
36+
hf_layer.self_attn.q_proj.weight.data = q_proj
37+
hf_layer.self_attn.q_proj.bias.data = q_proj_bias
38+
hf_layer.self_attn.k_proj.weight.data = k_proj
39+
hf_layer.self_attn.k_proj.bias.data = k_proj_bias
40+
hf_layer.self_attn.v_proj.weight.data = v_proj
41+
hf_layer.self_attn.v_proj.bias.data = v_proj_bias
42+
43+
hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"]
44+
hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"]
45+
46+
hf_layer.layer_norm1.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.weight"]
47+
hf_layer.layer_norm1.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.bias"]
48+
hf_layer.layer_norm2.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.weight"]
49+
hf_layer.layer_norm2.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.bias"]
50+
51+
hf_layer.mlp.fc1.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.weight"]
52+
hf_layer.mlp.fc1.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.bias"]
53+
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
54+
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
55+
56+
inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
57+
with torch.no_grad():
58+
outputs = model(**inputs)
59+
60+
model.save_pretrained("./glide-base")

models/vision/glide/run_glide.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
generator = generator.manual_seed(0)
77

88
# 1. Load models
9+
910
scheduler = GaussianDDPMScheduler.from_config("fusing/glide-base")
1011
model = UNetGLIDEModel.from_pretrained("fusing/glide-base")
1112

0 commit comments

Comments
 (0)