From fa5197a92f0ec2068e0537182005820274f70ce8 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Sun, 14 Apr 2024 11:13:48 -0700 Subject: [PATCH 1/2] move mask as sdpa input instead of attribute sdpa (https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) input is taking attention mask as input, refactor the sdpa module input closer to the sdpa input Differential Revision: [D56119739](https://our.internmc.facebook.com/intern/diff/D56119739/) [ghstack-poisoned] --- examples/models/llama2/export_llama_lib.py | 5 ++--- examples/models/llama2/llama_transformer.py | 10 ++++------ 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index aa195209ad9..0e81715f350 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -96,12 +96,10 @@ class SDPACustom(torch.nn.Module): def __init__( self, kv_cache: KVCache, - mask, dim: int, ): super().__init__() self.kv_cache = kv_cache - self.mask = mask self.dim = dim def forward( @@ -112,6 +110,7 @@ def forward( v: torch.Tensor, bsz, seqlen, + mask, ): output = torch.ops.llama.sdpa_with_kv_cache( q, @@ -131,7 +130,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module): setattr( module, name, - SDPACustom(child.kv_cache, child.mask, child.dim), + SDPACustom(child.kv_cache, child.dim), ) else: _replace_sdpa_with_custom_op(child) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index c353a913bf0..7c3b4efb037 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -213,13 +213,11 @@ class SDPA(nn.Module): def __init__( self, kv_cache: KVCache, - mask, dim: int, n_rep: int, ): super().__init__() self.kv_cache = kv_cache - self.mask = mask self.dim = dim self.n_rep = n_rep @@ -231,17 +229,18 @@ def forward( v: torch.Tensor, bsz, seqlen, + mask: torch.Tensor, ) -> torch.Tensor: q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) k = k.transpose(1, 2) v = v.transpose(1, 2) k, v = self.kv_cache.update(input_pos, k, v) - mask = self.mask[None, None, input_pos] + attn_mask = self.mask[None, None, input_pos] k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) @@ -288,7 +287,6 @@ def __init__(self, args: ModelArgs, layer_id: int): ) self.SDPA = SDPA( self.kv_cache, - self.mask, self.dim, self.n_rep, ) @@ -314,7 +312,7 @@ def forward( if self.use_kv_cache: assert input_pos is not None - output = self.SDPA(input_pos, q, k, v, bsz, seqlen) + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) return self.wo(output) q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) From 3ea88cde0c913513b6a1ec692f78e16e97f26fef Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Sun, 14 Apr 2024 14:58:27 -0700 Subject: [PATCH 2/2] Update on "move mask as sdpa input instead of attribute" sdpa (https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) input is taking attention mask as input, refactor the sdpa module input closer to the sdpa input Differential Revision: [D56119739](https://our.internmc.facebook.com/intern/diff/D56119739/) [ghstack-poisoned] --- examples/models/llama2/llama_transformer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 7c3b4efb037..4184861f091 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -214,11 +214,13 @@ def __init__( self, kv_cache: KVCache, dim: int, + head_dim: int, n_rep: int, ): super().__init__() self.kv_cache = kv_cache self.dim = dim + self.head_dim = head_dim self.n_rep = n_rep def forward( @@ -236,7 +238,7 @@ def forward( v = v.transpose(1, 2) k, v = self.kv_cache.update(input_pos, k, v) - attn_mask = self.mask[None, None, input_pos] + attn_mask = mask[None, None, input_pos] k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) @@ -286,9 +288,10 @@ def __init__(self, args: ModelArgs, layer_id: int): not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v ) self.SDPA = SDPA( - self.kv_cache, - self.dim, - self.n_rep, + kv_cache=self.kv_cache, + dim=self.dim, + head_dim=self.head_dim, + n_rep=self.n_rep, ) def forward(