Skip to content

Commit c8231b6

Browse files
committed
fix
1 parent bc4b41e commit c8231b6

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
3131
const auto& present_key = shader.AddOutput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
3232
const auto& present_value = shader.AddOutput("present_value", ShaderUsage::UseUniform);
3333
const auto& copy_kv_shape = shader.AddIndices("copy_kv_shape");
34-
shader.AddInput("seqlen_k");
34+
shader.AddInput("seqlen_k", ShaderUsage::None);
3535
// If prepare_indirect_dispatch is enabled, add seqlen_k input and indirect_buffer output
3636
if (prepare_indirect_dispatch_) {
37-
shader.AddOutput("indirect_buffer", ShaderUsage::UseUniform);
37+
shader.AddOutput("indirect_buffer", ShaderUsage::None);
3838
}
3939

4040
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.copy_size")
@@ -184,7 +184,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
184184
Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const {
185185
shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
186186
shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
187-
shader.AddInput("seqlens_k");
187+
shader.AddInput("seqlens_k", ShaderUsage::None);
188188
if (has_attention_bias_) {
189189
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
190190
}
@@ -241,7 +241,7 @@ Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shad
241241
shader.AddInput("metadata", ShaderUsage::UseUniform);
242242
shader.AddInput("qk", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
243243
shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
244-
shader.AddInput("seqlens_k");
244+
shader.AddInput("seqlens_k", ShaderUsage::None);
245245
shader.AddOutput("out_split_vx", ShaderUsage::UseUniform);
246246

247247
const uint32_t tile_size_k_vec = 8u;
@@ -292,7 +292,7 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
292292

293293
Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const {
294294
shader.AddInput("input", ShaderUsage::UseUniform);
295-
shader.AddInput("seqlens_k");
295+
shader.AddInput("seqlens_k", ShaderUsage::None);
296296
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
297297

298298
return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_vx_reduce.wgsl.template",

onnxruntime/core/session/inference_session.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ static bool HasControlflowNodes(const Graph& graph) {
145145

146146
static bool HasMemcpyNodes(const Graph& graph) {
147147
for (const auto& node : graph.Nodes()) {
148-
if (node.OpType() == "MemcpyFromHost" || node.OpType() == "MemcpyToHost") {
148+
if (node.OpType() == "MemcpyFromHost") {
149149
return true;
150150
}
151151
}

0 commit comments

Comments
 (0)