Skip to content
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False
VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
Expand Down Expand Up @@ -1145,6 +1146,10 @@ def get_vllm_port() -> Optional[int]:
"VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),

# If set to 1, when we use fp8 kv, we do not quantize Q to fp8
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION":
lambda: bool(int(os.getenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "0"))),

# If set, it means we pre-downloaded cubin files and flashinfer will
# read the cubin files directly.
"VLLM_HAS_FLASHINFER_CUBIN":
Expand Down Expand Up @@ -1299,6 +1304,7 @@ def compute_hash() -> str:
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
"VLLM_USE_CUDNN_PREFILL",
"VLLM_USE_TRTLLM_ATTENTION",
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
"VLLM_ROCM_USE_AITER",
"VLLM_ROCM_USE_AITER_PAGED_ATTN",
"VLLM_ROCM_USE_AITER_LINEAR",
Expand Down
11 changes: 6 additions & 5 deletions vllm/utils/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,6 @@ def use_trtllm_attention(
logger.info_once("Using TRTLLM attention (query is quantized).")
return True

# TRTLLM prefill attention does not support FP8 kv cache with
# non-quantized query
if is_prefill and kv_cache_dtype.startswith("fp8"):
return False

# If sinks are being used, we must use TRTLLM attention as it's
# the only backend that supports them
if has_sinks:
Expand Down Expand Up @@ -353,6 +348,12 @@ def flashinfer_scaled_fp8_mm(
return output


@functools.cache
def flashinfer_disable_q_quantization() -> bool:
"""Cache result which only depends on the environment"""
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
Comment on lines +351 to +354
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we could get rid of this after #26146



__all__ = [
"has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe",
Expand Down
118 changes: 109 additions & 9 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import (supports_trtllm_attention,
from vllm.utils.flashinfer import (flashinfer_disable_q_quantization,
supports_trtllm_attention,
use_trtllm_attention)
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
# yapf conflicts with isort for this block
Expand All @@ -48,8 +49,89 @@
logger = init_logger(__name__)


class FlashInferBackend(AttentionBackend):
@triton.jit
def _trtllm_prefill_attn_kvfp8_dequant(
kv_cache_ptr,
block_tables_prefill_ptr,
block_table_stride,
mock_kv_cache_ptr,
k_scale_ptr,
v_scale_ptr,
K_CACHE_STRIDE: tl.constexpr,
KV_CACHE_STRIDE: tl.constexpr,
):
batch_idx = tl.program_id(0).to(tl.int64)
mock_block_table_idx = tl.program_id(1).to(tl.int64)
orig_page_num = tl.load(block_tables_prefill_ptr +
batch_idx * block_table_stride +
mock_block_table_idx).to(tl.int64)
if orig_page_num <= 0:
return
dequant_dtype = mock_kv_cache_ptr.dtype.element_ty

# Dequantize K
k_scale_val = tl.load(k_scale_ptr)
offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
fp8_vals = tl.load(kv_cache_ptr + offset)
dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val
mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx
+ 1) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
dequantized_vals = dequantized_vals.to(dequant_dtype)
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)

# Dequantize V
v_scale_val = tl.load(v_scale_ptr)
offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE +
tl.arange(0, K_CACHE_STRIDE))
fp8_vals = tl.load(kv_cache_ptr + offset)
dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val
mock_cache_offset = (
(batch_idx * block_table_stride + mock_block_table_idx + 1) *
KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE))
dequantized_vals = dequantized_vals.to(dequant_dtype)
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)


def trtllm_prefill_attn_kvfp8_dequant(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure but I think these functions could be placed in the vllm/utils/flashinfer.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i put this triton kernel in this file simply because there are already triton kernels in this file:

https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/flashinfer.py#L1016

kv_cache: torch.Tensor,
block_tables_prefill: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
dequant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, num_of_page_per_token = block_tables_prefill.shape
s = kv_cache.shape
assert s[1] == 2
assert dequant_dtype in (torch.bfloat16, torch.float16)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin revised to passing in q_dtype, and then support bf16 and fp16 and assert when encountering other types

k_cache_stride = s[2] * s[3] * s[4]
kv_cache_stride = k_cache_stride * s[1]
new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
# mock kv cache contains just the pages needed by this prefill
mock_kv_cache = torch.empty(new_s,
dtype=dequant_dtype,
device=kv_cache.device)
# we simply sequentially index the pages needed by this prefill
mock_block_table = torch.arange(
start=1,
end=batch_size * num_of_page_per_token + 1,
dtype=torch.int32,
device=block_tables_prefill.device,
).reshape(batch_size, num_of_page_per_token)
grid = (batch_size, num_of_page_per_token)
_trtllm_prefill_attn_kvfp8_dequant[grid](
kv_cache,
block_tables_prefill,
num_of_page_per_token,
mock_kv_cache,
k_scale,
v_scale,
k_cache_stride,
kv_cache_stride,
)
return mock_kv_cache, mock_block_table


class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True

@classmethod
Expand Down Expand Up @@ -122,7 +204,6 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:

@dataclass
class FlashInferMetadata:

num_actual_tokens: int # Number of tokens excluding padding.

# The data type of the query
Expand Down Expand Up @@ -175,8 +256,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\
decode_mode() == CUDAGraphMode.FULL
self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\
decode_mode() == CUDAGraphMode.FULL)
if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
Expand All @@ -201,7 +282,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
assert self.kv_cache_spec.dtype == self.model_config.dtype
self.kv_cache_dtype = self.kv_cache_spec.dtype

if supports_trtllm_attention()[0]:
if supports_trtllm_attention()[0] and \
not flashinfer_disable_q_quantization():
self.q_data_type = self.kv_cache_dtype
else:
self.q_data_type = self.model_config.dtype
Expand Down Expand Up @@ -795,11 +877,29 @@ def forward(
assert self.o_sf_scale is None
out = output[num_decode_tokens:]

if attn_metadata.q_data_type != FP8_DTYPE \
and self.kv_cache_dtype.startswith("fp8"):
# TRTLLM prefill attention does not support BF16 Q
# and fp8 kv cache. So to enable prefill attention
# with fp8 kv cache, we can construct a mock block
# and mock kv cache with BF16 KV involved in the prefill
mock_kv_cache, mock_block_table = (
trtllm_prefill_attn_kvfp8_dequant(
kv_cache_permute,
block_tables_prefill,
layer._k_scale,
layer._v_scale,
attn_metadata.q_data_type,
))
else:
mock_kv_cache = kv_cache_permute
mock_block_table = block_tables_prefill
Comment on lines +895 to +896
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't like using the term "mock" here. I realize we need to keep around kv_cache_permute for decode so we can't just reuse the name, but it is misleading to pass into the kernel


trtllm_batch_context_with_kv_cache(
query=prefill_query,
kv_cache=kv_cache_permute,
kv_cache=mock_kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables_prefill,
block_tables=mock_block_table,
seq_lens=seq_lens_prefill,
max_q_len=attn_metadata.max_q_len,
max_kv_len=attn_metadata.max_seq_len,
Expand Down Expand Up @@ -837,7 +937,7 @@ def forward(
decode_query = decode_query.contiguous()
workspace_buffer = decode_wrapper._float_workspace_buffer
block_tables_decode = attn_metadata.\
block_table_tensor[:num_decode_tokens]
block_table_tensor[:num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]

# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
Expand Down