Skip to content

Commit 013e4b7

Browse files
committed
Fixed CB bug for SwiftKV
Signed-off-by: Onkar Chougule <[email protected]>
1 parent 266d67e commit 013e4b7

File tree

3 files changed

+28
-27
lines changed

3 files changed

+28
-27
lines changed

QEfficient/transformers/cache_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def write_only(self, key_states, value_states, layer_idx, cache_kwargs):
6262
self.value_cache[layer_idx], position_ids, value_states
6363
)
6464

65-
def read_only(self, layer_idx, **cache_kwargs):
65+
def read_only(self, layer_idx, cache_kwargs):
6666
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
6767
position_ids = cache_kwargs.get("position_ids")
6868
batch_index = cache_kwargs.get("batch_index", None)

QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def forward(
9494
) -> torch.Tensor:
9595
bsz, q_len, _ = hidden_states.size()
9696
query = self.q_proj_swiftkv(hidden_states)
97-
9897
# Reshape the query, key, and value tensors.
9998
query_states = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
10099

@@ -107,10 +106,9 @@ def forward(
107106
"with a layer index."
108107
)
109108
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
109+
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index}
110+
key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs)
110111

111-
key_states, value_states = past_key_value.read_only(
112-
self.layer_idx, position_ids=position_ids, batch_index=batch_index
113-
)
114112
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
115113
position_ids = position_ids[torch.arange(bsz), position_ids.to(torch.int32).argmax(1)].unsqueeze(1)
116114
query_states, _ = qeff_apply_rotary_pos_emb(
@@ -121,10 +119,8 @@ def forward(
121119
value_states = repeat_kv(value_states, self.num_key_value_groups)
122120

123121
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
124-
125122
if attention_mask is not None: # no matter the length, we just slice it
126123
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
127-
128124
# upcast attention to fp32
129125
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
130126
# attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
@@ -148,7 +144,6 @@ def __init__(self, config: LlamaSwiftKVConfig, layer_idx) -> None:
148144
super().__init__()
149145
self.hidden_size = config.hidden_size
150146
self.num_key_value_heads = config.num_key_value_heads
151-
152147
self.self_attn = LlamaSwiftKVAttention(config=config, layer_idx=layer_idx)
153148
self.mlp = LlamaMLP(config)
154149
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -343,7 +338,6 @@ def forward(
343338

344339
bsz, q_len, _ = hidden_states.size()
345340
swiftkv_hidden_states = self.norm_swiftkv(hidden_states)
346-
347341
####################################
348342
## THE MAGIC OF SWIFT KV BEGINS HERE
349343
####################################
@@ -374,24 +368,30 @@ def forward(
374368
last_pos_id = position_ids.to(torch.int32).argmax(1, keepdim=True)
375369
orig_hidden_states = hidden_states
376370

377-
hidden_states = orig_hidden_states[torch.arange(bsz), last_pos_id, :]
378-
379-
causal_mask = causal_mask[torch.arange(bsz), :, last_pos_id, :]
371+
# Extracting only the last valid position id to be processed by self-attn of half of the layers, as KV cache is already filled.
372+
if batch_index is not None:
373+
hidden_states = orig_hidden_states[batch_index, last_pos_id, :]
374+
causal_mask = causal_mask[batch_index, :, last_pos_id, :]
375+
else:
376+
hidden_states = orig_hidden_states[torch.arange(bsz), last_pos_id, :]
377+
causal_mask = causal_mask[torch.arange(bsz), :, last_pos_id, :]
380378

381379
hidden_states, next_decoder_cache = self._run_swiftkv_layers(
382380
hidden_states, position_ids, past_key_values, causal_mask, batch_index
383381
)
384-
385-
orig_hidden_states[torch.arange(bsz), last_pos_id, :] = hidden_states
382+
# We can fill the orig_hidden_states with the processed hidden_states here but it's not needed as for next token prediction
383+
# we only need the last valid pos_indices hidden_states.
384+
# Here the shape of hiden_states is [batch_size, 1, hidden_dim] instead of [batch_size, seq_len, hidden_dim]
385+
# This saves un-necessary data movement on devices.
386386
####################################
387387
## THE MAGIC OF SWIFT KV ENDS HERE
388388
####################################
389389

390390
next_cache = next_decoder_cache.to_legacy_cache()
391-
return orig_hidden_states, next_cache
391+
return hidden_states, next_cache
392392

393393

394-
class LlamaSwiftKVForCausalLM(PreTrainedModel):
394+
class LlamaSwiftKVForCausalLM(PreTrainedModel): #
395395
config_class = LlamaSwiftKVConfig
396396

397397
def __init__(self, config: LlamaSwiftKVConfig):
@@ -412,8 +412,6 @@ def forward(
412412
batch_index: Optional[torch.LongTensor] = None,
413413
):
414414
hidden_states, output_past_key_values = self.model(input_ids, position_ids, past_key_values, batch_index)
415-
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
416-
hidden_states = hidden_states[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
417415
logits = self.lm_head(hidden_states)
418416
return CausalLMOutputWithPast(
419417
loss=None,

tests/transformers/models/test_causal_lm_models.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import os
99
from typing import Optional
1010

11+
import numpy as np
12+
1113
import pytest
1214
from transformers import AutoModelForCausalLM
1315

@@ -123,17 +125,18 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
123125

124126
# testing for CB models
125127
model_hf, _ = load_causal_lm_model(model_config)
128+
config = model_hf.config
126129
full_batch_size = 4
127130
fbs_prompts = Constants.INPUT_STR * 4
128-
api_runner = ApiRunner(
129-
batch_size,
130-
tokenizer,
131-
config,
132-
fbs_prompts,
133-
Constants.PROMPT_LEN,
134-
Constants.CTX_LEN,
135-
full_batch_size,
136-
)
131+
# api_runner = ApiRunner(
132+
# batch_size,
133+
# tokenizer,
134+
# config,
135+
# fbs_prompts,
136+
# Constants.PROMPT_LEN,
137+
# Constants.CTX_LEN,
138+
# full_batch_size,
139+
# )
137140

138141
# pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch_CB(model_hf)
139142
# pytorch_hf_tokens = np.vstack(pytorch_hf_tokens)

0 commit comments

Comments
 (0)