@@ -85,10 +85,12 @@ class MODEL_ARCH(IntEnum):
8585 GPTNEOX : int = auto ()
8686 MPT : int = auto ()
8787 STARCODER : int = auto ()
88+ BERT : int = auto ()
8889
8990
9091class MODEL_TENSOR (IntEnum ):
9192 TOKEN_EMBD : int = auto ()
93+ TOKEN_TYPES : int = auto ()
9294 POS_EMBD : int = auto ()
9395 OUTPUT : int = auto ()
9496 OUTPUT_NORM : int = auto ()
@@ -116,10 +118,12 @@ class MODEL_TENSOR(IntEnum):
116118 MODEL_ARCH .GPTNEOX : "gptneox" ,
117119 MODEL_ARCH .MPT : "mpt" ,
118120 MODEL_ARCH .STARCODER : "starcoder" ,
121+ MODEL_ARCH .BERT : "bert" ,
119122}
120123
121124TENSOR_NAMES : dict [MODEL_TENSOR , str ] = {
122125 MODEL_TENSOR .TOKEN_EMBD : "token_embd" ,
126+ MODEL_TENSOR .TOKEN_TYPES : "token_types" ,
123127 MODEL_TENSOR .POS_EMBD : "position_embd" ,
124128 MODEL_TENSOR .OUTPUT_NORM : "output_norm" ,
125129 MODEL_TENSOR .OUTPUT : "output" ,
@@ -206,6 +210,43 @@ class MODEL_TENSOR(IntEnum):
206210 MODEL_TENSOR .FFN_DOWN ,
207211 MODEL_TENSOR .FFN_UP ,
208212 ],
213+ MODEL_ARCH .BERT : [
214+ MODEL_TENSOR .TOKEN_EMBD ,
215+ MODEL_TENSOR .TOKEN_TYPES ,
216+ MODEL_TENSOR .POS_EMBD ,
217+ MODEL_TENSOR .OUTPUT_NORM ,
218+ MODEL_TENSOR .ATTN_NORM ,
219+ MODEL_TENSOR .ATTN_Q ,
220+ MODEL_TENSOR .ATTN_K ,
221+ MODEL_TENSOR .ATTN_V ,
222+ MODEL_TENSOR .ATTN_OUT ,
223+ MODEL_TENSOR .FFN_NORM ,
224+ MODEL_TENSOR .FFN_DOWN ,
225+ MODEL_TENSOR .FFN_UP ,
226+ ],
227+ MODEL_ARCH .MPT : [
228+ MODEL_TENSOR .TOKEN_EMBD ,
229+ MODEL_TENSOR .OUTPUT_NORM ,
230+ MODEL_TENSOR .OUTPUT ,
231+ MODEL_TENSOR .ATTN_NORM ,
232+ MODEL_TENSOR .ATTN_QKV ,
233+ MODEL_TENSOR .ATTN_OUT ,
234+ MODEL_TENSOR .FFN_NORM ,
235+ MODEL_TENSOR .FFN_DOWN ,
236+ MODEL_TENSOR .FFN_UP ,
237+ ],
238+ MODEL_ARCH .GPTJ : [
239+ MODEL_TENSOR .TOKEN_EMBD ,
240+ MODEL_TENSOR .OUTPUT_NORM ,
241+ MODEL_TENSOR .OUTPUT ,
242+ MODEL_TENSOR .ATTN_NORM ,
243+ MODEL_TENSOR .ATTN_Q ,
244+ MODEL_TENSOR .ATTN_K ,
245+ MODEL_TENSOR .ATTN_V ,
246+ MODEL_TENSOR .ATTN_OUT ,
247+ MODEL_TENSOR .FFN_DOWN ,
248+ MODEL_TENSOR .FFN_UP ,
249+ ],
209250 MODEL_ARCH .GPT2 : [
210251 # TODO
211252 ],
@@ -229,31 +270,40 @@ class TensorNameMap:
229270 mappings_cfg : dict [MODEL_TENSOR , tuple [str , ...]] = {
230271 # Token embeddings
231272 MODEL_TENSOR .TOKEN_EMBD : (
232- "gpt_neox.embed_in" , # gptneox
233- "transformer.wte" , # gpt2 mpt
234- "transformer.word_embeddings" , # falcon
235- "model.embed_tokens" , # llama-hf
236- "tok_embeddings" , # llama-pth
273+ "gpt_neox.embed_in" , # gptneox
274+ "transformer.wte" , # gpt2 gpt-j mpt
275+ "transformer.word_embeddings" , # falcon
276+ "model.embed_tokens" , # llama-hf
277+ "tok_embeddings" , # llama-pth
278+ "embeddings.word_embeddings" , # bert
279+ ),
280+
281+ # Token type embeddings
282+ MODEL_TENSOR .TOKEN_TYPES : (
283+ "embeddings.token_type_embeddings" , # bert
237284 ),
238285
239286 # Position embeddings
240287 MODEL_TENSOR .POS_EMBD : (
241- "transformer.wpe" , # gpt2
288+ "transformer.wpe" , # gpt2
289+ "embeddings.position_embeddings" , # bert
242290 ),
243291
244292 # Output
245293 MODEL_TENSOR .OUTPUT : (
246- "embed_out" , # gptneox
247- "lm_head" , # gpt2 mpt falcon llama-hf baichuan
248- "output" , # llama-pth
294+ "embed_out" , # gptneox
295+ "lm_head" , # gpt2 gpt-j mpt falcon llama-hf baichuan
296+ "output" , # llama-pth
249297 ),
250298
251299 # Output norm
252300 MODEL_TENSOR .OUTPUT_NORM : (
253- "gpt_neox.final_layer_norm" , # gptneox
254- "transformer.ln_f" , # gpt2 falcon
255- "model.norm" , # llama-hf baichuan
256- "norm" , # llama-pth
301+ "gpt_neox.final_layer_norm" , # gptneox
302+ "transformer.ln_f" , # gpt2 gpt-j falcon
303+ "model.norm" , # llama-hf baichuan
304+ "norm" , # llama-pth
305+ "embeddings.LayerNorm" , # bert
306+ "transformer.norm_f" , # mpt
257307 ),
258308
259309 # Rope frequencies
@@ -265,13 +315,14 @@ class TensorNameMap:
265315 block_mappings_cfg : dict [MODEL_TENSOR , tuple [str , ...]] = {
266316 # Attention norm
267317 MODEL_TENSOR .ATTN_NORM : (
268- "gpt_neox.layers.{bid}.input_layernorm" , # gptneox
269- "transformer.h.{bid}.ln_1" , # gpt2
270- "transformer.blocks.{bid}.norm_1" , # mpt
271- "transformer.h.{bid}.input_layernorm" , # falcon7b
272- "transformer.h.{bid}.ln_mlp" , # falcon40b
273- "model.layers.{bid}.input_layernorm" , # llama-hf
274- "layers.{bid}.attention_norm" , # llama-pth
318+ "gpt_neox.layers.{bid}.input_layernorm" , # gptneox
319+ "transformer.h.{bid}.ln_1" , # gpt2 gpt-j
320+ "transformer.blocks.{bid}.norm_1" , # mpt
321+ "transformer.h.{bid}.input_layernorm" , # falcon7b
322+ "transformer.h.{bid}.ln_mlp" , # falcon40b
323+ "model.layers.{bid}.input_layernorm" , # llama-hf
324+ "layers.{bid}.attention_norm" , # llama-pth
325+ "encoder.layer.{bid}.attention.output.LayerNorm" , # bert
275326 ),
276327
277328 # Attention norm 2
@@ -281,38 +332,46 @@ class TensorNameMap:
281332
282333 # Attention query-key-value
283334 MODEL_TENSOR .ATTN_QKV : (
284- "gpt_neox.layers.{bid}.attention.query_key_value" , # gptneox
285- "transformer.h.{bid}.attn.c_attn" , # gpt2
286- "transformer.blocks.{bid}.attn.Wqkv" , # mpt
287- "transformer.h.{bid}.self_attention.query_key_value" , # falcon
335+ "gpt_neox.layers.{bid}.attention.query_key_value" , # gptneox
336+ "transformer.h.{bid}.attn.c_attn" , # gpt2
337+ "transformer.blocks.{bid}.attn.Wqkv" , # mpt
338+ "transformer.h.{bid}.self_attention.query_key_value" , # falcon
288339 ),
289340
290341 # Attention query
291342 MODEL_TENSOR .ATTN_Q : (
292- "model.layers.{bid}.self_attn.q_proj" , # llama-hf
293- "layers.{bid}.attention.wq" , # llama-pth
343+ "model.layers.{bid}.self_attn.q_proj" , # llama-hf
344+ "layers.{bid}.attention.wq" , # llama-pth
345+ "encoder.layer.{bid}.attention.self.query" , # bert
346+ "transformer.h.{bid}.attn.q_proj" , # gpt-j
294347 ),
295348
296349 # Attention key
297350 MODEL_TENSOR .ATTN_K : (
298- "model.layers.{bid}.self_attn.k_proj" , # llama-hf
299- "layers.{bid}.attention.wk" , # llama-pth
351+ "model.layers.{bid}.self_attn.k_proj" , # llama-hf
352+ "layers.{bid}.attention.wk" , # llama-pth
353+ "encoder.layer.{bid}.attention.self.key" , # bert
354+ "transformer.h.{bid}.attn.k_proj" , # gpt-j
300355 ),
301356
302357 # Attention value
303358 MODEL_TENSOR .ATTN_V : (
304- "model.layers.{bid}.self_attn.v_proj" , # llama-hf
305- "layers.{bid}.attention.wv" , # llama-pth
359+ "model.layers.{bid}.self_attn.v_proj" , # llama-hf
360+ "layers.{bid}.attention.wv" , # llama-pth
361+ "encoder.layer.{bid}.attention.self.value" , # bert
362+ "transformer.h.{bid}.attn.v_proj" , # gpt-j
306363 ),
307364
308365 # Attention output
309366 MODEL_TENSOR .ATTN_OUT : (
310- "gpt_neox.layers.{bid}.attention.dense" , # gptneox
311- "transformer.h.{bid}.attn.c_proj" , # gpt2
312- "transformer.blocks.{bid}.attn.out_proj" , # mpt
313- "transformer.h.{bid}.self_attention.dense" , # falcon
314- "model.layers.{bid}.self_attn.o_proj" , # llama-hf
315- "layers.{bid}.attention.wo" , # llama-pth
367+ "gpt_neox.layers.{bid}.attention.dense" , # gptneox
368+ "transformer.h.{bid}.attn.c_proj" , # gpt2
369+ "transformer.blocks.{bid}.attn.out_proj" , # mpt
370+ "transformer.h.{bid}.self_attention.dense" , # falcon
371+ "model.layers.{bid}.self_attn.o_proj" , # llama-hf
372+ "layers.{bid}.attention.wo" , # llama-pth
373+ "encoder.layer.{bid}.attention.output.dense" , # bert
374+ "transformer.h.{bid}.attn.out_proj" , # gpt-j
316375 ),
317376
318377 # Rotary embeddings
@@ -323,21 +382,24 @@ class TensorNameMap:
323382
324383 # Feed-forward norm
325384 MODEL_TENSOR .FFN_NORM : (
326- "gpt_neox.layers.{bid}.post_attention_layernorm" , # gptneox
327- "transformer.h.{bid}.ln_2" , # gpt2
328- "transformer.blocks.{bid}.norm_2" , # mpt
329- "model.layers.{bid}.post_attention_layernorm" , # llama-hf
330- "layers.{bid}.ffn_norm" , # llama-pth
385+ "gpt_neox.layers.{bid}.post_attention_layernorm" , # gptneox
386+ "transformer.h.{bid}.ln_2" , # gpt2
387+ "transformer.blocks.{bid}.norm_2" , # mpt
388+ "model.layers.{bid}.post_attention_layernorm" , # llama-hf
389+ "layers.{bid}.ffn_norm" , # llama-pth
390+ "encoder.layer.{bid}.output.LayerNorm" , # bert
331391 ),
332392
333393 # Feed-forward up
334394 MODEL_TENSOR .FFN_UP : (
335- "gpt_neox.layers.{bid}.mlp.dense_h_to_4h" , # gptneox
336- "transformer.h.{bid}.mlp.c_fc" , # gpt2
337- "transformer.blocks.{bid}.ffn.up_proj" , # mpt
338- "transformer.h.{bid}.mlp.dense_h_to_4h" , # falcon
339- "model.layers.{bid}.mlp.up_proj" , # llama-hf
340- "layers.{bid}.feed_forward.w3" , # llama-pth
395+ "gpt_neox.layers.{bid}.mlp.dense_h_to_4h" , # gptneox
396+ "transformer.h.{bid}.mlp.c_fc" , # gpt2
397+ "transformer.blocks.{bid}.ffn.up_proj" , # mpt
398+ "transformer.h.{bid}.mlp.dense_h_to_4h" , # falcon
399+ "model.layers.{bid}.mlp.up_proj" , # llama-hf
400+ "layers.{bid}.feed_forward.w3" , # llama-pth
401+ "encoder.layer.{bid}.intermediate.dense" , # bert
402+ "transformer.h.{bid}.mlp.fc_in" , # gpt-j
341403 ),
342404
343405 # Feed-forward gate
@@ -348,12 +410,14 @@ class TensorNameMap:
348410
349411 # Feed-forward down
350412 MODEL_TENSOR .FFN_DOWN : (
351- "gpt_neox.layers.{bid}.mlp.dense_4h_to_h" , # gptneox
352- "transformer.h.{bid}.mlp.c_proj" , # gpt2
353- "transformer.blocks.{bid}.ffn.down_proj" , # mpt
354- "transformer.h.{bid}.mlp.dense_4h_to_h" , # falcon
355- "model.layers.{bid}.mlp.down_proj" , # llama-hf
356- "layers.{bid}.feed_forward.w2" , # llama-pth
413+ "gpt_neox.layers.{bid}.mlp.dense_4h_to_h" , # gptneox
414+ "transformer.h.{bid}.mlp.c_proj" , # gpt2
415+ "transformer.blocks.{bid}.ffn.down_proj" , # mpt
416+ "transformer.h.{bid}.mlp.dense_4h_to_h" , # falcon
417+ "model.layers.{bid}.mlp.down_proj" , # llama-hf
418+ "layers.{bid}.feed_forward.w2" , # llama-pth
419+ "encoder.layer.{bid}.output.dense" , # bert
420+ "transformer.h.{bid}.mlp.fc_out" , # gpt-j
357421 ),
358422 }
359423
0 commit comments