@@ -31,9 +31,9 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
3131 const auto & present_key = shader.AddOutput (" present_key" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
3232 const auto & present_value = shader.AddOutput (" present_value" , ShaderUsage::UseUniform);
3333 const auto & copy_kv_shape = shader.AddIndices (" copy_kv_shape" );
34- shader.AddInput (" seqlen_k" , ShaderUsage::None);
3534 // If prepare_indirect_dispatch is enabled, add seqlen_k input and indirect_buffer output
3635 if (prepare_indirect_dispatch_) {
36+ shader.AddInput (" seqlen_k" , ShaderUsage::None);
3737 shader.AddOutput (" indirect_buffer" , ShaderUsage::None);
3838 }
3939
@@ -42,8 +42,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
4242 << " let head_size_id = output_indices[3];\n "
4343 " let sequence_id = output_indices[2];\n "
4444 " let num_head_id = output_indices[1];\n "
45- " let batch = output_indices[0];\n "
46- " let total_seq_length = u32(seqlen_k[0u]) + 1u;\n " ;
45+ " let batch = output_indices[0];\n " ;
46+ if (prepare_indirect_dispatch_) {
47+ shader.MainFunctionBody () << " let total_seq_length = u32(seqlen_k[0u]) + 1u;\n " ;
48+ } else {
49+ shader.MainFunctionBody () << " let total_seq_length = uniforms.total_sequence_length;\n " ;
50+ }
4751
4852 // Add indirect dispatch logic for thread 0
4953 if (prepare_indirect_dispatch_) {
@@ -89,7 +93,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
8993Status CopyKVCache (onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters,
9094 const Tensor* K, const Tensor* past_key, Tensor* present_key,
9195 const Tensor* V, const Tensor* past_value, Tensor* present_value,
92- const Tensor* seqlen_k, Tensor* indirect_buffer) {
96+ uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer) {
9397 // CopyKVCache takes past key/value and current key/value and copies them to present key and value.
9498 // This makes it so that FlashAttention only needs to look at present key and value, and saves
9599 // number of input buffers in the shader, which we run out of (<=8) without this optimization.
@@ -106,10 +110,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
106110
107111 // Determine if we need to prepare indirect dispatch
108112 bool prepare_indirect_dispatch = (indirect_buffer != nullptr );
109- constexpr uint32_t tile_size = 64 ;
110113
111114 CopyKVCacheProgram program{" CopyKVCache" , has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH, parameters.past_present_share_buffer_ ,
112- prepare_indirect_dispatch, tile_size, static_cast < uint32_t >(parameters. num_heads_ ) };
115+ prepare_indirect_dispatch};
113116 if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) {
114117 program.AddInputs ({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
115118 {V, ProgramTensorMetadataDependency::TypeAndRank, components}});
@@ -121,7 +124,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
121124 {V, ProgramTensorMetadataDependency::TypeAndRank, reshaped_KV_shape, components}});
122125 }
123126
124- if (seqlen_k != nullptr ) {
127+ if (prepare_indirect_dispatch ) {
125128 program.AddInput ({seqlen_k, ProgramTensorMetadataDependency::None});
126129 }
127130
@@ -132,7 +135,6 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
132135 program.AddOutputs ({{present_key, ProgramTensorMetadataDependency::Rank, components},
133136 {present_value, ProgramTensorMetadataDependency::Rank, components}});
134137
135- // Add indirect_buffer output if preparing indirect dispatch
136138 if (prepare_indirect_dispatch) {
137139 program.AddOutput ({indirect_buffer, ProgramTensorMetadataDependency::None});
138140 }
@@ -142,6 +144,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
142144 .SetWorkgroupSize (64 )
143145 .CacheHint (has_past, parameters.qkv_format_ , parameters.past_present_share_buffer_ , prepare_indirect_dispatch)
144146 .AddUniformVariables ({{static_cast <uint32_t >(copy_size)},
147+ {static_cast <uint32_t >(parameters.total_sequence_length_ )},
145148 {static_cast <uint32_t >(parameters.kv_sequence_length_ )},
146149 {tile_size},
147150 {static_cast <uint32_t >(parameters.num_heads_ )}});
@@ -184,7 +187,9 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
184187Status FlashAttentionDecodeQKTProgram::GenerateShaderCode (ShaderHelper& shader) const {
185188 shader.AddInput (" q" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
186189 shader.AddInput (" present_key" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
187- shader.AddInput (" seqlens_k" , ShaderUsage::None);
190+ if (use_indirect_dispatch_) {
191+ shader.AddInput (" seqlens_k" , ShaderUsage::None);
192+ }
188193 if (has_attention_bias_) {
189194 shader.AddInput (" attention_bias" , ShaderUsage::UseUniform);
190195 }
@@ -197,7 +202,8 @@ Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader)
197202 WGSL_TEMPLATE_PARAMETER (has_attention_bias, has_attention_bias_),
198203 WGSL_TEMPLATE_PARAMETER (sub_tile_count, sub_tile_count),
199204 WGSL_TEMPLATE_PARAMETER (tile_size, tile_size_),
200- WGSL_TEMPLATE_PARAMETER (tile_size_k_vec, tile_size_k_vec));
205+ WGSL_TEMPLATE_PARAMETER (tile_size_k_vec, tile_size_k_vec),
206+ WGSL_TEMPLATE_PARAMETER (use_indirect_dispatch, use_indirect_dispatch_));
201207}
202208
203209Status ComputeFlashAttentionDecodeQKT (onnxruntime::webgpu::ComputeContext& context, const Tensor* Q,
@@ -209,10 +215,12 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
209215 const bool has_attention_bias = attention_bias != nullptr ;
210216 const int components = 4 ;
211217
212- FlashAttentionDecodeQKTProgram program{" FlashAttentionDecodeQKT" , has_attention_bias, tile_size};
218+ FlashAttentionDecodeQKTProgram program{" FlashAttentionDecodeQKT" , has_attention_bias, tile_size, use_indirect_dispatch };
213219 program.AddInputs ({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
214- {present_key, ProgramTensorMetadataDependency::TypeAndRank, components},
215- {seqlen_k, ProgramTensorMetadataDependency::None}});
220+ {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}});
221+ if (use_indirect_dispatch) {
222+ program.AddInput ({seqlen_k, ProgramTensorMetadataDependency::None});
223+ }
216224 if (has_attention_bias) {
217225 program.AddInput ({attention_bias, ProgramTensorMetadataDependency::TypeAndRank});
218226 }
@@ -226,8 +234,9 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
226234 program.SetDispatchGroupSize (parameters.num_heads_ * num_total_seq_length_tile);
227235 }
228236 program.SetWorkgroupSize (64 )
229- .CacheHint (tile_size, has_attention_bias)
237+ .CacheHint (tile_size, has_attention_bias, use_indirect_dispatch )
230238 .AddUniformVariables ({{static_cast <uint32_t >(vectorized_head_size)},
239+ {static_cast <uint32_t >(parameters.total_sequence_length_ )},
231240 {static_cast <float >(alpha)},
232241 {static_cast <uint32_t >(present_sequence_length)},
233242 {static_cast <uint32_t >(parameters.n_reps )},
@@ -241,7 +250,9 @@ Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shad
241250 shader.AddInput (" metadata" , ShaderUsage::UseUniform);
242251 shader.AddInput (" qk" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
243252 shader.AddInput (" present_value" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
244- shader.AddInput (" seqlens_k" , ShaderUsage::None);
253+ if (use_indirect_dispatch_) {
254+ shader.AddInput (" seqlens_k" , ShaderUsage::None);
255+ }
245256 shader.AddOutput (" out_split_vx" , ShaderUsage::UseUniform);
246257
247258 const uint32_t tile_size_k_vec = 8u ;
@@ -250,7 +261,8 @@ Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shad
250261 WGSL_TEMPLATE_PARAMETER (head_size_vec, head_size_vec_),
251262 WGSL_TEMPLATE_PARAMETER (sub_tile_count, WorkgroupSizeX () / tile_size_k_vec),
252263 WGSL_TEMPLATE_PARAMETER (tile_size, tile_size_),
253- WGSL_TEMPLATE_PARAMETER (tile_size_k_vec, tile_size_k_vec));
264+ WGSL_TEMPLATE_PARAMETER (tile_size_k_vec, tile_size_k_vec),
265+ WGSL_TEMPLATE_PARAMETER (use_indirect_dispatch, use_indirect_dispatch_));
254266}
255267
256268Status ComputeFlashAttentionDecodeSplitVxScore (onnxruntime::webgpu::ComputeContext& context,
@@ -268,20 +280,21 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
268280 uint32_t present_sequence_length) {
269281 const int components = 4 ;
270282 int head_size_vec = parameters.v_head_size_ / components;
271- FlashAttentionDecodeSplitVxProgram program{" FlashAttentionDecodeSplitVx" , tile_size, head_size_vec};
283+ FlashAttentionDecodeSplitVxProgram program{" FlashAttentionDecodeSplitVx" , tile_size, head_size_vec, use_indirect_dispatch };
272284 program.AddInputs ({{metadata, ProgramTensorMetadataDependency::TypeAndRank, 2 },
273285 {qk, ProgramTensorMetadataDependency::TypeAndRank},
274- {present_value, ProgramTensorMetadataDependency::TypeAndRank, components},
275- {seqlen_k, ProgramTensorMetadataDependency::None}});
286+ {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
276287 program.AddOutputs ({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size]
277288 if (use_indirect_dispatch) {
278- program.SetIndirectDispatchTensor (indirect_buffer);
289+ program.AddInput ({seqlen_k, ProgramTensorMetadataDependency::None})
290+ .SetIndirectDispatchTensor (indirect_buffer);
279291 } else {
280292 program.SetDispatchGroupSize (parameters.num_heads_ * num_total_seq_length_tile);
281293 }
282- program.CacheHint (tile_size, head_size_vec)
294+ program.CacheHint (tile_size, head_size_vec, use_indirect_dispatch )
283295 .SetWorkgroupSize (64 )
284- .AddUniformVariables ({{static_cast <uint32_t >(head_size_vec)},
296+ .AddUniformVariables ({{static_cast <uint32_t >(parameters.total_sequence_length_ )},
297+ {static_cast <uint32_t >(head_size_vec)},
285298 {static_cast <uint32_t >(present_sequence_length)},
286299 {static_cast <uint32_t >(parameters.n_reps )},
287300 num_present_sequence_length_tile,
@@ -292,31 +305,39 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
292305
293306Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode (ShaderHelper& shader) const {
294307 shader.AddInput (" input" , ShaderUsage::UseUniform);
295- shader.AddInput (" seqlens_k" , ShaderUsage::None);
308+ if (use_indirect_dispatch_) {
309+ shader.AddInput (" seqlens_k" , ShaderUsage::None);
310+ }
296311 shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
297312
298313 return WGSL_TEMPLATE_APPLY (shader, " bert/flash_attention_decode_vx_reduce.wgsl.template" ,
299- WGSL_TEMPLATE_PARAMETER (tile_size, tile_size_));
314+ WGSL_TEMPLATE_PARAMETER (tile_size, tile_size_),
315+ WGSL_TEMPLATE_PARAMETER (use_indirect_dispatch, use_indirect_dispatch_));
300316}
301317
302318Status ComputeFlashAttentionDecodeVxReduce (onnxruntime::webgpu::ComputeContext& context,
303319 const Tensor* out_split_vx,
304320 Tensor* output,
305321 const Tensor* seqlen_k,
306322 const WebgpuAttentionParameters& parameters,
307- uint32_t num_present_sequence_length_tile) {
323+ uint32_t num_total_seq_length_tile,
324+ uint32_t num_present_sequence_length_tile,
325+ bool use_indirect_dispatch) {
308326 const int components = 4 ;
309327 constexpr int tile_size = 8 ;
310328 int tile_head_size = tile_size * components;
311- FlashAttentionDecodeVxReduceProgram program{" FlashAttentionDecodeVxReduce" , tile_size};
312- program.AddInputs ({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components},
313- {seqlen_k, ProgramTensorMetadataDependency::None}});
329+ FlashAttentionDecodeVxReduceProgram program{" FlashAttentionDecodeVxReduce" , tile_size, use_indirect_dispatch};
330+ program.AddInputs ({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}});
331+ if (use_indirect_dispatch) {
332+ program.AddInput ({seqlen_k, ProgramTensorMetadataDependency::None});
333+ }
314334 program.AddOutputs ({{output, ProgramTensorMetadataDependency::TypeAndRank, components}});
315335 const uint32_t num_head_size_tile = static_cast <uint32_t >((parameters.v_head_size_ + tile_head_size - 1 ) / tile_head_size);
316336 program.SetDispatchGroupSize (parameters.num_heads_ * num_head_size_tile)
317- .CacheHint (tile_size)
337+ .CacheHint (tile_size, use_indirect_dispatch )
318338 .SetWorkgroupSize (tile_size * tile_size)
319339 .AddUniformVariables ({{static_cast <uint32_t >(parameters.v_head_size_ / components)},
340+ num_total_seq_length_tile,
320341 num_present_sequence_length_tile,
321342 {num_head_size_tile},
322343 {static_cast <uint32_t >(parameters.num_heads_ )}});
@@ -332,10 +353,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
332353 const int present_sequence_length = static_cast <int >(present_key->Shape ()[2 ]);
333354
334355 if (parameters.sequence_length_ > 1 ) {
335- // For encode path, use the original CopyKVCache without indirect dispatch preparation
336- ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, seqlen_k, nullptr ));
337-
338356 const uint32_t tile_size = 64 ;
357+ // For encode path, use the original CopyKVCache without indirect dispatch preparation
358+ ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr ));
339359 bool has_attention_bias = attention_bias != nullptr ;
340360 bool is_qualcomm = context.AdapterInfo ().vendor == std::string_view{" qualcomm" };
341361 bool is_nvidia = context.AdapterInfo ().vendor == std::string_view{" nvidia" };
@@ -394,10 +414,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
394414 indirect_buffer = context.CreateGPUTensor (DataTypeImpl::GetType<uint32_t >(), indirect_buffer_shape);
395415 indirect_buffer_ptr = &indirect_buffer;
396416 // Use the fused CopyKVCache that also prepares the indirect dispatch buffer
397- ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, seqlen_k, indirect_buffer_ptr));
417+ ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, indirect_buffer_ptr));
398418 } else {
399419 // Use the original CopyKVCache without indirect dispatch preparation
400- ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, seqlen_k, nullptr ));
420+ ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr ));
401421 }
402422
403423 // The metadata is used to store the max and sum of each tile.
@@ -420,7 +440,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
420440 num_present_sequence_length_tile, tile_size,
421441 use_indirect_dispatch, present_sequence_length));
422442 ORT_RETURN_IF_ERROR (ComputeFlashAttentionDecodeVxReduce (context, &out_split_vx, output, seqlen_k, parameters,
423- num_present_sequence_length_tile));
443+ num_total_seq_length_tile,
444+ num_present_sequence_length_tile, use_indirect_dispatch));
424445
425446 return Status::OK ();
426447}
0 commit comments