diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index a0478a359f91..7e3ea561fd29 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -626,6 +626,7 @@ def _causal_conv1d_update_kernel( cache_seqlens_ptr, # circular buffer conv_state_indices_ptr, num_accepted_tokens_ptr, + query_start_loc_ptr, # (batch + 1) o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, @@ -652,6 +653,7 @@ def _causal_conv1d_update_kernel( HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, + IS_VARLEN: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, @@ -678,6 +680,25 @@ def _causal_conv1d_update_kernel( # not processing as this is not the actual sequence return + if IS_VARLEN: + query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64) + query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to( + tl.int64) + # revise state_len and seqlen + state_len = state_len - (seqlen - + (query_end_index - query_start_index)) + seqlen = query_end_index - query_start_index + x_offset = query_start_index * stride_x_token + o_offset = query_start_index * stride_o_token + else: + query_start_index = idx_seq * seqlen + query_end_index = query_start_index + seqlen + x_offset = idx_seq * stride_x_seq + o_offset = idx_seq * stride_o_seq + + if query_start_index == query_end_index: + return + if IS_SPEC_DECODING: # The rolling of conv state: # @@ -692,8 +713,8 @@ def _causal_conv1d_update_kernel( # - accept 1 tokens: [history2, ..., historyM, draft1] # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] # - and so on. - conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) - - 1) + conv_state_token_offset = ( + tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1) else: conv_state_token_offset = 0 @@ -713,9 +734,12 @@ def _causal_conv1d_update_kernel( if KERNEL_WIDTH >= 4: conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] col2 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH == 5: + if KERNEL_WIDTH >= 5: conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 6: + conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N] + col4 = tl.load(conv_states_ptrs, mask_w, 0.0) # STEP 2: assume state_len > seqlen idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] @@ -735,8 +759,7 @@ def _causal_conv1d_update_kernel( conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) VAL = state_len - seqlen - x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim - ) # [BLOCK_N] + x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] x_ptrs = x_base[None, :] + ( (idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] @@ -782,12 +805,18 @@ def _causal_conv1d_update_kernel( if KERNEL_WIDTH >= 4: w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 5: + w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor + w_col4 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 6: + w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor + w_col5 = tl.load(w_ptrs, mask_w, other=0.0) x_base_1d = x_base # starting of chunk [BLOCK_N] mask_x_1d = idx_feats < dim # STEP 5: compute each token - for idx_token in tl.static_range(seqlen): + for idx_token in tl.range(seqlen): acc = acc_preload matrix_w = w_col0 @@ -817,6 +846,37 @@ def _causal_conv1d_update_kernel( matrix_w = w_col3 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 5: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 6: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + matrix_x = col4 + elif j == 5: + matrix_w = w_col5 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) acc += matrix_x * matrix_w # [BLOCK_N] @@ -829,14 +889,24 @@ def _causal_conv1d_update_kernel( col0 = col1 col1 = col2 col2 = matrix_x + elif KERNEL_WIDTH == 5: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = matrix_x + elif KERNEL_WIDTH == 6: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = col4 + col4 = matrix_x if SILU_ACTIVATION: acc = acc / (1 + tl.exp(-acc)) mask_1d = (idx_token < seqlen) & (idx_feats < dim ) # token-index # feature-index - o_ptrs = o_ptr + ( - idx_seq) * stride_o_seq + idx_token * stride_o_token + ( - idx_feats * stride_o_dim) + o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * + stride_o_dim) tl.store(o_ptrs, acc, mask=mask_1d) @@ -850,14 +920,18 @@ def causal_conv1d_update( cache_seqlens: Optional[torch.Tensor] = None, conv_state_indices: Optional[torch.Tensor] = None, num_accepted_tokens: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, metadata=None, validate_data=False, ): """ - x: (batch, dim) or (batch, dim, seqlen) + x: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim) [shape=2: single token prediction] [shape=3: single or multiple tokens prediction] + [shape=2 with num_tokens: continuous batching, where num_tokens is the + total tokens of all sequences in that batch] conv_state: (..., dim, state_len), where state_len >= width - 1 weight: (dim, width) bias: (dim,) @@ -870,13 +944,24 @@ def causal_conv1d_update( If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. + num_accepted_tokens: (batch,), dtype int32 + If not None, it indicates the number of accepted tokens for each + sequence in the batch. + This is used in speculative decoding, where the conv_state is updated + in a sliding window manner. + query_start_loc: (batch + 1,) int32 + If not None, the inputs is given in a varlen fashion and this indicates + the starting index of each sequence in the batch. + max_query_len: int + If query_start_loc is not None, this indicates the maximum query + length in the batch. pad_slot_id: int if cache_indices is passed, lets the kernel identify padded entries that will not be processed, for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - out: (batch, dim) or (batch, dim, seqlen) + out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` """ if validate_data: assert cache_seqlens is None # not implemented yet - ok for vLLM @@ -886,11 +971,17 @@ def causal_conv1d_update( activation = "silu" if activation is True else None elif activation is not None: assert activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 + unsqueeze = query_start_loc is None and x.dim() == 2 if unsqueeze: # make it (batch, dim, seqlen) with seqlen == 1 x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape + if query_start_loc is None: + batch, dim, seqlen = x.shape + else: + assert conv_state_indices is not None + batch = conv_state_indices.size(0) + dim = x.size(1) + seqlen = max_query_len _, width = weight.shape # conv_state: (..., dim, state_len), where state_len >= width - 1 num_cache_lines, _, state_len = conv_state.size() @@ -916,10 +1007,17 @@ def causal_conv1d_update( out = x stride_w_dim, stride_w_width = weight.stride() - stride_x_seq, stride_x_dim, stride_x_token = x.stride( - ) # X (batch, dim, seqlen) + if query_start_loc is None: + # X (batch, dim, seqlen) + stride_x_seq, stride_x_dim, stride_x_token = x.stride() + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + else: + # X (dim, cu_seqlen) + stride_x_token, stride_x_dim = x.stride() + stride_x_seq = 0 + stride_o_token, stride_o_dim = out.stride() + stride_o_seq = 0 - stride_o_seq, stride_o_dim, stride_o_token = out.stride() stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( ) stride_state_indices = conv_state_indices.stride( @@ -945,6 +1043,7 @@ def grid(META): cache_seqlens, conv_state_indices, num_accepted_tokens, + query_start_loc, out, # Matrix dimensions batch, @@ -971,6 +1070,7 @@ def grid(META): HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], + IS_VARLEN=query_start_loc is not None, IS_CONTINUOUS_BATCHING=conv_state_indices is not None, IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 3c5407916c0b..fe63e9303235 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -417,9 +417,7 @@ def _forward( self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] - num_actual_tokens = (attn_metadata.num_prefill_tokens + - attn_metadata.num_decode_tokens + - attn_metadata.num_spec_decode_tokens) + num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens # 1. Set up dimensions for reshapes later @@ -458,9 +456,6 @@ def _forward( # 2.1: process the mutli-query part if spec_sequence_masks is not None: - mixed_qkv_spec = mixed_qkv_spec.view( - attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1)) - mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l') mixed_qkv_spec = causal_conv1d_update( mixed_qkv_spec, conv_state, @@ -470,9 +465,10 @@ def _forward( conv_state_indices=spec_state_indices_tensor[:, 0] [:attn_metadata.num_spec_decodes], num_accepted_tokens=num_accepted_tokens, + query_start_loc=spec_query_start_loc, + max_query_len=spec_state_indices_tensor.size(-1), validate_data=False, ) - mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d') # 2.2: process the remaining part if attn_metadata.num_prefills > 0: diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 74eb9ae9d325..ba89f93e8b56 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -31,6 +31,7 @@ class GDNAttentionMetadata: num_decode_tokens: int num_spec_decodes: int num_spec_decode_tokens: int + num_actual_tokens: int has_initial_state: Optional[torch.Tensor] = None @@ -74,8 +75,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.use_full_cuda_graph = \ self.compilation_config.cudagraph_mode.has_full_cudagraphs() self.decode_cudagraph_max_bs = min( - self.vllm_config.scheduler_config.max_num_seqs, - self.compilation_config.max_capture_size) + self.vllm_config.scheduler_config.max_num_seqs * + (self.num_spec + 1), self.compilation_config.max_capture_size) self.spec_state_indices_tensor = torch.empty( (self.decode_cudagraph_max_bs, self.num_spec + 1), @@ -194,9 +195,8 @@ def build( # type: ignore[override] dim=0, out=non_spec_query_start_loc[1:]) - num_spec_decode_tokens = min( - num_spec_decodes * (self.num_spec + 1), - spec_token_masks.size(0)) + num_spec_decode_tokens = (query_lens.sum().item() - + num_prefill_tokens - num_decode_tokens) assert num_accepted_tokens is not None num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] @@ -206,14 +206,22 @@ def build( # type: ignore[override] has_initial_state = has_initial_state[~spec_sequence_masks] else: has_initial_state = None + num_actual_tokens = num_prefill_tokens + num_decode_tokens + \ + num_spec_decode_tokens # prepare tensors for cudagraph + # + # With speculative decoding, the xgrammar backend may rollback tokens + # and causing some sequences has less draft tokens than self.num_spec. + # + # In above cases, the max possible batch size for n tokens, can be + # min(n, cudagraph_max_bs). if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0 and num_spec_decodes <= self.decode_cudagraph_max_bs - and m.num_actual_tokens <= self.decode_cudagraph_max_bs): - num_total_tokens = self.vllm_config.pad_for_cudagraph( + and num_spec_decode_tokens <= self.decode_cudagraph_max_bs): + num_actual_tokens = self.vllm_config.pad_for_cudagraph( m.num_actual_tokens) - batch_size = num_total_tokens // (self.num_spec + 1) + batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens) self.spec_state_indices_tensor[:num_spec_decodes].copy_( spec_state_indices_tensor, non_blocking=True) @@ -229,7 +237,7 @@ def build( # type: ignore[override] assert spec_token_masks is not None self.spec_token_masks[:spec_token_masks.size(0)].copy_( spec_token_masks, non_blocking=True) - spec_token_masks = self.spec_token_masks[:m.num_actual_tokens] + spec_token_masks = self.spec_token_masks[:num_actual_tokens] spec_token_masks[spec_token_masks.size(0):].fill_(False) self.spec_query_start_loc[:num_spec_decodes + 1].copy_( @@ -248,9 +256,9 @@ def build( # type: ignore[override] if (self.use_full_cuda_graph and num_prefills == 0 and num_spec_decodes == 0 and num_decodes <= self.decode_cudagraph_max_bs): - num_total_tokens = self.vllm_config.pad_for_cudagraph( + num_actual_tokens = self.vllm_config.pad_for_cudagraph( m.num_actual_tokens) - batch_size = num_total_tokens + batch_size = num_actual_tokens self.non_spec_state_indices_tensor[:num_decodes].copy_( non_spec_state_indices_tensor, non_blocking=True) @@ -274,6 +282,7 @@ def build( # type: ignore[override] num_decode_tokens=num_decode_tokens, num_spec_decodes=num_spec_decodes, num_spec_decode_tokens=num_spec_decode_tokens, + num_actual_tokens=num_actual_tokens, has_initial_state=has_initial_state, spec_query_start_loc=spec_query_start_loc, non_spec_query_start_loc=non_spec_query_start_loc,