-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[Don't review][webgpu] Make graph capture work on LLM #25868
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
Add seqlen_k to dynamically compute total_seq_length Add Indirect buffer usage fuse PrepareIndirectDispatch shader into CopyKVCache code reuse Update the conditions
d3e1ae0 to
1197a17
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc
Outdated
Show resolved
Hide resolved
onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc
Outdated
Show resolved
Hide resolved
This reverts commit bc4b41e.
### Description This PR unifies the present_sequence_length in flash attention and removes the dependency on total_sequence_length. This is preparation to support graph capture. #25868
### Description This PR adds the dispatchWorkgroupsIndirect capability for the program. It's part of the work to enable graph capture in phi4 #25868 --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This pull request introduces support for indirect dispatch in the WebGPU FlashAttention implementation, enabling more dynamic and efficient kernel launches based on runtime sequence lengths. The changes add new logic and parameters to propagate sequence length information and indirect dispatch buffers through the attention pipeline, with conditional code paths to maintain compatibility with the existing direct dispatch approach. It's part of the work to enable graph capture in phi4 #25868
This pull request extends the WebGPU execution provider to support int64 data type casting in the `Cast` operator, with conditional support based on whether graph capture is enabled. It refactors kernel registration to allow toggling int64 support and updates the shader code and kernel logic to handle int64 tensors efficiently. It's part of the work to enable graph capture in phi4 #25868
This pull request introduces support for indirect dispatch in the WebGPU FlashAttention implementation, enabling more dynamic and efficient kernel launches based on runtime sequence lengths. The changes add new logic and parameters to propagate sequence length information and indirect dispatch buffers through the attention pipeline, with conditional code paths to maintain compatibility with the existing direct dispatch approach. It's part of the work to enable graph capture in phi4 #25868
This pull request extends the WebGPU execution provider to support int64 data type casting in the `Cast` operator, with conditional support based on whether graph capture is enabled. It refactors kernel registration to allow toggling int64 support and updates the shader code and kernel logic to handle int64 tensors efficiently. It's part of the work to enable graph capture in phi4 #25868
This pull request enables conditionally register GQA with total_sequence_length on gpu or not. It resolves the issue that a MemcpyToHost is generated when graph capture is enabled (refer to #25868). This is the last functionality part to support graph capture in webgpu ep in ORT. The main changes ensure that when graph capture is enabled, sequence length information is read from GPU buffers instead of CPU memory, and shader code generation adapts accordingly. This enables more efficient execution and compatibility with graph-captured models. In this PR, we still get total sequence length from `seqlen_k` tensor not `total_seqlen_tensor` tensor to keep consistent with other parts. In the next PR, we can refactor all places to directly use `total_seqlen_tensor` instead of `seqlen_k` when graph capture enabled.
|
Close this one since the all functionalities to support graph capture have been merged separately. |
Description
This PR includes all necessary changes to enable graph capture in LLM.
It mainly introduces support for indirect dispatch in the WebGPU FlashAttention implementation, enabling kernel launches based on runtime total sequence lengths and some necessary changes so that the whole model can run on gpu.
This pull request is intended to facilitate discussion and provide a comprehensive overview of the overall changes. Subsequently, it will be divided into smaller pull requests to make the review process more manageable.