Skip to content

Conversation

@qjia7
Copy link
Contributor

@qjia7 qjia7 commented Aug 27, 2025

Description

This PR includes all necessary changes to enable graph capture in LLM.
It mainly introduces support for indirect dispatch in the WebGPU FlashAttention implementation, enabling kernel launches based on runtime total sequence lengths and some necessary changes so that the whole model can run on gpu.

This pull request is intended to facilitate discussion and provide a comprehensive overview of the overall changes. Subsequently, it will be divided into smaller pull requests to make the review process more manageable.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Add seqlen_k to dynamically compute total_seq_length

Add Indirect buffer usage

fuse PrepareIndirectDispatch shader into CopyKVCache

code reuse

Update the conditions
@qjia7 qjia7 force-pushed the indirect_dispatch branch from d3e1ae0 to 1197a17 Compare September 3, 2025 06:45
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

@qjia7 qjia7 changed the title [WIP][webgpu] Indirect dispatch support [Don't review][webgpu] Make graph capture work on LLM Sep 3, 2025
guschmue pushed a commit that referenced this pull request Sep 15, 2025
### Description
This PR unifies the present_sequence_length in flash attention and
removes the dependency on total_sequence_length. This is preparation to
support graph capture. #25868
fs-eire pushed a commit that referenced this pull request Sep 18, 2025
### Description
This PR adds the dispatchWorkgroupsIndirect capability for the program.
It's part of the work to enable graph capture in phi4 #25868

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
qjia7 added a commit that referenced this pull request Oct 14, 2025
This pull request introduces support for indirect dispatch in the WebGPU
FlashAttention implementation, enabling more dynamic and efficient
kernel launches based on runtime sequence lengths. The changes add new
logic and parameters to propagate sequence length information and
indirect dispatch buffers through the attention pipeline, with
conditional code paths to maintain compatibility with the existing
direct dispatch approach.

It's part of the work to enable graph capture in phi4
#25868
fs-eire pushed a commit that referenced this pull request Oct 15, 2025
This pull request extends the WebGPU execution provider to support int64
data type casting in the `Cast` operator, with conditional support based
on whether graph capture is enabled. It refactors kernel registration to
allow toggling int64 support and updates the shader code and kernel
logic to handle int64 tensors efficiently.

It's part of the work to enable graph capture in phi4
#25868
fs-eire pushed a commit that referenced this pull request Oct 24, 2025
This pull request introduces support for indirect dispatch in the WebGPU
FlashAttention implementation, enabling more dynamic and efficient
kernel launches based on runtime sequence lengths. The changes add new
logic and parameters to propagate sequence length information and
indirect dispatch buffers through the attention pipeline, with
conditional code paths to maintain compatibility with the existing
direct dispatch approach.

It's part of the work to enable graph capture in phi4
#25868
fs-eire pushed a commit that referenced this pull request Oct 24, 2025
This pull request extends the WebGPU execution provider to support int64
data type casting in the `Cast` operator, with conditional support based
on whether graph capture is enabled. It refactors kernel registration to
allow toggling int64 support and updates the shader code and kernel
logic to handle int64 tensors efficiently.

It's part of the work to enable graph capture in phi4
#25868
qjia7 added a commit that referenced this pull request Oct 29, 2025
This pull request enables conditionally register GQA with
total_sequence_length on gpu or not. It resolves the issue that a
MemcpyToHost is generated when graph capture is enabled (refer to
#25868). This is the last functionality part to support graph capture in
webgpu ep in ORT.

The main changes ensure that when graph capture is enabled, sequence
length information is read from GPU buffers instead of CPU memory, and
shader code generation adapts accordingly. This enables more efficient
execution and compatibility with graph-captured models.

In this PR, we still get total sequence length from `seqlen_k` tensor
not `total_seqlen_tensor` tensor to keep consistent with other parts. In
the next PR, we can refactor all places to directly use
`total_seqlen_tensor` instead of `seqlen_k` when graph capture enabled.
@qjia7
Copy link
Contributor Author

qjia7 commented Oct 30, 2025

Close this one since the all functionalities to support graph capture have been merged separately.

@qjia7 qjia7 closed this Oct 30, 2025
@qjia7 qjia7 deleted the indirect_dispatch branch October 30, 2025 06:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants