Skip to content

Conversation

@qjia7
Copy link
Contributor

@qjia7 qjia7 commented Oct 22, 2025

This pull request enables conditionally register GQA with total_sequence_length on gpu or not. It resolves the issue that a MemcpyToHost is generated when graph capture is enabled (refer to #25868). This is the last functionality part to support graph capture in webgpu ep in ORT.

The main changes ensure that when graph capture is enabled, sequence length information is read from GPU buffers instead of CPU memory, and shader code generation adapts accordingly. This enables more efficient execution and compatibility with graph-captured models.

In this PR, we still get total sequence length from seqlen_k tensor not total_seqlen_tensor tensor to keep consistent with other parts. In the next PR, we can refactor all places to directly use total_seqlen_tensor instead of seqlen_k when graph capture enabled.

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Oct 22, 2025
@guschmue guschmue requested a review from Copilot October 28, 2025 15:44
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enables conditional registration of the GroupQueryAttention (GQA) operator based on whether graph capture is enabled in the WebGPU execution provider. When graph capture is enabled, the operator reads total sequence length from GPU buffers instead of CPU memory, eliminating the need for a MemcpyToHost operation that was blocking graph capture support.

Key changes:

  • Modified GQA kernel registration to conditionally set InputMemoryType based on graph capture status
  • Updated flash attention shader templates and programs to support reading sequence length from GPU buffers
  • Added validation logic to handle total_seqlen tensor when it resides on GPU during graph capture

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc Passes enable_graph_capture flag to RegisterWebGpuContribKernels
onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h Adds enable_graph_capture parameter to RegisterWebGpuContribKernels signature
onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc Replaces static GQA registration with conditional registration via CreateGroupQueryAttentionKernelInfo
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h Declares CreateGroupQueryAttentionKernelInfo function for conditional kernel creation
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Implements conditional kernel registration and updates ApplyFlashAttention signature to accept seqlen_k
onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template Adds get_total_sequence_length() function that reads from either GPU buffer or uniforms based on use_seqlen_k flag
onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Adds use_seqlen_k member to CopyKVCacheProgram and FlashAttentionProgram classes
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Implements use_seqlen_k logic in shader code generation and removes past_sequence_length uniform
onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h Updates validation logic to skip CPU-specific checks when total_seqlen is on GPU

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@qjia7 qjia7 merged commit f7fd3b5 into main Oct 29, 2025
94 of 96 checks passed
@qjia7 qjia7 deleted the dynamic_register_gqa branch October 29, 2025 02:28
naomiOvad pushed a commit to naomiOvad/onnxruntime that referenced this pull request Nov 2, 2025
This pull request enables conditionally register GQA with
total_sequence_length on gpu or not. It resolves the issue that a
MemcpyToHost is generated when graph capture is enabled (refer to
microsoft#25868). This is the last functionality part to support graph capture in
webgpu ep in ORT.

The main changes ensure that when graph capture is enabled, sequence
length information is read from GPU buffers instead of CPU memory, and
shader code generation adapts accordingly. This enables more efficient
execution and compatibility with graph-captured models.

In this PR, we still get total sequence length from `seqlen_k` tensor
not `total_seqlen_tensor` tensor to keep consistent with other parts. In
the next PR, we can refactor all places to directly use
`total_seqlen_tensor` instead of `seqlen_k` when graph capture enabled.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants