Skip to content

Conversation

@IlyasMoutawwakil
Copy link
Contributor

This PR makes onnx sdpa export match the behavior of aten sdpa when boolean mask is used.

import onnxruntime as ort
import torch


class ScaledDotProductAttention(torch.nn.Module):
    def forward(self, query, key, value, attn_mask):
        return torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)


model = ScaledDotProductAttention()
attn_mask = torch.ones(2, 4, 8, 8).bool()  # boolean mask for attention
attn_mask[0, 0, 0, :] = False  # masking an entire row (padding token)
query = key = value = torch.randn(2, 4, 8, 16)
output = model(query, key, value, attn_mask)

torch.onnx.export(
    model,
    (query, key, value, attn_mask),
    "scaled_dot_product_attention.onnx",
    input_names=["query", "key", "value", "attn_mask"],
    output_names=["output"],
    opset_version=18,
    dynamo=True, # or False
)
ort_session = ort.InferenceSession("scaled_dot_product_attention.onnx")

np_inputs = {"query": query.numpy(), "key": key.numpy(), "value": value.numpy(), "attn_mask": attn_mask.numpy()}
onnx_outputs = ort_session.run(None, np_inputs)[0]

torch.testing.assert_close(output, torch.tensor(onnx_outputs), equal_nan=True)

fails the assertion because the ort model outputs nans.

@IlyasMoutawwakil
Copy link
Contributor Author

@titaiwangms @justinchuby

@titaiwangms titaiwangms enabled auto-merge (squash) August 7, 2025 15:55
@codecov
Copy link

codecov bot commented Aug 7, 2025

Codecov Report

❌ Patch coverage is 0% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 69.81%. Comparing base (32f2196) to head (0068e40).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/nn.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2479      +/-   ##
==========================================
- Coverage   69.81%   69.81%   -0.01%     
==========================================
  Files         209      209              
  Lines       25313    25314       +1     
  Branches     2525     2525              
==========================================
  Hits        17673    17673              
- Misses       6762     6763       +1     
  Partials      878      878              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@titaiwangms titaiwangms disabled auto-merge August 7, 2025 16:26
@titaiwangms titaiwangms merged commit ecb7677 into microsoft:main Aug 7, 2025
25 of 32 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in ONNX Script Review Board Aug 7, 2025
titaiwangms added a commit that referenced this pull request Aug 8, 2025
@justinchuby
Copy link
Collaborator

Do we need to update https://github.com/pytorch/pytorch/blob/main/torch/onnx/_internal/exporter/_torchlib/ops/nn.py as well, or improve specs of the Attention op? @gramalingam @titaiwangms

# This is because there's no safe/masked softmax imp in ONNX, so we need to handle NaN values explicitly to match
# the behavior of PyTorch with boolean masks.
attn_weight = op.Where(op.IsNaN(attn_weight), zero, attn_weight)
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@titaiwangms we should probably conditionally skip this line (even though there is a rewrite rule already)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you fix this, can you also please add a reference to pytorch/pytorch#103749 in the comments for the previous line fixing NaN?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We skip when dropout_p is 0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Aug 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

4 participants