@@ -34,15 +34,20 @@ def __init__(
34
34
vllm_config : VllmConfig ,
35
35
prefix : str = "" ,
36
36
config : Optional [LlamaConfig ] = None ,
37
+ layer_idx : int = 0 ,
37
38
) -> None :
38
39
super ().__init__ (vllm_config , prefix = prefix , config = config )
39
40
40
41
config = config or vllm_config .model_config .hf_config
41
42
quant_config = self .get_quant_config (vllm_config )
42
43
44
+ # First layer uses 2*hidden_size (embeds + hidden_states concatenated)
45
+ # Subsequent layers use hidden_size (only hidden_states, no embeds)
46
+ qkv_input_size = 2 * self .hidden_size if layer_idx == 0 else self .hidden_size
47
+
43
48
# override qkv
44
49
self .self_attn .qkv_proj = QKVParallelLinear (
45
- 2 * self . hidden_size ,
50
+ qkv_input_size ,
46
51
self .self_attn .head_dim ,
47
52
self .self_attn .total_num_heads ,
48
53
self .self_attn .total_num_kv_heads ,
@@ -52,6 +57,7 @@ def __init__(
52
57
)
53
58
54
59
self .hidden_norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
60
+ self .layer_idx = layer_idx
55
61
56
62
if getattr (config , "norm_before_residual" , False ):
57
63
self ._residual_norm = self ._norm_before_residual
@@ -90,11 +96,15 @@ def forward(
90
96
hidden_states : torch .Tensor ,
91
97
residual : Optional [torch .Tensor ],
92
98
) -> tuple [torch .Tensor , torch .Tensor ]:
93
- embeds = self .input_layernorm (embeds )
94
-
95
- hidden_states , residual = self ._residual_norm (hidden_states = hidden_states )
99
+ if self .layer_idx == 0 :
100
+ # First layer: concatenate embeds with hidden_states
101
+ embeds = self .input_layernorm (embeds )
102
+ hidden_states , residual = self ._residual_norm (hidden_states = hidden_states )
103
+ hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
104
+ else :
105
+ # Subsequent layers: process hidden_states and residuals only
106
+ hidden_states , residual = self .input_layernorm (hidden_states , residual )
96
107
97
- hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
98
108
# Self Attention
99
109
hidden_states = self .self_attn (
100
110
positions = positions ,
@@ -133,9 +143,11 @@ def __init__(
133
143
[
134
144
LlamaDecoderLayer (
135
145
current_vllm_config ,
136
- prefix = maybe_prefix (prefix , f"layers.{ start_layer_id } " ),
146
+ prefix = maybe_prefix (prefix , f"layers.{ layer_idx + start_layer_id } " ),
137
147
config = self .config ,
148
+ layer_idx = layer_idx ,
138
149
)
150
+ for layer_idx in range (self .config .num_hidden_layers )
139
151
]
140
152
)
141
153
if hasattr (self .config , "target_hidden_size" ):
@@ -166,13 +178,13 @@ def forward(
166
178
assert hidden_states .shape [- 1 ] == input_embeds .shape [- 1 ]
167
179
168
180
residual = None
169
- hidden_states , residual = self .layers [ 0 ](
170
- positions ,
171
- input_embeds ,
172
- hidden_states ,
173
- residual ,
174
- )
175
-
181
+ for layer in self .layers :
182
+ hidden_states , residual = layer (
183
+ positions = positions ,
184
+ embeds = input_embeds ,
185
+ hidden_states = hidden_states ,
186
+ residual = residual ,
187
+ )
176
188
hidden_states , hidden_prenorm = self .norm (hidden_states , residual )
177
189
return hidden_states , hidden_prenorm
178
190
0 commit comments