30
30
31
31
from .interfaces import SupportsPP
32
32
from .utils import (is_pp_missing_parameter ,
33
- make_empty_intermediate_tensors_factory , make_layers )
33
+ make_empty_intermediate_tensors_factory , make_layers ,
34
+ maybe_prefix )
34
35
35
36
36
37
class InternLM2MLP (nn .Module ):
@@ -41,16 +42,23 @@ def __init__(
41
42
intermediate_size : int ,
42
43
hidden_act : str ,
43
44
quant_config : Optional [QuantizationConfig ] = None ,
45
+ prefix : str = "" ,
44
46
) -> None :
45
47
super ().__init__ ()
46
48
self .gate_up_proj = MergedColumnParallelLinear (
47
- hidden_size , [intermediate_size ] * 2 ,
49
+ hidden_size ,
50
+ [intermediate_size ] * 2 ,
51
+ bias = False ,
52
+ quant_config = quant_config ,
53
+ prefix = f"{ prefix } .gate_up_proj" ,
54
+ )
55
+ self .w2 = RowParallelLinear (
56
+ intermediate_size ,
57
+ hidden_size ,
48
58
bias = False ,
49
- quant_config = quant_config )
50
- self .w2 = RowParallelLinear (intermediate_size ,
51
- hidden_size ,
52
- bias = False ,
53
- quant_config = quant_config )
59
+ quant_config = quant_config ,
60
+ prefix = f"{ prefix } .w2" ,
61
+ )
54
62
if hidden_act != "silu" :
55
63
raise ValueError (f"Unsupported activation: { hidden_act } . "
56
64
"Only silu is supported for now." )
@@ -75,6 +83,7 @@ def __init__(
75
83
max_position_embeddings : int = 8192 ,
76
84
cache_config : Optional [CacheConfig ] = None ,
77
85
quant_config : Optional [QuantizationConfig ] = None ,
86
+ prefix : str = "" ,
78
87
) -> None :
79
88
super ().__init__ ()
80
89
self .hidden_size = hidden_size
@@ -108,12 +117,14 @@ def __init__(
108
117
self .total_num_kv_heads ,
109
118
bias = False ,
110
119
quant_config = quant_config ,
120
+ prefix = f"{ prefix } .wqkv" ,
111
121
)
112
122
self .wo = RowParallelLinear (
113
123
self .total_num_heads * self .head_dim ,
114
124
hidden_size ,
115
125
bias = False ,
116
126
quant_config = quant_config ,
127
+ prefix = f"{ prefix } .wo" ,
117
128
)
118
129
119
130
self .rotary_emb = get_rope (
@@ -123,12 +134,15 @@ def __init__(
123
134
base = rope_theta ,
124
135
rope_scaling = rope_scaling ,
125
136
)
126
- self .attn = Attention (self .num_heads ,
127
- self .head_dim ,
128
- self .scaling ,
129
- num_kv_heads = self .num_kv_heads ,
130
- cache_config = cache_config ,
131
- quant_config = quant_config )
137
+ self .attn = Attention (
138
+ self .num_heads ,
139
+ self .head_dim ,
140
+ self .scaling ,
141
+ num_kv_heads = self .num_kv_heads ,
142
+ cache_config = cache_config ,
143
+ quant_config = quant_config ,
144
+ prefix = f"{ prefix } .attn" ,
145
+ )
132
146
133
147
def split_qkv (self , qkv : torch .Tensor ):
134
148
seq_len = qkv .shape [0 ]
@@ -176,6 +190,7 @@ def __init__(
176
190
config : PretrainedConfig ,
177
191
cache_config : Optional [CacheConfig ] = None ,
178
192
quant_config : Optional [QuantizationConfig ] = None ,
193
+ prefix : str = "" ,
179
194
) -> None :
180
195
super ().__init__ ()
181
196
self .hidden_size = config .hidden_size
@@ -192,12 +207,14 @@ def __init__(
192
207
max_position_embeddings = max_position_embeddings ,
193
208
cache_config = cache_config ,
194
209
quant_config = quant_config ,
210
+ prefix = f"{ prefix } .attention" ,
195
211
)
196
212
self .feed_forward = InternLM2MLP (
197
213
hidden_size = self .hidden_size ,
198
214
intermediate_size = config .intermediate_size ,
199
215
hidden_act = config .hidden_act ,
200
216
quant_config = quant_config ,
217
+ prefix = f"{ prefix } .feed_forward" ,
201
218
)
202
219
self .attention_norm = RMSNorm (config .hidden_size ,
203
220
eps = config .rms_norm_eps )
@@ -251,8 +268,8 @@ def __init__(
251
268
)
252
269
self .start_layer , self .end_layer , self .layers = make_layers (
253
270
config .num_hidden_layers ,
254
- lambda prefix : InternLMDecoderLayer (config , cache_config ,
255
- quant_config ),
271
+ lambda prefix : InternLMDecoderLayer (
272
+ config , cache_config , quant_config , prefix = prefix ),
256
273
prefix = f"{ prefix } .layers" )
257
274
self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
258
275
self .make_empty_intermediate_tensors = (
@@ -306,14 +323,19 @@ def __init__(
306
323
config : PretrainedConfig ,
307
324
cache_config : Optional [CacheConfig ] = None ,
308
325
quant_config : Optional [QuantizationConfig ] = None ,
326
+ prefix : str = "" ,
309
327
) -> None :
310
328
super ().__init__ ()
311
329
self .config = config
312
330
self .quant_config = quant_config
313
- self .model = InternLM2Model (config , cache_config , quant_config )
331
+ self .model = InternLM2Model (config ,
332
+ cache_config ,
333
+ quant_config ,
334
+ prefix = maybe_prefix (prefix , "model" ))
314
335
self .output = ParallelLMHead (config .vocab_size ,
315
336
config .hidden_size ,
316
- quant_config = quant_config )
337
+ quant_config = quant_config ,
338
+ prefix = maybe_prefix (prefix , "output" ))
317
339
if self .config .tie_word_embeddings :
318
340
self .output .weight = self .model .tok_embeddings .weight
319
341
self .logits_processor = LogitsProcessor (config .vocab_size )
0 commit comments