Skip to content

Commit 1197a17

Browse files
committed
keep the dispatch path unchanged
1 parent c8231b6 commit 1197a17

File tree

5 files changed

+87
-46
lines changed

5 files changed

+87
-46
lines changed

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
3131
const auto& present_key = shader.AddOutput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
3232
const auto& present_value = shader.AddOutput("present_value", ShaderUsage::UseUniform);
3333
const auto& copy_kv_shape = shader.AddIndices("copy_kv_shape");
34-
shader.AddInput("seqlen_k", ShaderUsage::None);
3534
// If prepare_indirect_dispatch is enabled, add seqlen_k input and indirect_buffer output
3635
if (prepare_indirect_dispatch_) {
36+
shader.AddInput("seqlen_k", ShaderUsage::None);
3737
shader.AddOutput("indirect_buffer", ShaderUsage::None);
3838
}
3939

@@ -42,8 +42,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
4242
<< " let head_size_id = output_indices[3];\n"
4343
" let sequence_id = output_indices[2];\n"
4444
" let num_head_id = output_indices[1];\n"
45-
" let batch = output_indices[0];\n"
46-
" let total_seq_length = u32(seqlen_k[0u]) + 1u;\n";
45+
" let batch = output_indices[0];\n";
46+
if (prepare_indirect_dispatch_) {
47+
shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[0u]) + 1u;\n";
48+
} else {
49+
shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n";
50+
}
4751

4852
// Add indirect dispatch logic for thread 0
4953
if (prepare_indirect_dispatch_) {
@@ -89,7 +93,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
8993
Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters,
9094
const Tensor* K, const Tensor* past_key, Tensor* present_key,
9195
const Tensor* V, const Tensor* past_value, Tensor* present_value,
92-
const Tensor* seqlen_k, Tensor* indirect_buffer) {
96+
uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer) {
9397
// CopyKVCache takes past key/value and current key/value and copies them to present key and value.
9498
// This makes it so that FlashAttention only needs to look at present key and value, and saves
9599
// number of input buffers in the shader, which we run out of (<=8) without this optimization.
@@ -106,10 +110,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
106110

107111
// Determine if we need to prepare indirect dispatch
108112
bool prepare_indirect_dispatch = (indirect_buffer != nullptr);
109-
constexpr uint32_t tile_size = 64;
110113

111114
CopyKVCacheProgram program{"CopyKVCache", has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH, parameters.past_present_share_buffer_,
112-
prepare_indirect_dispatch, tile_size, static_cast<uint32_t>(parameters.num_heads_)};
115+
prepare_indirect_dispatch};
113116
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) {
114117
program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
115118
{V, ProgramTensorMetadataDependency::TypeAndRank, components}});
@@ -121,7 +124,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
121124
{V, ProgramTensorMetadataDependency::TypeAndRank, reshaped_KV_shape, components}});
122125
}
123126

124-
if (seqlen_k != nullptr) {
127+
if (prepare_indirect_dispatch) {
125128
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None});
126129
}
127130

@@ -132,7 +135,6 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
132135
program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components},
133136
{present_value, ProgramTensorMetadataDependency::Rank, components}});
134137

135-
// Add indirect_buffer output if preparing indirect dispatch
136138
if (prepare_indirect_dispatch) {
137139
program.AddOutput({indirect_buffer, ProgramTensorMetadataDependency::None});
138140
}
@@ -142,6 +144,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
142144
.SetWorkgroupSize(64)
143145
.CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_, prepare_indirect_dispatch)
144146
.AddUniformVariables({{static_cast<uint32_t>(copy_size)},
147+
{static_cast<uint32_t>(parameters.total_sequence_length_)},
145148
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
146149
{tile_size},
147150
{static_cast<uint32_t>(parameters.num_heads_)}});
@@ -184,7 +187,9 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
184187
Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const {
185188
shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
186189
shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
187-
shader.AddInput("seqlens_k", ShaderUsage::None);
190+
if (use_indirect_dispatch_) {
191+
shader.AddInput("seqlens_k", ShaderUsage::None);
192+
}
188193
if (has_attention_bias_) {
189194
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
190195
}
@@ -197,7 +202,8 @@ Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader)
197202
WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_),
198203
WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count),
199204
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
200-
WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec));
205+
WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec),
206+
WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_));
201207
}
202208

203209
Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q,
@@ -209,10 +215,12 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
209215
const bool has_attention_bias = attention_bias != nullptr;
210216
const int components = 4;
211217

212-
FlashAttentionDecodeQKTProgram program{"FlashAttentionDecodeQKT", has_attention_bias, tile_size};
218+
FlashAttentionDecodeQKTProgram program{"FlashAttentionDecodeQKT", has_attention_bias, tile_size, use_indirect_dispatch};
213219
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
214-
{present_key, ProgramTensorMetadataDependency::TypeAndRank, components},
215-
{seqlen_k, ProgramTensorMetadataDependency::None}});
220+
{present_key, ProgramTensorMetadataDependency::TypeAndRank, components}});
221+
if (use_indirect_dispatch) {
222+
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None});
223+
}
216224
if (has_attention_bias) {
217225
program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank});
218226
}
@@ -226,8 +234,9 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
226234
program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile);
227235
}
228236
program.SetWorkgroupSize(64)
229-
.CacheHint(tile_size, has_attention_bias)
237+
.CacheHint(tile_size, has_attention_bias, use_indirect_dispatch)
230238
.AddUniformVariables({{static_cast<uint32_t>(vectorized_head_size)},
239+
{static_cast<uint32_t>(parameters.total_sequence_length_)},
231240
{static_cast<float>(alpha)},
232241
{static_cast<uint32_t>(present_sequence_length)},
233242
{static_cast<uint32_t>(parameters.n_reps)},
@@ -241,7 +250,9 @@ Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shad
241250
shader.AddInput("metadata", ShaderUsage::UseUniform);
242251
shader.AddInput("qk", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
243252
shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
244-
shader.AddInput("seqlens_k", ShaderUsage::None);
253+
if (use_indirect_dispatch_) {
254+
shader.AddInput("seqlens_k", ShaderUsage::None);
255+
}
245256
shader.AddOutput("out_split_vx", ShaderUsage::UseUniform);
246257

247258
const uint32_t tile_size_k_vec = 8u;
@@ -250,7 +261,8 @@ Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shad
250261
WGSL_TEMPLATE_PARAMETER(head_size_vec, head_size_vec_),
251262
WGSL_TEMPLATE_PARAMETER(sub_tile_count, WorkgroupSizeX() / tile_size_k_vec),
252263
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
253-
WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec));
264+
WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec),
265+
WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_));
254266
}
255267

256268
Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeContext& context,
@@ -268,20 +280,21 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
268280
uint32_t present_sequence_length) {
269281
const int components = 4;
270282
int head_size_vec = parameters.v_head_size_ / components;
271-
FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec};
283+
FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec, use_indirect_dispatch};
272284
program.AddInputs({{metadata, ProgramTensorMetadataDependency::TypeAndRank, 2},
273285
{qk, ProgramTensorMetadataDependency::TypeAndRank},
274-
{present_value, ProgramTensorMetadataDependency::TypeAndRank, components},
275-
{seqlen_k, ProgramTensorMetadataDependency::None}});
286+
{present_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
276287
program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size]
277288
if (use_indirect_dispatch) {
278-
program.SetIndirectDispatchTensor(indirect_buffer);
289+
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None})
290+
.SetIndirectDispatchTensor(indirect_buffer);
279291
} else {
280292
program.SetDispatchGroupSize(parameters.num_heads_ * num_total_seq_length_tile);
281293
}
282-
program.CacheHint(tile_size, head_size_vec)
294+
program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch)
283295
.SetWorkgroupSize(64)
284-
.AddUniformVariables({{static_cast<uint32_t>(head_size_vec)},
296+
.AddUniformVariables({{static_cast<uint32_t>(parameters.total_sequence_length_)},
297+
{static_cast<uint32_t>(head_size_vec)},
285298
{static_cast<uint32_t>(present_sequence_length)},
286299
{static_cast<uint32_t>(parameters.n_reps)},
287300
num_present_sequence_length_tile,
@@ -292,31 +305,39 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
292305

293306
Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const {
294307
shader.AddInput("input", ShaderUsage::UseUniform);
295-
shader.AddInput("seqlens_k", ShaderUsage::None);
308+
if (use_indirect_dispatch_) {
309+
shader.AddInput("seqlens_k", ShaderUsage::None);
310+
}
296311
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
297312

298313
return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_vx_reduce.wgsl.template",
299-
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_));
314+
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
315+
WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_));
300316
}
301317

302318
Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& context,
303319
const Tensor* out_split_vx,
304320
Tensor* output,
305321
const Tensor* seqlen_k,
306322
const WebgpuAttentionParameters& parameters,
307-
uint32_t num_present_sequence_length_tile) {
323+
uint32_t num_total_seq_length_tile,
324+
uint32_t num_present_sequence_length_tile,
325+
bool use_indirect_dispatch) {
308326
const int components = 4;
309327
constexpr int tile_size = 8;
310328
int tile_head_size = tile_size * components;
311-
FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size};
312-
program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components},
313-
{seqlen_k, ProgramTensorMetadataDependency::None}});
329+
FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, use_indirect_dispatch};
330+
program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}});
331+
if (use_indirect_dispatch) {
332+
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None});
333+
}
314334
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}});
315335
const uint32_t num_head_size_tile = static_cast<uint32_t>((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size);
316336
program.SetDispatchGroupSize(parameters.num_heads_ * num_head_size_tile)
317-
.CacheHint(tile_size)
337+
.CacheHint(tile_size, use_indirect_dispatch)
318338
.SetWorkgroupSize(tile_size * tile_size)
319339
.AddUniformVariables({{static_cast<uint32_t>(parameters.v_head_size_ / components)},
340+
num_total_seq_length_tile,
320341
num_present_sequence_length_tile,
321342
{num_head_size_tile},
322343
{static_cast<uint32_t>(parameters.num_heads_)}});
@@ -332,10 +353,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
332353
const int present_sequence_length = static_cast<int>(present_key->Shape()[2]);
333354

334355
if (parameters.sequence_length_ > 1) {
335-
// For encode path, use the original CopyKVCache without indirect dispatch preparation
336-
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, seqlen_k, nullptr));
337-
338356
const uint32_t tile_size = 64;
357+
// For encode path, use the original CopyKVCache without indirect dispatch preparation
358+
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr));
339359
bool has_attention_bias = attention_bias != nullptr;
340360
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
341361
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
@@ -394,10 +414,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
394414
indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), indirect_buffer_shape);
395415
indirect_buffer_ptr = &indirect_buffer;
396416
// Use the fused CopyKVCache that also prepares the indirect dispatch buffer
397-
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, seqlen_k, indirect_buffer_ptr));
417+
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, indirect_buffer_ptr));
398418
} else {
399419
// Use the original CopyKVCache without indirect dispatch preparation
400-
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, seqlen_k, nullptr));
420+
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr));
401421
}
402422

403423
// The metadata is used to store the max and sum of each tile.
@@ -420,7 +440,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
420440
num_present_sequence_length_tile, tile_size,
421441
use_indirect_dispatch, present_sequence_length));
422442
ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, seqlen_k, parameters,
423-
num_present_sequence_length_tile));
443+
num_total_seq_length_tile,
444+
num_present_sequence_length_tile, use_indirect_dispatch));
424445

425446
return Status::OK();
426447
}

0 commit comments

Comments
 (0)