Skip to content

Add a simple sdpa (#3037) #3166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 19, 2024
Merged

Conversation

cccclai
Copy link
Contributor

@cccclai cccclai commented Apr 19, 2024

Summary:
Pull Request resolved: #3037

Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including torch.where

def forward(self, q, k, v):
    aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605);  q = None
    aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2);  aten_arange_start_step = None
    aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1);  aten_arange_start_step_1 = None
    aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1);  aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None
    aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0);  aten_sub_tensor = None
    aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default);  aten_le_scalar = aten_full_default = None
    aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format)
    aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default);  aten_logical_and_default = None
    aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
    aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default);  aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None
    aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]);  k = None
    aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605);  aten_permute_copy_default = None
    aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]);  aten_mul_scalar = None
    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]);  aten_expand_copy_default = None
    aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]);  aten_mul_scalar_1 = None
    aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]);  aten_expand_copy_default_1 = None
    aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1);  aten_view_copy_default = aten_view_copy_default_1 = None
    aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]);  aten_bmm_default = None
    aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self);  aten_view_copy_default_2 = aten_where_self = None
    aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False);  aten_add_tensor = None
    aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]);  aten__softmax_default = None
    aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]);  aten_expand_copy_default_2 = None
    aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]);  v = None
    aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]);  aten_expand_copy_default_3 = None
    aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4);  aten_view_copy_default_3 = aten_view_copy_default_4 = None
    aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]);  aten_bmm_default_1 = None
    return (aten_view_copy_default_5,)

After applying the diff, we remove the following ops

    %aten_full_like_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.full_like.default](args = (%aten_index_tensor_2, 0), kwargs = {dtype: torch.float32, pin_memory: False, memory_format: torch.preserve_format})

    %aten_logical_not_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.logical_not.default](args = (%aten_index_tensor_2,), kwargs = {})

    %aten_scalar_tensor_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.scalar_tensor.default](args = (-inf,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})

    %aten_where_self : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.where.self](args = (%aten_logical_not_default, %aten_scalar_tensor_default, %aten_full_like_default), kwargs = {})

    %aten_mul_scalar : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mul.Scalar](args = (%aten_permute_copy_default_3, 0.5946035575013605), kwargs = {})
    ...
    %aten_mul_scalar_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mul.Scalar](args = (%aten_permute_copy_default_6, 0.5946035575013605), kwargs = {})

but introduce an add
%aten_add_tensor_3 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mul_tensor_11, %aten_index_tensor_2), kwargs = {})

ghstack-source-id: 223152096
exported-using-ghexport

Reviewed By: mergennachin, kimishpatel

Differential Revision: D56119737

fbshipit-source-id: ec8e875f0a4c4ec67b7493e4872c9a5b081e6de7
(cherry picked from commit cf781073f8dd369930d00cfa95807a96cbb08705)

Summary:
Pull Request resolved: pytorch#3037

Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where`
```
def forward(self, q, k, v):
    aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605);  q = None
    aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2);  aten_arange_start_step = None
    aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1);  aten_arange_start_step_1 = None
    aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1);  aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None
    aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0);  aten_sub_tensor = None
    aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default);  aten_le_scalar = aten_full_default = None
    aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format)
    aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default);  aten_logical_and_default = None
    aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
    aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default);  aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None
    aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]);  k = None
    aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605);  aten_permute_copy_default = None
    aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]);  aten_mul_scalar = None
    aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]);  aten_expand_copy_default = None
    aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]);  aten_mul_scalar_1 = None
    aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]);  aten_expand_copy_default_1 = None
    aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1);  aten_view_copy_default = aten_view_copy_default_1 = None
    aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]);  aten_bmm_default = None
    aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self);  aten_view_copy_default_2 = aten_where_self = None
    aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False);  aten_add_tensor = None
    aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]);  aten__softmax_default = None
    aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]);  aten_expand_copy_default_2 = None
    aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]);  v = None
    aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]);  aten_expand_copy_default_3 = None
    aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4);  aten_view_copy_default_3 = aten_view_copy_default_4 = None
    aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]);  aten_bmm_default_1 = None
    return (aten_view_copy_default_5,)
```
After applying the diff, we remove the following ops
```
    %aten_full_like_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.full_like.default](args = (%aten_index_tensor_2, 0), kwargs = {dtype: torch.float32, pin_memory: False, memory_format: torch.preserve_format})

    %aten_logical_not_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.logical_not.default](args = (%aten_index_tensor_2,), kwargs = {})

    %aten_scalar_tensor_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.scalar_tensor.default](args = (-inf,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})

    %aten_where_self : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.where.self](args = (%aten_logical_not_default, %aten_scalar_tensor_default, %aten_full_like_default), kwargs = {})

    %aten_mul_scalar : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mul.Scalar](args = (%aten_permute_copy_default_3, 0.5946035575013605), kwargs = {})
    ...
    %aten_mul_scalar_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mul.Scalar](args = (%aten_permute_copy_default_6, 0.5946035575013605), kwargs = {})
```
but introduce an add
    %aten_add_tensor_3 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mul_tensor_11, %aten_index_tensor_2), kwargs = {})
```
ghstack-source-id: 223152096
exported-using-ghexport

Reviewed By: mergennachin, kimishpatel

Differential Revision: D56119737

fbshipit-source-id: ec8e875f0a4c4ec67b7493e4872c9a5b081e6de7
(cherry picked from commit cf78107)
Copy link

pytorch-bot bot commented Apr 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/3166

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit dc0b5bd with merge base d3326a2 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 19, 2024
@guangy10
Copy link
Contributor

@cccclai There is a llama-runner failure

@cccclai
Copy link
Contributor Author

cccclai commented Apr 19, 2024

that's llava_encoder failure

@guangy10 guangy10 merged commit efb7cf3 into pytorch:release/0.2 Apr 19, 2024
@mergennachin mergennachin mentioned this pull request Apr 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants