Skip to content

Commit ebed490

Browse files
kimishpatelpytorchmergebot
authored andcommitted
[sdpa decomp] change sdpa decomp to be consistent with flash attention (pytorch#108608)
Summary: See the comment in code for the reasons of the change Test Plan: buck2 test executorch/examples/export/test:test_export -- test_vit_export_to_executorch Differential Revision: D48992180 Pull Request resolved: pytorch#108608 Approved by: https://github.com/larryliu0820
1 parent 6edd064 commit ebed490

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

torch/_decomp/decompositions.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3967,6 +3967,40 @@ def scaled_dot_product_flash_attention(
39673967
output, _ = aten._scaled_dot_product_attention_math.default(
39683968
query, key, value, attn_mask, dropout_p, is_causal, None, scale=scale
39693969
)
3970+
# Why this change?
3971+
# In pre-dispatch export scaled_dot_product_attention is executed via
3972+
# * flash_attention.
3973+
# flash_attention allocates output tensor as (N, L, H, E)
3974+
# it then tranposes that to get (N, H, L, E) which is supposed to be the return
3975+
# tensor dim for scaled_dot_product_attention
3976+
# assume x: [N, H, L, E] is the output sdpa
3977+
# In MHA code, this output is then permuted via (2, 0, 1, 3) to get
3978+
# (L, N, H, E) dim tensor
3979+
# x = x.permute(2, 0, 1, 3).contiguous() and the viewed via
3980+
# x = x.view(L * N, H * E)
3981+
# During pre autograd dispatch call to contiguous is not traced because
3982+
# flash_attention output after the x.permute is already contiguous
3983+
# on which the view is valid
3984+
# However, during 2nd stage export, post-dispatch, we run _match variant
3985+
# instead of flash* to get the decomposition. _match variant returns
3986+
# x: [N, H, L, E] applying x.permute(2, 0, 1, 3) returns
3987+
# x: [L, N, H, E] and without converting this to contiguous tensor
3988+
# subsequent view is not valid and the export fails
3989+
# solution is to maintain the return tensor view from the decomp to be
3990+
# exactly same as *flash* variant.
3991+
# flash variants output is contiguous as [N, L, H, E]
3992+
# _match variant out is contiguous as [N, H, L, E]
3993+
# out = out.tranpose(1, 2).contiguous gets output as contiguous
3994+
# in [N, L, H, E].
3995+
# Subsrequent tranpose(1, 2) then returns a view on which
3996+
# aforementioned code snippet, as showm below, is valid
3997+
# x = x.permute(2, 0, 1, 3).contiguous() and the viewed via
3998+
# x = x.view(L * N, H * E)
3999+
4000+
# Really the invairant you want to maintain is:
4001+
# pre-dispatch op-output and its decomposed representation must
4002+
# return tensor with same view and dims
4003+
output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format)
39704004
return (
39714005
output.transpose(1, 2),
39724006
logsumexp,

0 commit comments

Comments
 (0)