Skip to content

Support gqa in aten spda #2408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 158 additions & 19 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,10 +1772,6 @@ def aten_scaled_dot_product_attention(
"is_causal and attn_mask cannot be set at the same time"
)

assert not enable_gqa, (
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
)

# Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if scale is None:
scale = _attention_scale(query)
Expand All @@ -1790,7 +1786,7 @@ def aten_scaled_dot_product_attention(
)

return _aten_scaled_dot_product_attention_float_mask_onnx(
query, key, value, attn_mask, scale, dropout_p
query, key, value, attn_mask, scale, dropout_p, enable_gqa
)


Expand Down Expand Up @@ -1982,28 +1978,24 @@ def aten_scaled_dot_product_attention_bool_mask(
"is_causal and attn_mask cannot be set at the same time"
)

assert not enable_gqa, (
"conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
)

if scale is None:
scale = _attention_scale(query)
scale = op.CastLike(scale, query)

if is_causal:
attn_mask = _causal_attention_mask(query, key)
# The causal mask is always float
return _aten_scaled_dot_product_attention_float_mask_onnx(
query, key, value, attn_mask, scale, dropout_p
)

if attn_mask is None:
return _aten_scaled_dot_product_attention_no_mask_onnx(
query, key, value, scale, dropout_p
query, key, value, scale, dropout_p, enable_gqa=enable_gqa
)

return _aten_scaled_dot_product_attention_bool_mask_onnx(
query, key, value, attn_mask, scale, dropout_p
if attn_mask.dtype == ir.DataType.BOOL:
return _aten_scaled_dot_product_attention_bool_mask_onnx(
query, key, value, attn_mask, scale, dropout_p, enable_gqa=enable_gqa
)
Comment on lines +1994 to +1996

Check failure

Code scanning / CodeQL

Wrong name for an argument in a call

Keyword argument 'enable_gqa' is not a supported parameter name of [function _aten_scaled_dot_product_attention_bool_mask_onnx](1).

Copilot Autofix

AI 6 days ago

To fix the issue, the keyword argument enable_gqa should be removed from the call to _aten_scaled_dot_product_attention_bool_mask_onnx on line 1994. This ensures that the function is called with only the parameters it supports. The removal of enable_gqa will not affect the functionality of _aten_scaled_dot_product_attention_bool_mask_onnx, as it does not use this argument.

Suggested changeset 1
onnxscript/function_libs/torch_lib/ops/nn.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py
--- a/onnxscript/function_libs/torch_lib/ops/nn.py
+++ b/onnxscript/function_libs/torch_lib/ops/nn.py
@@ -1994,3 +1994,3 @@
         return _aten_scaled_dot_product_attention_bool_mask_onnx(
-            query, key, value, attn_mask, scale, dropout_p, enable_gqa=enable_gqa
+            query, key, value, attn_mask, scale, dropout_p
         )
EOF
@@ -1994,3 +1994,3 @@
return _aten_scaled_dot_product_attention_bool_mask_onnx(
query, key, value, attn_mask, scale, dropout_p, enable_gqa=enable_gqa
query, key, value, attn_mask, scale, dropout_p
)
Copilot is powered by AI and may make mistakes. Always verify output.
Unable to commit as this autofix suggestion is now outdated
return _aten_scaled_dot_product_attention_float_mask_onnx(
query, key, value, attn_mask, scale, dropout_p, enable_gqa=enable_gqa
)


Expand All @@ -2013,7 +2005,55 @@ def _aten_scaled_dot_product_attention_no_mask_onnx(
value: TFloat,
scale: TFloat,
dropout_p: float,
enable_gqa: bool,
) -> TFloat:
# Handle Grouped Query Attention (GQA) if enabled
if enable_gqa:
# Get head dimensions
query_shape = op.Shape(query)
key_shape = op.Shape(key)
query_heads = op.Slice(query_shape, [-3], [-2]) # query.size(-3)
key_heads = op.Slice(key_shape, [-3], [-2]) # key.size(-3)

# Calculate the repeat factor: query_heads // key_heads
repeat_factor = op.Div(query_heads, key_heads)

# Expand key and value to match query head dimension
# Implement key.repeat_interleave(repeat_factor, -3) using Expand
# First, get the shape of key and modify the head dimension
key_shape_expanded = op.Concat(
op.Slice(key_shape, [0], [-3]), # batch and other dims
op.Mul(key_heads, repeat_factor), # expanded head dimension
op.Slice(key_shape, [-2], [_INT64_MAX]), # remaining dims
axis=0
)
# Expand key by repeating each head 'repeat_factor' times
key_unsqueezed = op.Unsqueeze(key, [-2]) # Add dimension for repeating
key_tiled = op.Tile(key_unsqueezed, op.Concat(
op.Constant(value_ints=[1, 1, 1]), # don't repeat batch, seq, head dims
repeat_factor, # repeat factor for the new dimension
op.Constant(value_ints=[1, 1]), # don't repeat the remaining dims
axis=0
))
key = op.Reshape(key_tiled, key_shape_expanded)

# Same for value
value_shape = op.Shape(value)
value_shape_expanded = op.Concat(
op.Slice(value_shape, [0], [-3]), # batch and other dims
op.Mul(key_heads, repeat_factor), # expanded head dimension
op.Slice(value_shape, [-2], [_INT64_MAX]), # remaining dims
axis=0
)
value_unsqueezed = op.Unsqueeze(value, [-2])
value_tiled = op.Tile(value_unsqueezed, op.Concat(
op.Constant(value_ints=[1, 1, 1]),
repeat_factor,
op.Constant(value_ints=[1, 1]),
axis=0
))
value = op.Reshape(value_tiled, value_shape_expanded)

# Swap the last two axes of key
key_shape = op.Shape(key)
key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX]))
Expand All @@ -2037,7 +2077,8 @@ def _aten_scaled_dot_product_attention_no_mask_onnx(
op.MatMul(query_scaled, key_transposed_scaled),
axis=-1,
)
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
if dropout_p > 0.0:
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
return op.MatMul(attn_weight, value)


Expand All @@ -2048,7 +2089,55 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx(
attn_mask: BOOL,
scale: TFloat,
dropout_p: float,
enable_gqa: bool,
) -> TFloat:
# Handle Grouped Query Attention (GQA) if enabled
if enable_gqa:
# Get head dimensions
query_shape = op.Shape(query)
key_shape = op.Shape(key)
query_heads = op.Slice(query_shape, [-3], [-2]) # query.size(-3)
key_heads = op.Slice(key_shape, [-3], [-2]) # key.size(-3)

# Calculate the repeat factor: query_heads // key_heads
repeat_factor = op.Div(query_heads, key_heads)

# Expand key and value to match query head dimension
# Implement key.repeat_interleave(repeat_factor, -3) using Expand
# First, get the shape of key and modify the head dimension
key_shape_expanded = op.Concat(
op.Slice(key_shape, [0], [-3]), # batch and other dims
op.Mul(key_heads, repeat_factor), # expanded head dimension
op.Slice(key_shape, [-2], [_INT64_MAX]), # remaining dims
axis=0
)
# Expand key by repeating each head 'repeat_factor' times
key_unsqueezed = op.Unsqueeze(key, [-2]) # Add dimension for repeating
key_tiled = op.Tile(key_unsqueezed, op.Concat(
op.Constant(value_ints=[1, 1, 1]), # don't repeat batch, seq, head dims
repeat_factor, # repeat factor for the new dimension
op.Constant(value_ints=[1, 1]), # don't repeat the remaining dims
axis=0
))
key = op.Reshape(key_tiled, key_shape_expanded)

# Same for value
value_shape = op.Shape(value)
value_shape_expanded = op.Concat(
op.Slice(value_shape, [0], [-3]), # batch and other dims
op.Mul(key_heads, repeat_factor), # expanded head dimension
op.Slice(value_shape, [-2], [_INT64_MAX]), # remaining dims
axis=0
)
value_unsqueezed = op.Unsqueeze(value, [-2])
value_tiled = op.Tile(value_unsqueezed, op.Concat(
op.Constant(value_ints=[1, 1, 1]),
repeat_factor,
op.Constant(value_ints=[1, 1]),
axis=0
))
value = op.Reshape(value_tiled, value_shape_expanded)

# Swap the last two axes of key
key_shape = op.Shape(key)
key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX]))
Expand Down Expand Up @@ -2076,7 +2165,8 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx(
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),
axis=-1,
)
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
if dropout_p > 0.0:
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
return op.MatMul(attn_weight, value)


Expand All @@ -2087,7 +2177,55 @@ def _aten_scaled_dot_product_attention_float_mask_onnx(
attn_mask: TFloat,
scale: TFloat,
dropout_p: float,
enable_gqa: bool,
) -> TFloat:
# Handle Grouped Query Attention (GQA) if enabled
if enable_gqa:
# Get head dimensions
query_shape = op.Shape(query)
key_shape = op.Shape(key)
query_heads = op.Slice(query_shape, [-3], [-2]) # query.size(-3)
key_heads = op.Slice(key_shape, [-3], [-2]) # key.size(-3)

# Calculate the repeat factor: query_heads // key_heads
repeat_factor = op.Div(query_heads, key_heads)

# Expand key and value to match query head dimension
# Implement key.repeat_interleave(repeat_factor, -3) using Expand
# First, get the shape of key and modify the head dimension
key_shape_expanded = op.Concat(
op.Slice(key_shape, [0], [-3]), # batch and other dims
op.Mul(key_heads, repeat_factor), # expanded head dimension
op.Slice(key_shape, [-2], [_INT64_MAX]), # remaining dims
axis=0
)
# Expand key by repeating each head 'repeat_factor' times
key_unsqueezed = op.Unsqueeze(key, [-2]) # Add dimension for repeating
key_tiled = op.Tile(key_unsqueezed, op.Concat(
op.Constant(value_ints=[1, 1, 1]), # don't repeat batch, seq, head dims
repeat_factor, # repeat factor for the new dimension
op.Constant(value_ints=[1, 1]), # don't repeat the remaining dims
axis=0
))
key = op.Reshape(key_tiled, key_shape_expanded)

# Same for value
value_shape = op.Shape(value)
value_shape_expanded = op.Concat(
op.Slice(value_shape, [0], [-3]), # batch and other dims
op.Mul(key_heads, repeat_factor), # expanded head dimension
op.Slice(value_shape, [-2], [_INT64_MAX]), # remaining dims
axis=0
)
value_unsqueezed = op.Unsqueeze(value, [-2])
value_tiled = op.Tile(value_unsqueezed, op.Concat(
op.Constant(value_ints=[1, 1, 1]),
repeat_factor,
op.Constant(value_ints=[1, 1]),
axis=0
))
value = op.Reshape(value_tiled, value_shape_expanded)

# Swap the last two axes of key
key_shape = op.Shape(key)
key_last_dim = op.Slice(key_shape, [-1], op.Constant(value_ints=[_INT64_MAX]))
Expand All @@ -2111,7 +2249,8 @@ def _aten_scaled_dot_product_attention_float_mask_onnx(
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),
axis=-1,
)
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
if dropout_p > 0.0:
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
return op.MatMul(attn_weight, value)


Expand Down
Loading