-
Notifications
You must be signed in to change notification settings - Fork 559
Description
🚀 Feature
To have better performance and more rich features for vLLM on TPU, we want to add ragged input support in the multi-queries paged attention.
Motivation
- Batching more than one prefill inputs
Currently, the vLLM TPU backend limits the batch size to 1 for prefill inputs because of the compilation time. The current 2D input shape [batch_size, seq_len] requires compiling approximately log2(max_seq_len) different XLA graphs per batch size. Given that log2(max_seq_len) is ~14 in practice, we can only support batch_size=1 for reasonable compilation time. However, with the 1D input shape [num_tokens], we can cover all batch sizes with log2(max_seq_len) XLA graphs in total. This can significantly affect the performance when processing many short sequences. For example, if we have 4 sequences each with 128 tokens, currently we have to run them one by one, through 4 steps. In contrast, with the 1D input graphs, we can process the 4 sequences in a single step with the XLA graph for num_tokens=512. Given that the model performance is memory bound for small inputs, this can theoretically give ~4x speedup (while the actual speedup will be definitely smaller in practice).
- Mixing prefill & decode inputs in the same batch
Currently, the vLLM TPU backend alternates between the “prefill mode” and “decode mode” because the model can only run a single type of inputs at a time. With the new kernel, we can mix both in the same batch, and achieve higher hardware utilization. This is expected to give 5~10% perf improvements.
- Enabling chunked prefills
After 2, we can truly integrate chunked prefills, which splits the prefill input into multiple chunks and processes them through multiple steps. While this could theoretically give performance improvements, we didn’t see noticeable improvements on the GPU backend. However, it is important to control the TPOT (time per output token) since without this feature we have latency spikes whenever a new prefill request joins.
- Reducing compilation time
Currently, we compile ~28 XLA graphs by default, and 42 XLA graphs when the prefix caching is enabled. This takes 10-20 mins for first-time compilation, and 1-5 mins for later (when the compiled XLA graphs are cached in disk). With the new kernel, we only need to compile 14 XLA graphs in total. Therefore, the compilation time will decrease by 2-3x.
-
The new kernel enables vLLM on TPU to migrate to vLLM v1, the latest and the most optimized version of vLLM. vLLM on GPU has been migrated from v0 to v1 which is said to have brought 33% performance gain in the proof of concept.
-
The new kernel enables vLLM to have a unified interface between GPU and TPU. On GPU, vLLM is already using a single kernel to handle the 3 use cases: flash attention, single-query paged attention, and multi-queries paged attention.
Pitch
def paged_attention(
q: jax.Array, # [num_tokens, num_heads, head_dim]
k_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim]
v_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim]
lengths: jax.Array, # i32[num_tokens]
page_indices: jax.Array, # i32[num_tokens, pages_per_sequence]
cu_q_lens: jax.Array, # i32[num_tokens + 1]
num_seqs: jax.Array, # i32[1]
) -> jax.Array: # [num_tokens, num_heads, head_dim]
Since we pack multiple sequences into the 1D tensor of shape [num_tokens], we need to keep some metadata to reconstruct the sequences:
- num_seqs is an array with a single integer element that indicates the actual number of sequences packed in the input q. The value should be in [1, num_tokens] (inclusive).
- cu_q_lens is an integer array for the cumulative sum of the query lengths. For example, if we pack three sequences of lengths 2, 9, and 5, cu_q_lens should be [0, 2, 11, 16, x, x, x, ...] where x is any value for padding. While cu_q_lens’s shape is [num_tokens + 1], only the first num_seqs + 1 values are valid. The rest should be ignored.
- lengths refers to the effective kv length for each sequence. For example, if we have three sequences, lengths could be [16, 3, 1024, x, x, x, x, ...] where x is any value for padding. While lengths’s shape is [num_tokens], only the first num_seqs values are valid. The rest should be ignored.
- page_indices stores the index for each sequence and each page within the sequence. For example, assuming we have 2 sequences, if the query token 0 and 1 are from the 0th sequence, and query token 2,3,4,5 are from the 1st sequence, then token 0 and 1 attend the kv cache pages in page_indices[0]. Token 2,3,4,5 attend the kv cache pages in page_indices[1]. The rest (paged_indices[2:]) should be ignored.
This kernel should be able to handle a wide range of inputs. One extreme case is when we have only one sequence with num_tokens tokens. Another extreme case is when we have num_tokens sequences each with a single token. Both are totally valid cases that need to be optimized.
Also, the page_indices tensor should be loaded to SMEM on demand, since it’s unlikely that the entire tensor is needed for inputs (when num_seqs < num_tokens).
Finally, num_tokens can range from 16 – 8K in practice.
Implementation
We will handle the raggedness like this
Psudo code:
Step 1: calculate metadata
Similar to https://github.com/jax-ml/jax/blob/6e1f060ad3b0b7d08980d65c867186e914eb90b1/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L79, we need to calculate some metadata:
- num_logical_tiles_q: On query-dim, how many logical query-dimension tiles do we need to work on.
- sequence_ids: Which sequence the grid index (num_logical_tiles_q) will work on.
- query_physical_tile_ids: Which physical query-dim tile the grid index (num_logical_tiles_q) will work on.
# See the appendix for an example to understand this better.
# NB: no need to calculate `group_offsets` as in the gmm kernel because the paged_attention argument `cu_q_lens` provides the information.
Step 2: Run the kernel
We will set the `grid` and `BlockSpec` according to the below pseudo code.
q = jnp.permute_dims(q, (1, 0, 2)) # [num_q_heads, num_tokens, head_dim]
for kv_head_idx in range(num_kv_heads):
for logical_q_blk_idx in range(num_logical_tiles_q):
for kv_blk_idx in range(num_kv_len_blocks):
# Within the Pallas kernel
# q.shape=[num_q_heads_per_kv_head, q_block_size, head_size]
# cur_sequence_id=sequence_ids[logical_q_blk_idx]
# cur_physical_query_tile_id=query_physical_tile_ids[logical_q_blk_idx]
# Load the kv pages corresponding to the current batch from HBM to VMEM. (may have some complication due to a tile can belong to 1 or 2 sequences.):
# - we use sequence_id, kv_head_idx, kv_blk_idx to identify the k/v pages.
# - we use sequence_id, kv_head_idx, logical_q_blk_idx, kv_blk_idx to calculate the next (sequence_id, kv_head_idx, kv_blk_idx) to prefetch the next KV pages.
for q_head_idx in range(num_q_heads_per_kv_head): # for GQA and MQA
# Within the flash attention kernel, we do the regular flash attention v2 algorithm.
# q.shape=[q_block_size, head_size]
# k.shape=[k_blk_size, head_size](aka. [pages_per_compute_block*page_size,head_size])
# If the q-dim physical tile is changed (meaning it is a new physical q-dim tile that has not visited before), initialize the acc_scratch_ref, m_scratch_ref, and l_scratch_ref to run the flash attention v2 algorithm.
# Load the part of the q_block that belongs to the current sequence and compute the attention. May need to pad on the q-dim due to Mosaic matmul constraint.
# Modify the mask accordingly.
# For the q-dim of acc_scratch_ref, m_scratch_ref, and l_scratch_ref for the, only update the relevant parts via a mask.
# Store the result from VMEM to HBM only when it is the last kv_block and the next q-dim logical tile belongs to a different q-dim physical tile.
Alternatives
Additional context
cc: @miladm