@@ -106,6 +106,7 @@ def __init__(
106
106
cache_config = cache_config ,
107
107
quant_config = quant_config ,
108
108
prefix = f"{ prefix } .attn" ,
109
+ attn_type = self .attn_type ,
109
110
)
110
111
111
112
def _init_qkv (
@@ -134,12 +135,7 @@ def forward(
134
135
qkv , _ = self .qkv_proj (hidden_states )
135
136
q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
136
137
137
- attn_output = self .attn (q ,
138
- k ,
139
- v ,
140
- kv_cache ,
141
- attn_metadata ,
142
- attn_type = self .attn_type )
138
+ attn_output = self .attn (q , k , v , kv_cache , attn_metadata )
143
139
144
140
output , _ = self .out_proj (attn_output )
145
141
@@ -164,6 +160,7 @@ def __init__(
164
160
cache_config = cache_config ,
165
161
quant_config = quant_config ,
166
162
prefix = prefix ,
163
+ attn_type = AttentionType .ENCODER_DECODER ,
167
164
)
168
165
169
166
def _init_qkv (
@@ -207,12 +204,13 @@ def forward(
207
204
else :
208
205
k = v = None
209
206
210
- attn_output = self .attn (q ,
211
- k ,
212
- v ,
213
- kv_cache ,
214
- attn_metadata ,
215
- attn_type = AttentionType .ENCODER_DECODER )
207
+ attn_output = self .attn (
208
+ q ,
209
+ k ,
210
+ v ,
211
+ kv_cache ,
212
+ attn_metadata ,
213
+ )
216
214
217
215
output , _ = self .out_proj (attn_output )
218
216
@@ -734,4 +732,4 @@ def load_weights(self, weights: Iterable[Tuple[str,
734
732
loaded_weights = [(name , loaded_weight )
735
733
for name , loaded_weight in weights ]
736
734
mapper = WeightsMapper ({".fc1." : ".mlp.fc1." , ".fc2." : ".mlp.fc2." })
737
- return loader .load_weights (loaded_weights , mapper = mapper )
735
+ return loader .load_weights (loaded_weights , mapper = mapper )
0 commit comments