1+ import argparse
2+
3+ import torch
4+ from torch import nn
5+
6+ from transformers import CLIPTextConfig , CLIPTextModel , GPT2Tokenizer
7+
8+ # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
9+ state_dict = torch .load ("base.pt" , map_location = "cpu" )
10+ state_dict = {k : nn .Parameter (v ) for k , v in state_dict .items ()}
11+ config = CLIPTextConfig (
12+ hidden_size = 512 ,
13+ intermediate_size = 2048 ,
14+ num_hidden_layers = 16 ,
15+ num_attention_heads = 8 ,
16+ max_position_embeddings = 128
17+ )
18+ model = CLIPTextModel (config ).eval ()
19+ tokenizer = GPT2Tokenizer ("./glide-base/vocab.json" , "./glide-base/merges.txt" , pad_token = "<|endoftext|>" )
20+ tokenizer .save_pretrained ("./glide-base" )
21+
22+ hf_encoder = model .text_model
23+
24+ hf_encoder .embeddings .token_embedding .weight = state_dict ["token_embedding.weight" ]
25+ hf_encoder .embeddings .position_embedding .weight .data = state_dict ["positional_embedding" ]
26+ hf_encoder .embeddings .padding_embedding .weight .data = state_dict ["padding_embedding" ]
27+
28+ hf_encoder .final_layer_norm .weight = state_dict ["final_ln.weight" ]
29+ hf_encoder .final_layer_norm .bias = state_dict ["final_ln.bias" ]
30+
31+ for layer_idx in range (config .num_hidden_layers ):
32+ hf_layer = hf_encoder .encoder .layers [layer_idx ]
33+ q_proj , k_proj , v_proj = state_dict [f"transformer.resblocks.{ layer_idx } .attn.c_qkv.weight" ].chunk (3 , dim = 0 )
34+ q_proj_bias , k_proj_bias , v_proj_bias = state_dict [f"transformer.resblocks.{ layer_idx } .attn.c_qkv.bias" ].chunk (3 , dim = 0 )
35+
36+ hf_layer .self_attn .q_proj .weight .data = q_proj
37+ hf_layer .self_attn .q_proj .bias .data = q_proj_bias
38+ hf_layer .self_attn .k_proj .weight .data = k_proj
39+ hf_layer .self_attn .k_proj .bias .data = k_proj_bias
40+ hf_layer .self_attn .v_proj .weight .data = v_proj
41+ hf_layer .self_attn .v_proj .bias .data = v_proj_bias
42+
43+ hf_layer .self_attn .out_proj .weight = state_dict [f"transformer.resblocks.{ layer_idx } .attn.c_proj.weight" ]
44+ hf_layer .self_attn .out_proj .bias = state_dict [f"transformer.resblocks.{ layer_idx } .attn.c_proj.bias" ]
45+
46+ hf_layer .layer_norm1 .weight = state_dict [f"transformer.resblocks.{ layer_idx } .ln_1.weight" ]
47+ hf_layer .layer_norm1 .bias = state_dict [f"transformer.resblocks.{ layer_idx } .ln_1.bias" ]
48+ hf_layer .layer_norm2 .weight = state_dict [f"transformer.resblocks.{ layer_idx } .ln_2.weight" ]
49+ hf_layer .layer_norm2 .bias = state_dict [f"transformer.resblocks.{ layer_idx } .ln_2.bias" ]
50+
51+ hf_layer .mlp .fc1 .weight = state_dict [f"transformer.resblocks.{ layer_idx } .mlp.c_fc.weight" ]
52+ hf_layer .mlp .fc1 .bias = state_dict [f"transformer.resblocks.{ layer_idx } .mlp.c_fc.bias" ]
53+ hf_layer .mlp .fc2 .weight = state_dict [f"transformer.resblocks.{ layer_idx } .mlp.c_proj.weight" ]
54+ hf_layer .mlp .fc2 .bias = state_dict [f"transformer.resblocks.{ layer_idx } .mlp.c_proj.bias" ]
55+
56+ inputs = tokenizer (["an oil painting of a corgi" , "" ], padding = "max_length" , max_length = 128 , return_tensors = "pt" )
57+ with torch .no_grad ():
58+ outputs = model (** inputs )
59+
60+ model .save_pretrained ("./glide-base" )
0 commit comments