Skip to content
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False
VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
Expand Down Expand Up @@ -1145,6 +1146,10 @@ def get_vllm_port() -> Optional[int]:
"VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),

# If set to 1, when we use fp8 kv, we do not quantize Q to fp8
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION":
lambda: bool(int(os.getenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "0"))),

# If set, it means we pre-downloaded cubin files and flashinfer will
# read the cubin files directly.
"VLLM_HAS_FLASHINFER_CUBIN":
Expand Down Expand Up @@ -1299,6 +1304,7 @@ def compute_hash() -> str:
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
"VLLM_USE_CUDNN_PREFILL",
"VLLM_USE_TRTLLM_ATTENTION",
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
"VLLM_ROCM_USE_AITER",
"VLLM_ROCM_USE_AITER_PAGED_ATTN",
"VLLM_ROCM_USE_AITER_LINEAR",
Expand Down
11 changes: 6 additions & 5 deletions vllm/utils/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,6 @@ def use_trtllm_attention(
logger.info_once("Using TRTLLM attention (query is quantized).")
return True

# TRTLLM prefill attention does not support FP8 kv cache with
# non-quantized query
if is_prefill and kv_cache_dtype.startswith("fp8"):
return False

# If sinks are being used, we must use TRTLLM attention as it's
# the only backend that supports them
if has_sinks:
Expand Down Expand Up @@ -353,6 +348,12 @@ def flashinfer_scaled_fp8_mm(
return output


@functools.cache
def flashinfer_disable_q_quantization() -> bool:
"""Cache result which only depends on the environment"""
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
Comment on lines +351 to +354
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we could get rid of this after #26146



__all__ = [
"has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe",
Expand Down
114 changes: 105 additions & 9 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import (supports_trtllm_attention,
from vllm.utils.flashinfer import (flashinfer_disable_q_quantization,
supports_trtllm_attention,
use_trtllm_attention)
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
# yapf conflicts with isort for this block
Expand All @@ -48,8 +49,86 @@
logger = init_logger(__name__)


class FlashInferBackend(AttentionBackend):
@triton.jit
def _trtllm_prefill_attn_kvfp8_dequant(
kv_cache_ptr,
block_tables_prefill_ptr,
block_table_stride,
mock_kv_cache_ptr,
k_scale_ptr,
v_scale_ptr,
K_CACHE_STRIDE: tl.constexpr,
KV_CACHE_STRIDE: tl.constexpr,
):
batch_idx = tl.program_id(0).to(tl.int64)
mock_block_table_idx = tl.program_id(1).to(tl.int64)
orig_page_num = tl.load(block_tables_prefill_ptr +
batch_idx * block_table_stride +
mock_block_table_idx).to(tl.int64)
if orig_page_num <= 0:
return

# Dequantize K
k_scale_val = tl.load(k_scale_ptr)
offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
fp8_vals = tl.load(kv_cache_ptr + offset)
dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val
mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx
+ 1) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
tl.store(mock_kv_cache_ptr + mock_cache_offset,
dequantized_vals.to(tl.bfloat16))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't hardcode the dtype to bfloat16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin triton does not support template dtype, so it is difficult to support generic types here. Do you have some specific types that you want to support beyond bfloat16?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't you just use the dtype of mock_kv_cache?

Copy link
Contributor Author

@mxz297 mxz297 Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the kernel, we need to use triton dtype (tl.bfloat16 or tl.float16) instead of tensor dtype (torch.bfloat16 or torch.float16). I may be wrong but my understanding is that within a triton kernel, we do not get the dtype of a tensor. The only way i am aware is to passing in bool variable as kernel input argument to indicate whether certain type is being used. We cannot pass a torch.dtype as a kernel input argument.

In my latest revision, i changed the code to support both bf16 and fp16


# Dequantize V
v_scale_val = tl.load(v_scale_ptr)
offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE +
tl.arange(0, K_CACHE_STRIDE))
fp8_vals = tl.load(kv_cache_ptr + offset)
dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val
mock_cache_offset = (
(batch_idx * block_table_stride + mock_block_table_idx + 1) *
KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE))
tl.store(mock_kv_cache_ptr + mock_cache_offset,
dequantized_vals.to(tl.bfloat16))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto



def trtllm_prefill_attn_kvfp8_dequant(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure but I think these functions could be placed in the vllm/utils/flashinfer.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i put this triton kernel in this file simply because there are already triton kernels in this file:

https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/flashinfer.py#L1016

kv_cache: torch.Tensor,
block_tables_prefill: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, num_of_page_per_token = block_tables_prefill.shape
s = kv_cache.shape
assert s[1] == 2
k_cache_stride = s[2] * s[3] * s[4]
kv_cache_stride = k_cache_stride * s[1]
new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
# mock kv cache contains just the pages needed by this prefill
mock_kv_cache = torch.empty(new_s,
dtype=torch.bfloat16,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

device=kv_cache.device)
# we simply sequentially index the pages needed by this prefill
mock_block_table = torch.arange(
start=1,
end=batch_size * num_of_page_per_token + 1,
dtype=torch.int32,
device=block_tables_prefill.device,
).reshape(batch_size, num_of_page_per_token)
grid = (batch_size, num_of_page_per_token)
_trtllm_prefill_attn_kvfp8_dequant[grid](
kv_cache,
block_tables_prefill,
num_of_page_per_token,
mock_kv_cache,
k_scale,
v_scale,
k_cache_stride,
kv_cache_stride,
)
return mock_kv_cache, mock_block_table


class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True

@classmethod
Expand Down Expand Up @@ -122,7 +201,6 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:

@dataclass
class FlashInferMetadata:

num_actual_tokens: int # Number of tokens excluding padding.

# The data type of the query
Expand Down Expand Up @@ -177,8 +255,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\
decode_mode() == CUDAGraphMode.FULL
self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\
decode_mode() == CUDAGraphMode.FULL)
if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
Expand All @@ -203,7 +281,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
assert self.kv_cache_spec.dtype == self.model_config.dtype
self.kv_cache_dtype = self.kv_cache_spec.dtype

if supports_trtllm_attention()[0]:
if supports_trtllm_attention()[0] and \
not flashinfer_disable_q_quantization():
self.q_data_type = self.kv_cache_dtype
else:
self.q_data_type = self.model_config.dtype
Expand Down Expand Up @@ -805,11 +884,28 @@ def forward(
assert self.o_sf_scale is None
out = output[num_decode_tokens:]

if attn_metadata.q_data_type != FP8_DTYPE \
and self.kv_cache_dtype.startswith("fp8"):
# TRTLLM prefill attention does not support BF16 Q
# and fp8 kv cache. So to enable prefill attention
# with fp8 kv cache, we can construct a mock block
# and mock kv cache with BF16 KV involved in the prefill
mock_kv_cache, mock_block_table = (
trtllm_prefill_attn_kvfp8_dequant(
kv_cache_permute,
block_tables_prefill,
layer._k_scale,
layer._v_scale,
))
else:
mock_kv_cache = kv_cache_permute
mock_block_table = block_tables_prefill
Comment on lines +895 to +896
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't like using the term "mock" here. I realize we need to keep around kv_cache_permute for decode so we can't just reuse the name, but it is misleading to pass into the kernel


trtllm_batch_context_with_kv_cache(
query=prefill_query,
kv_cache=kv_cache_permute,
kv_cache=mock_kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables_prefill,
block_tables=mock_block_table,
seq_lens=seq_lens_prefill,
max_q_len=attn_metadata.max_q_len,
max_kv_len=attn_metadata.max_seq_len,
Expand Down Expand Up @@ -847,7 +943,7 @@ def forward(
decode_query = decode_query.contiguous()
workspace_buffer = decode_wrapper._float_workspace_buffer
block_tables_decode = attn_metadata.\
block_table_tensor[:num_decode_tokens]
block_table_tensor[:num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]

# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
Expand Down