@@ -94,7 +94,6 @@ def forward(
94
94
) -> torch .Tensor :
95
95
bsz , q_len , _ = hidden_states .size ()
96
96
query = self .q_proj_swiftkv (hidden_states )
97
-
98
97
# Reshape the query, key, and value tensors.
99
98
query_states = query .view (bsz , q_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
100
99
@@ -107,10 +106,9 @@ def forward(
107
106
"with a layer index."
108
107
)
109
108
kv_seq_len = past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
109
+ cache_kwargs = {"position_ids" : position_ids , "batch_index" : batch_index }
110
+ key_states , value_states = past_key_value .read_only (self .layer_idx , cache_kwargs = cache_kwargs )
110
111
111
- key_states , value_states = past_key_value .read_only (
112
- self .layer_idx , position_ids = position_ids , batch_index = batch_index
113
- )
114
112
cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
115
113
position_ids = position_ids [torch .arange (bsz ), position_ids .to (torch .int32 ).argmax (1 )].unsqueeze (1 )
116
114
query_states , _ = qeff_apply_rotary_pos_emb (
@@ -121,10 +119,8 @@ def forward(
121
119
value_states = repeat_kv (value_states , self .num_key_value_groups )
122
120
123
121
attn_weights = torch .matmul (query_states , key_states .transpose (2 , 3 )) / math .sqrt (self .head_dim )
124
-
125
122
if attention_mask is not None : # no matter the length, we just slice it
126
123
attn_weights = torch .where (attention_mask , torch .tensor (- 10000.0 , dtype = torch .float32 ), attn_weights )
127
-
128
124
# upcast attention to fp32
129
125
attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query_states .dtype )
130
126
# attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
@@ -148,7 +144,6 @@ def __init__(self, config: LlamaSwiftKVConfig, layer_idx) -> None:
148
144
super ().__init__ ()
149
145
self .hidden_size = config .hidden_size
150
146
self .num_key_value_heads = config .num_key_value_heads
151
-
152
147
self .self_attn = LlamaSwiftKVAttention (config = config , layer_idx = layer_idx )
153
148
self .mlp = LlamaMLP (config )
154
149
self .input_layernorm = LlamaRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
@@ -343,7 +338,6 @@ def forward(
343
338
344
339
bsz , q_len , _ = hidden_states .size ()
345
340
swiftkv_hidden_states = self .norm_swiftkv (hidden_states )
346
-
347
341
####################################
348
342
## THE MAGIC OF SWIFT KV BEGINS HERE
349
343
####################################
@@ -374,24 +368,30 @@ def forward(
374
368
last_pos_id = position_ids .to (torch .int32 ).argmax (1 , keepdim = True )
375
369
orig_hidden_states = hidden_states
376
370
377
- hidden_states = orig_hidden_states [torch .arange (bsz ), last_pos_id , :]
378
-
379
- causal_mask = causal_mask [torch .arange (bsz ), :, last_pos_id , :]
371
+ # Extracting only the last valid position id to be processed by self-attn of half of the layers, as KV cache is already filled.
372
+ if batch_index is not None :
373
+ hidden_states = orig_hidden_states [batch_index , last_pos_id , :]
374
+ causal_mask = causal_mask [batch_index , :, last_pos_id , :]
375
+ else :
376
+ hidden_states = orig_hidden_states [torch .arange (bsz ), last_pos_id , :]
377
+ causal_mask = causal_mask [torch .arange (bsz ), :, last_pos_id , :]
380
378
381
379
hidden_states , next_decoder_cache = self ._run_swiftkv_layers (
382
380
hidden_states , position_ids , past_key_values , causal_mask , batch_index
383
381
)
384
-
385
- orig_hidden_states [torch .arange (bsz ), last_pos_id , :] = hidden_states
382
+ # We can fill the orig_hidden_states with the processed hidden_states here but it's not needed as for next token prediction
383
+ # we only need the last valid pos_indices hidden_states.
384
+ # Here the shape of hiden_states is [batch_size, 1, hidden_dim] instead of [batch_size, seq_len, hidden_dim]
385
+ # This saves un-necessary data movement on devices.
386
386
####################################
387
387
## THE MAGIC OF SWIFT KV ENDS HERE
388
388
####################################
389
389
390
390
next_cache = next_decoder_cache .to_legacy_cache ()
391
- return orig_hidden_states , next_cache
391
+ return hidden_states , next_cache
392
392
393
393
394
- class LlamaSwiftKVForCausalLM (PreTrainedModel ):
394
+ class LlamaSwiftKVForCausalLM (PreTrainedModel ): #
395
395
config_class = LlamaSwiftKVConfig
396
396
397
397
def __init__ (self , config : LlamaSwiftKVConfig ):
@@ -412,8 +412,6 @@ def forward(
412
412
batch_index : Optional [torch .LongTensor ] = None ,
413
413
):
414
414
hidden_states , output_past_key_values = self .model (input_ids , position_ids , past_key_values , batch_index )
415
- logit_index = position_ids .to (torch .int32 ).argmax (1 , keepdim = True )
416
- hidden_states = hidden_states [torch .arange (position_ids .shape [0 ]).view (- 1 , 1 ), logit_index ]
417
415
logits = self .lm_head (hidden_states )
418
416
return CausalLMOutputWithPast (
419
417
loss = None ,
0 commit comments