Skip to content
Open
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
b3a01dc
copying jit_cache, adapting jit cache for 2d kernel
bringlein Apr 10, 2025
3490bfc
some cleanup
bringlein Apr 10, 2025
4cc5407
formatting, typos...
bringlein Apr 10, 2025
59755e2
ruff....
bringlein Apr 10, 2025
f114090
adding assume const to jit cache
bringlein Apr 11, 2025
c43006e
experimenting with static launch grid again
bringlein Apr 11, 2025
9da4df6
recovering good performance
bringlein Apr 11, 2025
d7fc0af
going back to static launch grid
bringlein Apr 14, 2025
bf64b6d
formatting...
bringlein Apr 14, 2025
f3fb7e9
make type checking of key arguments more helpful
bringlein Apr 14, 2025
dc3b28c
applying jit cache for prefix prefill
bringlein Apr 14, 2025
e717040
fmt & ruff
bringlein Apr 14, 2025
fe2f6a5
ci
bringlein Apr 14, 2025
14cca7e
remove changed requirements by mistake/pre-hook?
bringlein Apr 14, 2025
d37ef48
fmt...
bringlein Apr 14, 2025
5e4bb2f
removing jit cache from prefix prefill again
bringlein Apr 15, 2025
c711433
cleanup
bringlein Apr 15, 2025
f8c6610
address review comments
bringlein Apr 24, 2025
f8b5001
fix type hints
bringlein Apr 24, 2025
ef3d6a3
add transparency as fallback mode
bringlein Apr 24, 2025
edf8633
CI whacamole
bringlein Apr 24, 2025
10df1df
CI whacamole...
bringlein Apr 24, 2025
cf1cea9
Merge branch 'main' into ngl_jit_cache_pr
bringlein May 7, 2025
f6852ed
adding triton 3.3 support
bringlein May 8, 2025
b93de23
Merge branch 'main' into ngl_jit_cache_pr
bringlein May 8, 2025
72d9858
fixing triton 3.3 support (1/x); add support for unified kernel
bringlein May 8, 2025
eeaab8d
fixing triton 3.3 support (2/2)
bringlein May 9, 2025
9ffc6e4
cleanup and add env var
bringlein May 9, 2025
1c65d75
adding assume_const
bringlein May 9, 2025
43b500b
make argument passing (slightly) faster
bringlein May 9, 2025
43aed8c
Merge branch 'main' into ngl_jit_cache_pr (moving envs content)
bringlein May 13, 2025
e50534a
fixing env var merge conflict
bringlein May 13, 2025
450770c
adding attention metadata specific for triton_backend
bringlein May 13, 2025
f7705c0
fixing env file again
bringlein May 13, 2025
3a5c63e
Revert "adding attention metadata specific for triton_backend"
bringlein May 13, 2025
e2ef23e
more elegant fix on dependency of flash attention
bringlein May 13, 2025
8f5735b
thrid way to un-break triton backend
bringlein May 13, 2025
ccd22c9
CI...
bringlein May 13, 2025
a94e99b
making jitcache safe to use with autotuner
cyang49 May 14, 2025
af094a3
CI whacamole...
bringlein May 14, 2025
c1b21d5
fixup spelling in a few spots
tlrmchlsmth May 20, 2025
be9d7d4
Merge branch 'main' into ngl_jit_cache_pr
tdoublep May 23, 2025
791b8b2
Added support for specialization.
tdoublep May 23, 2025
f4a436a
Merge branch 'main' into ngl_jit_cache_pr
bringlein Jun 12, 2025
f72a768
minor cleanup; remove copy of launch grid
bringlein Jun 13, 2025
02a6ea4
improve docstring
bringlein Jun 18, 2025
cd987c2
Merge branch 'main' into ngl_jit_cache_pr
bringlein Jun 18, 2025
d52af9b
ruff....
bringlein Jun 18, 2025
e1cf444
fixing merge error
bringlein Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 62 additions & 35 deletions vllm/attention/ops/chunked_prefill_paged_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.triton_utils import tl, triton
from vllm.triton_utils.jit_cache import jitcache

from .prefix_prefill import context_attention_fwd

Expand All @@ -21,45 +22,67 @@ def cdiv_fn(x, y):
return (x + y - 1) // y


@jitcache(
check_keys=["USE_ALIBI_SLOPES", "SLIDING_WINDOW", "filter_by_query_len"],
assume_const=[
"scale",
"k_scale",
"v_scale",
"query_stride_1",
"output_stride_1",
"stride_k_cache_0",
"stride_k_cache_1",
"stride_k_cache_2",
"stride_k_cache_4",
"stride_v_cache_0",
"stride_v_cache_1",
"stride_v_cache_2",
"stride_v_cache_2",
],
cache_launch_grid=True,
)
@triton.jit
def kernel_paged_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
scale, # float32
k_scale, # float32
v_scale, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
x: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int
stride_k_cache_4: tl.int64, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
scale: float, # float32
k_scale: float, # float32
v_scale: float, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
x: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int
stride_k_cache_4: tl.int64, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
num_seqs: int,
):
seq_idx = tl.program_id(0)
if seq_idx >= num_seqs:
return
kv_head_idx = tl.program_id(1)

if filter_by_query_len:
Expand Down Expand Up @@ -324,6 +347,9 @@ def chunked_prefill_paged_decode(
v_scale=v_scale,
)
else:
# we use a "static launch grid" for the kernel, in order to cache it.
# Therefore, we assuem a maximum batch_size of 4096.
assert num_seqs <= 4096
kernel_paged_attention_2d[(
num_seqs,
num_kv_heads,
Expand Down Expand Up @@ -363,4 +389,5 @@ def chunked_prefill_paged_decode(
stride_v_cache_3=value_cache.stride(3),
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
num_seqs=num_seqs,
)
18 changes: 18 additions & 0 deletions vllm/attention/ops/triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import triton.language as tl

from vllm.logger import init_logger
from vllm.triton_utils.jit_cache import jitcache

logger = init_logger(__name__)

Expand All @@ -27,6 +28,23 @@ def apply_softcap(S, x):
return x * (p1 - p2) / (p1 + p2)


@jitcache(
check_keys=[],
assume_const=[
"scale",
"k_scale",
"v_scale",
"query_stride_1",
"output_stride_1",
"stride_k_cache_0",
"stride_k_cache_1",
"stride_k_cache_2",
"stride_k_cache_4",
"stride_v_cache_0",
"stride_v_cache_1",
"stride_v_cache_2",
],
)
@triton.jit
def kernel_unified_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size]
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
VLLM_PLUGINS: Optional[list[str]] = None
VLLM_LORA_RESOLVER_CACHE_DIR: Optional[str] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_TRITON_ENABLE_JITCACHE: bool = False
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
Expand Down Expand Up @@ -516,6 +517,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),

# Enable the JITCache for Triton Kernels
# see triton_utils/jitcache.py
"VLLM_TRITON_ENABLE_JITCACHE":
lambda: bool(int(os.getenv("VLLM_TRITON_ENABLE_JITCACHE", "0"))),

# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
Expand Down
Loading