28
28
29
29
30
30
class LlamaDecoderLayer (LlamaDecoderLayer ):
31
+
31
32
def __init__ (
32
33
self ,
33
34
vllm_config : VllmConfig ,
34
35
prefix : str = "" ,
35
36
config : Optional [LlamaConfig ] = None ,
37
+ layer_idx : int = 0 ,
36
38
) -> None :
37
39
super ().__init__ (vllm_config , prefix = prefix , config = config )
38
40
39
41
config = config or vllm_config .model_config .hf_config
40
42
quant_config = self .get_quant_config (vllm_config )
41
43
44
+ # First layer uses 2*hidden_size (embeds + hidden_states concatenated)
45
+ # Subsequent layers use hidden_size (only hidden_states, no embeds)
46
+ qkv_input_size = (2 * self .hidden_size
47
+ if layer_idx == 0 else self .hidden_size )
48
+
42
49
# override qkv
43
50
self .self_attn .qkv_proj = QKVParallelLinear (
44
- 2 * self . hidden_size ,
51
+ qkv_input_size ,
45
52
self .self_attn .head_dim ,
46
53
self .self_attn .total_num_heads ,
47
54
self .self_attn .total_num_kv_heads ,
@@ -51,6 +58,7 @@ def __init__(
51
58
)
52
59
53
60
self .hidden_norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
61
+ self .layer_idx = layer_idx
54
62
55
63
if getattr (config , "norm_before_residual" , False ):
56
64
self ._residual_norm = self ._norm_before_residual
@@ -89,11 +97,18 @@ def forward(
89
97
hidden_states : torch .Tensor ,
90
98
residual : Optional [torch .Tensor ],
91
99
) -> tuple [torch .Tensor , torch .Tensor ]:
92
- embeds = self .input_layernorm (embeds )
93
100
94
- hidden_states , residual = self ._residual_norm (hidden_states = hidden_states )
101
+ if self .layer_idx == 0 :
102
+ embeds = self .input_layernorm (embeds )
103
+ hidden_states , residual = self ._residual_norm (
104
+ hidden_states = hidden_states )
105
+ hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
106
+ else :
107
+ # Subsequent layers: only process hidden_states
108
+ # (no embeds, no input_layernorm)
109
+ hidden_states , residual = self ._residual_norm (
110
+ hidden_states = hidden_states )
95
111
96
- hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
97
112
# Self Attention
98
113
hidden_states = self .self_attn (
99
114
positions = positions ,
@@ -128,15 +143,15 @@ def __init__(
128
143
prefix = maybe_prefix (prefix , "embed_tokens" ),
129
144
)
130
145
131
- self .layers = nn .ModuleList (
132
- [
133
- LlamaDecoderLayer (
134
- current_vllm_config ,
135
- prefix = maybe_prefix ( prefix , f"layers.{ start_layer_id } " ),
136
- config = self .config ,
137
- )
138
- ]
139
- )
146
+ self .layers = nn .ModuleList ([
147
+ LlamaDecoderLayer ( current_vllm_config ,
148
+ prefix = maybe_prefix (
149
+ prefix ,
150
+ f"layers.{ layer_idx + start_layer_id } " ),
151
+ config = self .config ,
152
+ layer_idx = layer_idx )
153
+ for layer_idx in range ( self . config . num_hidden_layers )
154
+ ] )
140
155
if hasattr (self .config , "target_hidden_size" ):
141
156
self .fc = torch .nn .Linear (
142
157
self .config .target_hidden_size * 3 , self .config .hidden_size , bias = False
@@ -165,13 +180,13 @@ def forward(
165
180
assert hidden_states .shape [- 1 ] == input_embeds .shape [- 1 ]
166
181
167
182
residual = None
168
- hidden_states , residual = self .layers [ 0 ](
169
- positions ,
170
- input_embeds ,
171
- hidden_states ,
172
- residual ,
173
- )
174
-
183
+ for layer in self .layers :
184
+ hidden_states , residual = layer (
185
+ positions = positions ,
186
+ embeds = input_embeds ,
187
+ hidden_states = hidden_states ,
188
+ residual = residual ,
189
+ )
175
190
hidden_states , hidden_prenorm = self .norm (hidden_states , residual )
176
191
return hidden_states , hidden_prenorm
177
192
0 commit comments