[Bugfix] Correct KV cache tensor dimension handling in FlashInfer backend's block operations #15603
+9
−4
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
For now,
swap_blocks
andcopy_blocks
in Flashinfer backend directly reusing PagedAttention implements as follows:However, there's a critical shape mismatch between the two impls. Flashinfer assumes the KV cache block shape is
(num_blocks, 2, block_size, num_kv_heads, head_size)
, whereas PagedAttention uses(2, num_blocks, block_size * num_kv_heads * head_size)
. This discrepancy causes output errors when usingswap_blocks
andcopy_blocks
to move cache blocks.For
swap_blocks
, we can simply reuse the existingops.swap_blocks
since it only considers the index of the firstnum_blocks
dimension:For
copy_blocks
, we need to split KV along the second dimension:After fixing this issue, features such as CPU offloading (e.g., unmerged #13377) will work properly with Flashinfer backend, and also ensures the correctness of any future functionality that relies on these interfaces.