Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,10 @@ class SDPACustom(torch.nn.Module):
def __init__(
self,
kv_cache: KVCache,
mask,
dim: int,
):
super().__init__()
self.kv_cache = kv_cache
self.mask = mask
self.dim = dim

def forward(
Expand All @@ -112,6 +110,7 @@ def forward(
v: torch.Tensor,
bsz,
seqlen,
mask,
):
output = torch.ops.llama.sdpa_with_kv_cache(
q,
Expand All @@ -131,7 +130,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
setattr(
module,
name,
SDPACustom(child.kv_cache, child.mask, child.dim),
SDPACustom(child.kv_cache, child.dim),
)
else:
_replace_sdpa_with_custom_op(child)
Expand Down
19 changes: 10 additions & 9 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,14 @@ class SDPA(nn.Module):
def __init__(
self,
kv_cache: KVCache,
mask,
dim: int,
head_dim: int,
n_rep: int,
):
super().__init__()
self.kv_cache = kv_cache
self.mask = mask
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep

def forward(
Expand All @@ -215,17 +215,18 @@ def forward(
v: torch.Tensor,
bsz,
seqlen,
mask: torch.Tensor,
) -> torch.Tensor:
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
mask = self.mask[None, None, input_pos]
attn_mask = mask[None, None, input_pos]

k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)

return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

Expand Down Expand Up @@ -271,10 +272,10 @@ def __init__(self, args: ModelArgs, layer_id: int):
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
)
self.SDPA = SDPA(
self.kv_cache,
self.mask,
self.dim,
self.n_rep,
kv_cache=self.kv_cache,
dim=self.dim,
head_dim=self.head_dim,
n_rep=self.n_rep,
)

def forward(
Expand All @@ -298,7 +299,7 @@ def forward(

if self.use_kv_cache:
assert input_pos is not None
output = self.SDPA(input_pos, q, k, v, bsz, seqlen)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
return self.wo(output)

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
Expand Down