8
8
import onnx_ir as ir
9
9
10
10
import onnxscript .rewriter ._fusion_utils as _fusion_utils
11
- from onnxscript .rewriter import _ir_utils , pattern
11
+ from onnxscript .rewriter import _basics , _ir_utils , pattern
12
12
13
13
"""
14
14
GroupQueryAttention: This generalizes MHA by allowing the number of heads to be different
32
32
Dim = Union [int , ir .SymbolicDim ]
33
33
34
34
35
- def causal_mask_pattern (op , input_ids , past_kv_cache , shape_B111 ):
35
+ def _is_model_input (value : ir .Value , name : str , model : ir .Model ) -> bool :
36
+ return value in model .graph .inputs and value .name == name
37
+
38
+
39
+ def _causal_mask (
40
+ op ,
41
+ input_ids ,
42
+ past_kv_cache ,
43
+ shape_B111 ,
44
+ min_val ,
45
+ window_size ,
46
+ dtype ,
47
+ ):
48
+ """Defines a pattern for a pure causal mask, with optional sliding window support."""
36
49
seq_len = op .Shape (input_ids , end = 2 , start = 1 )
37
50
seq_len_0D = op .Squeeze (seq_len )
38
51
@@ -42,28 +55,93 @@ def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111):
42
55
total_seq_len_0D = op .Add (past_seq_len_0D , seq_len_0D )
43
56
total_seq_len = op .Reshape (total_seq_len_0D , [- 1 ])
44
57
45
- # The Phi modeling code generates the following +1 as the target-length, which seems
46
- # unnecessary in this context. But using it for pattern-matching against
47
- # generated onnx model.
48
- total_seq_len_plus_1_0D = op .Add (total_seq_len_0D , 1 )
49
- total_seq_len_plus_1 = op .Reshape (total_seq_len_plus_1_0D , [- 1 ])
50
-
51
58
current_range = op .Range (past_seq_len_0D , total_seq_len_0D , 1 )
52
- mask_shape = op .Concat (seq_len , total_seq_len_plus_1 , axis = 0 )
53
- min_float32 = float (np .finfo (np .float32 ).min )
54
- mask_all_min = op .Expand (min_float32 , mask_shape )
55
- total_range_as_row = op .Range (0 , total_seq_len_plus_1_0D , 1 )
59
+ mask_shape = op .Concat (seq_len , total_seq_len , axis = 0 )
60
+ mask_all_min_expand = op .Expand (min_val , mask_shape )
61
+ # The following Trilu is optional: not used in Phi models, but used in LLama.
62
+ mask_all_min_trilu = op .Trilu (mask_all_min_expand , 1 , upper = 1 )
63
+ mask_all_min = pattern .OrValue ([mask_all_min_expand , mask_all_min_trilu ])
64
+ total_range_as_row = op .Range (0 , total_seq_len_0D , 1 )
56
65
current_range_as_column = op .Reshape (current_range , [- 1 , 1 ])
57
- boolean_mask = op .Greater (total_range_as_row , current_range_as_column )
58
- float_0_1_mask = op .Cast (boolean_mask , to = 1 )
66
+
67
+ non_causal = op .Greater (total_range_as_row , current_range_as_column )
68
+
69
+ # sliding window support:
70
+ current_range_minus_window = op .Sub (current_range_as_column , window_size )
71
+ out_of_sliding_window = op .LessOrEqual (total_range_as_row , current_range_minus_window )
72
+ non_causal_sliding_window = op .Or (non_causal , out_of_sliding_window )
73
+
74
+ boolean_mask = pattern .OrValue ([non_causal , non_causal_sliding_window ])
75
+
76
+ float_0_1_mask = op .Cast (boolean_mask , to = dtype )
59
77
float_0_min_mask = op .Mul (mask_all_min , float_0_1_mask )
60
- mask_4d = op .Unsqueeze (float_0_min_mask , [0 , 1 ])
61
- mask_B1ST_plus = op .Expand (mask_4d , shape_B111 )
78
+ mask_4d_11ST = op .Unsqueeze (float_0_min_mask , [0 , 1 ])
79
+ mask_4d_B1ST = op .Expand (mask_4d_11ST , shape_B111 )
80
+
81
+ return mask_4d_B1ST
82
+
83
+
84
+ class _CausalMaskPattern (pattern .PatternBase ):
85
+ def pattern (
86
+ self ,
87
+ op ,
88
+ input_ids ,
89
+ past_kv_cache ,
90
+ shape_B111 ,
91
+ min_val ,
92
+ window_size ,
93
+ dtype1 ,
94
+ attn_mask_2d ,
95
+ dtype2 ,
96
+ ):
97
+ causal_mask = _causal_mask (
98
+ op ,
99
+ input_ids ,
100
+ past_kv_cache ,
101
+ shape_B111 ,
102
+ min_val ,
103
+ window_size ,
104
+ dtype1 ,
105
+ )
106
+
107
+ attn_mask_4d = op .Unsqueeze (attn_mask_2d , [1 , 2 ])
108
+ attn_mask_4d_cast = op .Cast (attn_mask_4d , to = dtype2 )
109
+
110
+ sum = op .Add (causal_mask , attn_mask_4d_cast )
111
+ sum_fp32 = op .Cast (sum , to = ir .DataType .FLOAT )
112
+ # The cast is optional, and may be absent if the sum is already in float32.
113
+ sum_fp32 = pattern .OrValue ([sum_fp32 , sum ])
114
+ is_zero = op .Equal (sum_fp32 , 0.0 )
115
+ result = op .Where (is_zero , min_val , causal_mask )
116
+ return result
117
+
118
+ def check (self , context , dtype1 , dtype2 , min_val , attn_mask_2d , sliding_window = None , ** _ ):
119
+ # Check that attn_mask_2d is the model input "attention_mask"
120
+ if not _is_model_input (attn_mask_2d , "attention_mask" , context .model ):
121
+ return pattern .MatchResult ().fail ("Invalid attention_mask input" , attn_mask_2d )
122
+
123
+ if dtype1 .as_int () != dtype2 .as_int ():
124
+ return pattern .MatchResult ().fail ("Dtype mismatch" , [dtype1 , dtype2 ])
125
+
126
+ # Check that min_val is a constant and matches the expected minimum value for the dtype.
127
+ min_value = _ir_utils .get_singleton_value (min_val )
128
+ if min_value is None :
129
+ return pattern .MatchResult ().fail ("Minval is not a constant." , min_val )
130
+ expected_min_value = np .finfo (min_val .dtype .numpy ()).min
131
+ if min_value != expected_min_value :
132
+ return pattern .MatchResult ().fail (
133
+ f"Expected min value { expected_min_value } , got { min_value } " , min_val
134
+ )
135
+
136
+ # TODO(rama) Sliding window: not yet supported.
137
+ if sliding_window :
138
+ return pattern .MatchResult ().fail (
139
+ "Sliding window not yet supported" , sliding_window
140
+ )
141
+ return True
62
142
63
- # Get rid of the extra +1 added above: total_seq_len is enough, no
64
- # need for total_seq_len+1.
65
- mask_B1ST = op .Slice (mask_B1ST_plus , [0 ], total_seq_len , [3 ], [1 ])
66
- return mask_B1ST
143
+
144
+ _causal_mask_pattern = _CausalMaskPattern ()
67
145
68
146
69
147
class GroupQueryAttention (pattern .RewriteRuleClassBase ):
@@ -78,8 +156,7 @@ def pattern(
78
156
value_BSDkv ,
79
157
past_key ,
80
158
past_value ,
81
- position_ids_q ,
82
- position_ids_k ,
159
+ position_ids ,
83
160
cos ,
84
161
sin ,
85
162
mask ,
@@ -101,15 +178,15 @@ def pattern(
101
178
102
179
query_BHSDh_rope = op .RotaryEmbedding (
103
180
query_BHSDh ,
104
- position_ids_q ,
181
+ position_ids ,
105
182
cos ,
106
183
sin ,
107
184
_domain = "com.microsoft" ,
108
185
_outputs = ["query_BHSDh_rope" ],
109
186
)
110
187
key_BHkvSDh_rope = op .RotaryEmbedding (
111
188
key_BHkvSDh ,
112
- position_ids_k ,
189
+ position_ids ,
113
190
cos ,
114
191
sin ,
115
192
_domain = "com.microsoft" ,
@@ -154,7 +231,7 @@ def pattern(
154
231
155
232
def check (
156
233
self ,
157
- op ,
234
+ context : _basics . MatchContext ,
158
235
query_BSD ,
159
236
key_BSDkv ,
160
237
value_BSDkv ,
@@ -164,6 +241,7 @@ def check(
164
241
key_BHkvSDh_rope ,
165
242
query_BSHDh ,
166
243
key_BSHkvDh ,
244
+ mask ,
167
245
** _ ,
168
246
):
169
247
bindings : dict [str , Dim ] = {}
@@ -210,6 +288,20 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
210
288
)
211
289
self ._interleaved = query_interleaved
212
290
291
+ # Check mask:
292
+ mask_node = mask .producer ()
293
+ if mask_node is None :
294
+ return pattern .MatchResult ().fail ("Unhandled mask pattern" , mask )
295
+ mask_match_result = _causal_mask_pattern .match (
296
+ context .model ,
297
+ context .graph_or_function ,
298
+ mask_node ,
299
+ check_nodes_are_removable = False ,
300
+ )
301
+ if mask_match_result is None :
302
+ return pattern .MatchResult ().fail ("Mask does not match causal mask pattern" , mask )
303
+ # TODO: handle sliding window support in mask
304
+
213
305
return True
214
306
215
307
def rewrite (
@@ -220,104 +312,51 @@ def rewrite(
220
312
value_BSDkv ,
221
313
past_key ,
222
314
past_value ,
223
- position_ids_q ,
224
- position_ids_k ,
315
+ position_ids ,
225
316
cos ,
226
317
sin ,
227
318
mask ,
228
319
** _ ,
229
320
):
230
- return op .GQA (
231
- mask ,
232
- position_ids_k ,
233
- position_ids_q ,
321
+ # Note that the following optimization is specific to current ORT GenAI attention-mask
322
+ # usage. Specifically, it assumes that the model-input "attention_mask" is a 2D
323
+ # mask with shape [batch_size, sequence_length], and that the mask is a 0/1 mask
324
+ # that is used only to indicate the current tokens. Hence, the input attention_mask
325
+ # is redundant as long as past-sequence-length and current-sequence-length can be
326
+ # computed.
327
+
328
+ # Construct seqlens_k and total_seq_length_int32 from position_ids
329
+ # seqlens_k : int32[batch_size] indicates total_sequence-length-1 for each batch
330
+ # position_ids: int64[batch_size, sequence_length] indicates the position of each token
331
+ one_int32_0d = op .Constant (value = ir .tensor (1 , dtype = ir .DataType .INT32 ))
332
+ one_int64_1d = op .Constant (value = ir .tensor ([1 ], dtype = ir .DataType .INT64 ))
333
+ zero_int64_1d = op .Constant (value = ir .tensor ([0 ], dtype = ir .DataType .INT64 ))
334
+ seqlens_k_int64 = op .ReduceMax (position_ids , one_int64_1d , keepdims = 0 )
335
+ seqlens_k = op .Cast (seqlens_k_int64 , to = ir .DataType .INT32 )
336
+ max_seq_length = op .ReduceMax (seqlens_k , zero_int64_1d , keepdims = 0 )
337
+ total_seq_length_int32 = op .Add (max_seq_length , one_int32_0d )
338
+ return op .GroupQueryAttention (
234
339
query_BSD ,
235
340
key_BSDkv ,
236
341
value_BSDkv ,
237
342
past_key ,
238
343
past_value ,
239
- None , # seqlens_k,
240
- None , # total_seq_length_int32,
344
+ seqlens_k ,
345
+ total_seq_length_int32 ,
241
346
cos ,
242
347
sin ,
243
348
num_heads = self .num_heads ,
244
349
kv_num_heads = self .kv_num_heads ,
245
350
do_rotary = 1 ,
246
351
rotary_interleaved = self ._interleaved ,
247
352
# skipped optional attributes: local_window_size, scale, smooth_softmax, softcap
248
- _domain = "ai.onnxruntime._fusion " ,
353
+ _domain = "com.microsoft " ,
249
354
_outputs = 3 ,
250
355
)
251
356
252
357
253
- class GQACausalMask (pattern .RewriteRuleClassBase ):
254
- def __init__ (self ):
255
- super ().__init__ ("GQACausalMask" , remove_nodes = False )
256
-
257
- def pattern (
258
- self ,
259
- op ,
260
- mask ,
261
- input_ids ,
262
- some_kv_cache ,
263
- shape_B111 ,
264
- past_seq_length ,
265
- total_seq_length ,
266
- ):
267
- mask = causal_mask_pattern (op , input_ids , some_kv_cache , shape_B111 )
268
- position_ids = op .Range (past_seq_length , total_seq_length , 1 )
269
- position_ids_q = op .Unsqueeze (position_ids , [0 ])
270
- position_ids_k = op .Unsqueeze (position_ids , [0 ])
271
- return op .GQA (
272
- mask ,
273
- position_ids_k ,
274
- position_ids_q ,
275
- _allow_other_inputs = True ,
276
- _domain = "ai.onnxruntime._fusion" ,
277
- _outputs = ["attn_output" , "key_seq" , "value_seq" ],
278
- )
279
-
280
- def rewrite (
281
- self ,
282
- op ,
283
- total_seq_length ,
284
- attn_output ,
285
- ** _ ,
286
- ):
287
- # Construct total_seq_length_int32 and seqlens_k
288
- total_seq_length_int32 = op .Cast (total_seq_length , to = ir .DataType .INT32 )
289
- one_0D = op .Constant (value_int = 1 )
290
- one_0D_int32 = op .Cast (one_0D , to = ir .DataType .INT32 )
291
- seqlens_k_0D = op .Sub (total_seq_length_int32 , one_0D_int32 )
292
- zero_1D = op .Constant (value_int = 0 , dtype = ir .DataType .INT64 , shape = [1 ])
293
- seqlens_k = op .Unsqueeze (seqlens_k_0D , zero_1D )
294
-
295
- gqa_node = attn_output .producer ()
296
- assert len (gqa_node .inputs ) == 12 , (
297
- f"Expected 12 inputs for GQA node, got { len (gqa_node .inputs )} "
298
- )
299
- query , key , value , past_key , past_value = gqa_node .inputs [3 :8 ]
300
- cos , sin = gqa_node .inputs [10 :12 ]
301
- updated_inputs = [
302
- query ,
303
- key ,
304
- value ,
305
- past_key ,
306
- past_value ,
307
- seqlens_k ,
308
- total_seq_length_int32 ,
309
- cos ,
310
- sin ,
311
- ]
312
- attributes = gqa_node .attributes
313
- return op .GroupQueryAttention (
314
- * updated_inputs , ** attributes , _domain = "com.microsoft" , _outputs = 3
315
- )
316
-
317
-
318
358
_basic_gqa_rule = GroupQueryAttention .rule ()
319
- _gqa_causal_mask_rule = GQACausalMask .rule ()
320
359
321
- gqa_rules = pattern .RewriteRuleSet ([_basic_gqa_rule , _gqa_causal_mask_rule ])
360
+ gqa_rules = pattern .RewriteRuleSet ([_basic_gqa_rule ])
322
361
323
362
fuse_gqa = _fusion_utils .apply_fusion_rules (gqa_rules )
0 commit comments