Skip to content

Commit 3f2f7d3

Browse files
gramalingamCopilotjustinchuby
authored
Attention mask for GQA fusion (#2452)
Expand the GQA fusion rule to handle attention mask better. * The patterns are extended to handle variations found in the attention-mask logic for various models. * It incorporates some optimizations of ModelBuilder that are arguably not general-purpose, but make assumptions about the intended use-case (which is the GenAI usage pattern). --------- Signed-off-by: Ganesan Ramalingam <[email protected]> Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: gramalingam <[email protected]> Co-authored-by: Justin Chu <[email protected]>
1 parent c33fce2 commit 3f2f7d3

File tree

1 file changed

+140
-101
lines changed
  • onnxscript/rewriter/ort_fusions

1 file changed

+140
-101
lines changed

onnxscript/rewriter/ort_fusions/gqa.py

Lines changed: 140 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import onnx_ir as ir
99

1010
import onnxscript.rewriter._fusion_utils as _fusion_utils
11-
from onnxscript.rewriter import _ir_utils, pattern
11+
from onnxscript.rewriter import _basics, _ir_utils, pattern
1212

1313
"""
1414
GroupQueryAttention: This generalizes MHA by allowing the number of heads to be different
@@ -32,7 +32,20 @@
3232
Dim = Union[int, ir.SymbolicDim]
3333

3434

35-
def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111):
35+
def _is_model_input(value: ir.Value, name: str, model: ir.Model) -> bool:
36+
return value in model.graph.inputs and value.name == name
37+
38+
39+
def _causal_mask(
40+
op,
41+
input_ids,
42+
past_kv_cache,
43+
shape_B111,
44+
min_val,
45+
window_size,
46+
dtype,
47+
):
48+
"""Defines a pattern for a pure causal mask, with optional sliding window support."""
3649
seq_len = op.Shape(input_ids, end=2, start=1)
3750
seq_len_0D = op.Squeeze(seq_len)
3851

@@ -42,28 +55,93 @@ def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111):
4255
total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D)
4356
total_seq_len = op.Reshape(total_seq_len_0D, [-1])
4457

45-
# The Phi modeling code generates the following +1 as the target-length, which seems
46-
# unnecessary in this context. But using it for pattern-matching against
47-
# generated onnx model.
48-
total_seq_len_plus_1_0D = op.Add(total_seq_len_0D, 1)
49-
total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0D, [-1])
50-
5158
current_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1)
52-
mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0)
53-
min_float32 = float(np.finfo(np.float32).min)
54-
mask_all_min = op.Expand(min_float32, mask_shape)
55-
total_range_as_row = op.Range(0, total_seq_len_plus_1_0D, 1)
59+
mask_shape = op.Concat(seq_len, total_seq_len, axis=0)
60+
mask_all_min_expand = op.Expand(min_val, mask_shape)
61+
# The following Trilu is optional: not used in Phi models, but used in LLama.
62+
mask_all_min_trilu = op.Trilu(mask_all_min_expand, 1, upper=1)
63+
mask_all_min = pattern.OrValue([mask_all_min_expand, mask_all_min_trilu])
64+
total_range_as_row = op.Range(0, total_seq_len_0D, 1)
5665
current_range_as_column = op.Reshape(current_range, [-1, 1])
57-
boolean_mask = op.Greater(total_range_as_row, current_range_as_column)
58-
float_0_1_mask = op.Cast(boolean_mask, to=1)
66+
67+
non_causal = op.Greater(total_range_as_row, current_range_as_column)
68+
69+
# sliding window support:
70+
current_range_minus_window = op.Sub(current_range_as_column, window_size)
71+
out_of_sliding_window = op.LessOrEqual(total_range_as_row, current_range_minus_window)
72+
non_causal_sliding_window = op.Or(non_causal, out_of_sliding_window)
73+
74+
boolean_mask = pattern.OrValue([non_causal, non_causal_sliding_window])
75+
76+
float_0_1_mask = op.Cast(boolean_mask, to=dtype)
5977
float_0_min_mask = op.Mul(mask_all_min, float_0_1_mask)
60-
mask_4d = op.Unsqueeze(float_0_min_mask, [0, 1])
61-
mask_B1ST_plus = op.Expand(mask_4d, shape_B111)
78+
mask_4d_11ST = op.Unsqueeze(float_0_min_mask, [0, 1])
79+
mask_4d_B1ST = op.Expand(mask_4d_11ST, shape_B111)
80+
81+
return mask_4d_B1ST
82+
83+
84+
class _CausalMaskPattern(pattern.PatternBase):
85+
def pattern(
86+
self,
87+
op,
88+
input_ids,
89+
past_kv_cache,
90+
shape_B111,
91+
min_val,
92+
window_size,
93+
dtype1,
94+
attn_mask_2d,
95+
dtype2,
96+
):
97+
causal_mask = _causal_mask(
98+
op,
99+
input_ids,
100+
past_kv_cache,
101+
shape_B111,
102+
min_val,
103+
window_size,
104+
dtype1,
105+
)
106+
107+
attn_mask_4d = op.Unsqueeze(attn_mask_2d, [1, 2])
108+
attn_mask_4d_cast = op.Cast(attn_mask_4d, to=dtype2)
109+
110+
sum = op.Add(causal_mask, attn_mask_4d_cast)
111+
sum_fp32 = op.Cast(sum, to=ir.DataType.FLOAT)
112+
# The cast is optional, and may be absent if the sum is already in float32.
113+
sum_fp32 = pattern.OrValue([sum_fp32, sum])
114+
is_zero = op.Equal(sum_fp32, 0.0)
115+
result = op.Where(is_zero, min_val, causal_mask)
116+
return result
117+
118+
def check(self, context, dtype1, dtype2, min_val, attn_mask_2d, sliding_window=None, **_):
119+
# Check that attn_mask_2d is the model input "attention_mask"
120+
if not _is_model_input(attn_mask_2d, "attention_mask", context.model):
121+
return pattern.MatchResult().fail("Invalid attention_mask input", attn_mask_2d)
122+
123+
if dtype1.as_int() != dtype2.as_int():
124+
return pattern.MatchResult().fail("Dtype mismatch", [dtype1, dtype2])
125+
126+
# Check that min_val is a constant and matches the expected minimum value for the dtype.
127+
min_value = _ir_utils.get_singleton_value(min_val)
128+
if min_value is None:
129+
return pattern.MatchResult().fail("Minval is not a constant.", min_val)
130+
expected_min_value = np.finfo(min_val.dtype.numpy()).min
131+
if min_value != expected_min_value:
132+
return pattern.MatchResult().fail(
133+
f"Expected min value {expected_min_value}, got {min_value}", min_val
134+
)
135+
136+
# TODO(rama) Sliding window: not yet supported.
137+
if sliding_window:
138+
return pattern.MatchResult().fail(
139+
"Sliding window not yet supported", sliding_window
140+
)
141+
return True
62142

63-
# Get rid of the extra +1 added above: total_seq_len is enough, no
64-
# need for total_seq_len+1.
65-
mask_B1ST = op.Slice(mask_B1ST_plus, [0], total_seq_len, [3], [1])
66-
return mask_B1ST
143+
144+
_causal_mask_pattern = _CausalMaskPattern()
67145

68146

69147
class GroupQueryAttention(pattern.RewriteRuleClassBase):
@@ -78,8 +156,7 @@ def pattern(
78156
value_BSDkv,
79157
past_key,
80158
past_value,
81-
position_ids_q,
82-
position_ids_k,
159+
position_ids,
83160
cos,
84161
sin,
85162
mask,
@@ -101,15 +178,15 @@ def pattern(
101178

102179
query_BHSDh_rope = op.RotaryEmbedding(
103180
query_BHSDh,
104-
position_ids_q,
181+
position_ids,
105182
cos,
106183
sin,
107184
_domain="com.microsoft",
108185
_outputs=["query_BHSDh_rope"],
109186
)
110187
key_BHkvSDh_rope = op.RotaryEmbedding(
111188
key_BHkvSDh,
112-
position_ids_k,
189+
position_ids,
113190
cos,
114191
sin,
115192
_domain="com.microsoft",
@@ -154,7 +231,7 @@ def pattern(
154231

155232
def check(
156233
self,
157-
op,
234+
context: _basics.MatchContext,
158235
query_BSD,
159236
key_BSDkv,
160237
value_BSDkv,
@@ -164,6 +241,7 @@ def check(
164241
key_BHkvSDh_rope,
165242
query_BSHDh,
166243
key_BSHkvDh,
244+
mask,
167245
**_,
168246
):
169247
bindings: dict[str, Dim] = {}
@@ -210,6 +288,20 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
210288
)
211289
self._interleaved = query_interleaved
212290

291+
# Check mask:
292+
mask_node = mask.producer()
293+
if mask_node is None:
294+
return pattern.MatchResult().fail("Unhandled mask pattern", mask)
295+
mask_match_result = _causal_mask_pattern.match(
296+
context.model,
297+
context.graph_or_function,
298+
mask_node,
299+
check_nodes_are_removable=False,
300+
)
301+
if mask_match_result is None:
302+
return pattern.MatchResult().fail("Mask does not match causal mask pattern", mask)
303+
# TODO: handle sliding window support in mask
304+
213305
return True
214306

215307
def rewrite(
@@ -220,104 +312,51 @@ def rewrite(
220312
value_BSDkv,
221313
past_key,
222314
past_value,
223-
position_ids_q,
224-
position_ids_k,
315+
position_ids,
225316
cos,
226317
sin,
227318
mask,
228319
**_,
229320
):
230-
return op.GQA(
231-
mask,
232-
position_ids_k,
233-
position_ids_q,
321+
# Note that the following optimization is specific to current ORT GenAI attention-mask
322+
# usage. Specifically, it assumes that the model-input "attention_mask" is a 2D
323+
# mask with shape [batch_size, sequence_length], and that the mask is a 0/1 mask
324+
# that is used only to indicate the current tokens. Hence, the input attention_mask
325+
# is redundant as long as past-sequence-length and current-sequence-length can be
326+
# computed.
327+
328+
# Construct seqlens_k and total_seq_length_int32 from position_ids
329+
# seqlens_k : int32[batch_size] indicates total_sequence-length-1 for each batch
330+
# position_ids: int64[batch_size, sequence_length] indicates the position of each token
331+
one_int32_0d = op.Constant(value=ir.tensor(1, dtype=ir.DataType.INT32))
332+
one_int64_1d = op.Constant(value=ir.tensor([1], dtype=ir.DataType.INT64))
333+
zero_int64_1d = op.Constant(value=ir.tensor([0], dtype=ir.DataType.INT64))
334+
seqlens_k_int64 = op.ReduceMax(position_ids, one_int64_1d, keepdims=0)
335+
seqlens_k = op.Cast(seqlens_k_int64, to=ir.DataType.INT32)
336+
max_seq_length = op.ReduceMax(seqlens_k, zero_int64_1d, keepdims=0)
337+
total_seq_length_int32 = op.Add(max_seq_length, one_int32_0d)
338+
return op.GroupQueryAttention(
234339
query_BSD,
235340
key_BSDkv,
236341
value_BSDkv,
237342
past_key,
238343
past_value,
239-
None, # seqlens_k,
240-
None, # total_seq_length_int32,
344+
seqlens_k,
345+
total_seq_length_int32,
241346
cos,
242347
sin,
243348
num_heads=self.num_heads,
244349
kv_num_heads=self.kv_num_heads,
245350
do_rotary=1,
246351
rotary_interleaved=self._interleaved,
247352
# skipped optional attributes: local_window_size, scale, smooth_softmax, softcap
248-
_domain="ai.onnxruntime._fusion",
353+
_domain="com.microsoft",
249354
_outputs=3,
250355
)
251356

252357

253-
class GQACausalMask(pattern.RewriteRuleClassBase):
254-
def __init__(self):
255-
super().__init__("GQACausalMask", remove_nodes=False)
256-
257-
def pattern(
258-
self,
259-
op,
260-
mask,
261-
input_ids,
262-
some_kv_cache,
263-
shape_B111,
264-
past_seq_length,
265-
total_seq_length,
266-
):
267-
mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111)
268-
position_ids = op.Range(past_seq_length, total_seq_length, 1)
269-
position_ids_q = op.Unsqueeze(position_ids, [0])
270-
position_ids_k = op.Unsqueeze(position_ids, [0])
271-
return op.GQA(
272-
mask,
273-
position_ids_k,
274-
position_ids_q,
275-
_allow_other_inputs=True,
276-
_domain="ai.onnxruntime._fusion",
277-
_outputs=["attn_output", "key_seq", "value_seq"],
278-
)
279-
280-
def rewrite(
281-
self,
282-
op,
283-
total_seq_length,
284-
attn_output,
285-
**_,
286-
):
287-
# Construct total_seq_length_int32 and seqlens_k
288-
total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32)
289-
one_0D = op.Constant(value_int=1)
290-
one_0D_int32 = op.Cast(one_0D, to=ir.DataType.INT32)
291-
seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D_int32)
292-
zero_1D = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1])
293-
seqlens_k = op.Unsqueeze(seqlens_k_0D, zero_1D)
294-
295-
gqa_node = attn_output.producer()
296-
assert len(gqa_node.inputs) == 12, (
297-
f"Expected 12 inputs for GQA node, got {len(gqa_node.inputs)}"
298-
)
299-
query, key, value, past_key, past_value = gqa_node.inputs[3:8]
300-
cos, sin = gqa_node.inputs[10:12]
301-
updated_inputs = [
302-
query,
303-
key,
304-
value,
305-
past_key,
306-
past_value,
307-
seqlens_k,
308-
total_seq_length_int32,
309-
cos,
310-
sin,
311-
]
312-
attributes = gqa_node.attributes
313-
return op.GroupQueryAttention(
314-
*updated_inputs, **attributes, _domain="com.microsoft", _outputs=3
315-
)
316-
317-
318358
_basic_gqa_rule = GroupQueryAttention.rule()
319-
_gqa_causal_mask_rule = GQACausalMask.rule()
320359

321-
gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _gqa_causal_mask_rule])
360+
gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule])
322361

323362
fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules)

0 commit comments

Comments
 (0)