Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
edddaf1
[Config][HybridModel] Enhance layer determination logic for hybrid mo…
qscqesze Mar 13, 2025
1719721
[Refactor][MiniMaxText] Update cache mapping reference in MiniMaxText…
qscqesze Mar 13, 2025
d61b446
[Refactor][AsyncLLM] Improve comments and clean up unused variables i…
qscqesze Mar 13, 2025
faa8c6c
[Refactor][MiniMaxText] Clean up imports and improve code formatting …
qscqesze Mar 13, 2025
7c65c03
[Refactor][Config] Improve formatting and error handling in ModelConf…
qscqesze Mar 13, 2025
43f0152
[Refactor][Config] Enhance layer counting logic in ModelConfig and im…
qscqesze Mar 13, 2025
d6e7798
[Refactor][Config] Improve formatting in VllmConfig for better readab…
qscqesze Mar 13, 2025
5504867
[Refactor][LightningAttn] Update grid configuration in _attention fun…
qscqesze Mar 13, 2025
6c3f08b
[Refactor][LightningAttn] Simplify grid configuration in _attention f…
qscqesze Mar 13, 2025
fad01e8
[Refactor][MiniMaxText] Enhance forward method in MiniMaxText01 model…
qscqesze Mar 13, 2025
f674279
[Refactor][MiniMaxText] Remove max_context_len parameter from MiniMax…
qscqesze Mar 13, 2025
cb7c074
[Refactor][MiniMaxText] Update forward method parameters in MiniMaxTe…
qscqesze Mar 13, 2025
d48d375
[Refactor][MiniMaxText] Add context_lens_tensor and slot_mapping to A…
qscqesze Mar 13, 2025
989b488
[Refactor][MiniMaxText] Remove unnecessary property methods from Atte…
qscqesze Mar 13, 2025
0d4822d
[Refactor][MiniMaxText] Simplify weight handling methods and improve …
qscqesze Mar 13, 2025
b682944
[Refactor][MiniMaxText] Clean up and optimize weight handling and par…
qscqesze Mar 13, 2025
9e9704a
[Refactor][MiniMaxText] Fix index handling in prefill loop of MiniMax…
qscqesze Mar 13, 2025
5de6b1b
[Refactor][MiniMaxText] Streamline handling of attn_metadata in forwa…
qscqesze Mar 13, 2025
96c6dff
[Refactor][MiniMaxText] Consolidate attn_metadata handling in MiniMax…
qscqesze Mar 13, 2025
fc6ab05
[Refactor][MiniMaxText] Remove kv_caches parameter from multiple meth…
qscqesze Mar 13, 2025
a7f2e3a
[Refactor][MiniMaxText] Enhance kv_cache handling in MiniMaxText01 mo…
qscqesze Mar 13, 2025
bc17ba9
[Refactor][MiniMaxText] Remove unused kv_caches parameter from _clear…
qscqesze Mar 13, 2025
2f873a9
[Refactor][MiniMaxText] Initialize kv_cache in multiple classes of Mi…
qscqesze Mar 13, 2025
152b430
[Refactor][MiniMaxText] Remove kv_cache initialization from MiniMaxTe…
qscqesze Mar 13, 2025
2e59aa7
[Refactor][MiniMaxText] Update forward method in MiniMaxText01 model …
qscqesze Mar 13, 2025
9ff34fc
[Refactor][MiniMaxText] Update forward method in MiniMaxText01 model …
qscqesze Mar 13, 2025
ed4ddca
[Refactor][MiniMaxText] Remove redundant closing parenthesis in MiniM…
qscqesze Mar 13, 2025
4bee45b
[Refactor][MiniMaxText] Set default number of hidden layers to 8 in M…
qscqesze Mar 13, 2025
e4fd74e
[Refactor][MiniMaxText] Remove hardcoded number of hidden layers in M…
qscqesze Mar 13, 2025
495a39a
[Refactor][MiniMaxText] Update forward method in MiniMaxText01 model …
qscqesze Mar 13, 2025
e0dec3a
[Refactor][MiniMaxText] Update forward method in MiniMaxText01 model …
qscqesze Mar 13, 2025
8f9891f
[Refactor][MiniMaxText] Initialize MinimaxCacheManager in MiniMaxText…
qscqesze Mar 13, 2025
1774c66
[Refactor][MiniMaxText] Simplify kv_cache handling in MiniMaxText01 m…
qscqesze Mar 13, 2025
88ec7c6
[Refactor][MiniMaxText] Reorder parameters in forward method of MiniM…
qscqesze Mar 13, 2025
2aa1c0d
[Refactor][MiniMaxText] Add attn_metadata parameter to forward method…
qscqesze Mar 13, 2025
37e7fec
[Refactor][MiniMaxText] Remove kv_cache initialization in MiniMaxText…
qscqesze Mar 13, 2025
f1c8fb6
[Refactor][MiniMaxText] Update forward method in MiniMaxText01 model …
qscqesze Mar 13, 2025
fc361d8
[Refactor][MiniMaxText] Correctly define NUM_FBLOCK as a constexpr in…
qscqesze Mar 13, 2025
be625bf
[Refactor][LightningAttention] Improve code readability and consisten…
qscqesze Mar 13, 2025
95bdd4a
[Refactor][MiniMaxText] Simplify forward method in MiniMaxText01 mode…
qscqesze Mar 13, 2025
5b619bb
[Refactor][MiniMaxText] Update kv_cache initialization in MiniMaxText…
qscqesze Mar 13, 2025
f46e997
[Refactor][LightningAttention] Enhance code readability in lightning_…
qscqesze Mar 13, 2025
fce7cae
[Refactor][LightningAttention] Optimize grid calculations in lightnin…
qscqesze Mar 13, 2025
f16f818
[Refactor][MiniMaxText] Remove unused weight2param_match and weight2p…
qscqesze Mar 18, 2025
09c9cea
[Refactor][MiniMaxText] Refactor layer initialization in MiniMaxText0…
qscqesze Mar 18, 2025
20d811a
[Update][SupportedModels] Add MiniMaxText01 model to the supported mo…
qscqesze Mar 18, 2025
aea72dc
[Refactor][MiniMaxText] Clean up formatting and improve readability i…
qscqesze Mar 18, 2025
925c02f
Merge remote-tracking branch 'origin/main' into qinggangying/vllm
qscqesze Mar 19, 2025
65c8274
[Model] Refactor layer block type handling in ModelConfig for improve…
qscqesze Mar 19, 2025
01c5f9e
Merge branch 'vllm-project:main' into qinggangying/vllm
qscqesze Mar 20, 2025
5a02fdf
Merge branch 'vllm-project:main' into qinggangying/vllm
qscqesze Mar 20, 2025
f0e54a7
Refactor MiniMaxText01 model: import make_layers utility and initiali…
qscqesze Mar 20, 2025
61b3820
Enhance error handling in model execution: return None for None hidde…
qscqesze Mar 20, 2025
727b572
Refactor MiniMaxText01 model: move None check for attn_metadata to af…
qscqesze Mar 20, 2025
09d044b
Refactor MiniMaxText01 model: replace direct access to attn_metadata.…
qscqesze Mar 20, 2025
01c008a
Merge branch 'vllm-project:main' into qinggangying/vllm
qscqesze Mar 25, 2025
078a836
[Enhancement][Tests] Add comprehensive tests for lightning attention …
qscqesze Mar 25, 2025
42dc9b8
[Refactor][GPU] Simplify dummy run and sampler execution in GPU model…
qscqesze Mar 25, 2025
8005212
[Refactor][Tests] Clean up formatting and comments in lightning atten…
qscqesze Mar 25, 2025
d30be90
[Refactor][Attention] Enhance kernel functions and parameter handling…
qscqesze Mar 25, 2025
c0581a3
[Refactor][Attention] Improve clarity and structure in lightning atte…
qscqesze Mar 25, 2025
4036f88
[Refactor][Tests] Update decay calculation in linear decode forward test
qscqesze Mar 25, 2025
44d828b
[Refactor][Tests] Update decay handling in lightning attention tests
qscqesze Mar 25, 2025
25353a6
[Refactor][Tests] Update lightning attention tests to skip incompatib…
qscqesze Mar 25, 2025
6147492
[Refactor][Tests] Enhance lightning attention tests to handle bfloat1…
qscqesze Mar 25, 2025
75fcabc
[Refactor][Tests] Remove bfloat16 handling and clean up lightning att…
qscqesze Mar 25, 2025
68d4549
[Refactor][Tests] Update decay tensor handling in lightning attention…
qscqesze Mar 25, 2025
1fdb4cc
[Refactor][Tests] Remove deprecated lightning attention tests
qscqesze Mar 25, 2025
358ba2d
[Refactor][Tests] Expand data type support in lightning attention tests
qscqesze Mar 25, 2025
703af1d
Fix variable name in lightning attention layer to correct tensor load…
qscqesze Mar 25, 2025
e8d5724
Refactor lightning attention integration in MiniMaxText01 model
qscqesze Mar 25, 2025
0c6a904
Add assertion for dimension divisibility in lightning attention
qscqesze Mar 25, 2025
8663e13
Add reference implementation for linear attention decoding in tests
qscqesze Mar 27, 2025
e61d6e3
Enhance lightning attention tests with reference implementation and a…
qscqesze Mar 28, 2025
57471b8
Fix typos and enhance data type consistency in lightning attention im…
qscqesze Mar 28, 2025
ddabd28
Enhance data type handling in linear decode function of lightning att…
qscqesze Mar 28, 2025
7f32996
Refactor linear attention decoding kernel for improved clarity and pe…
qscqesze Mar 28, 2025
2ed7f2d
Refactor and enhance lightning attention tests for clarity and functi…
qscqesze Mar 28, 2025
1107317
Refactor linear attention decoding kernel to improve efficiency and c…
qscqesze Mar 28, 2025
19ae251
Refactor linear attention decoding kernel and tests for improved clar…
qscqesze Mar 28, 2025
2f1bed0
Enhance linear decode tests by incorporating padding mask for accurat…
qscqesze Mar 28, 2025
7bffe30
Refactor linear attention decoding kernel to improve handling of padd…
qscqesze Mar 28, 2025
e791c9f
Add reference test for lightning attention consistency
qscqesze Mar 28, 2025
2bd8fcb
Fix typo in reference implementation comment and streamline tensor ha…
qscqesze Mar 28, 2025
5483d26
Refactor lightning attention test for improved clarity and consistency
qscqesze Mar 28, 2025
19b1264
Update lightning attention tests to relax tolerance levels and addres…
qscqesze Mar 28, 2025
ea80155
Update tolerance levels in lightning attention tests for improved acc…
qscqesze Mar 28, 2025
33eecfa
Refactor reference implementation of lightning attention for clarity …
qscqesze Mar 28, 2025
c2abab4
Refactor lightning attention test to improve error handling and data …
qscqesze Mar 28, 2025
2c04f99
Update data type handling in lightning attention test for consistency
qscqesze Mar 28, 2025
c134e79
Update data type in lightning attention test to float32 for consistency
qscqesze Mar 28, 2025
2850c68
Refactor lightning attention implementation for improved efficiency a…
qscqesze Mar 28, 2025
2ac5d73
Refactor lightning attention implementation for enhanced efficiency a…
qscqesze Mar 28, 2025
11c9b85
Enhance numerical stability and efficiency in lightning attention imp…
qscqesze Mar 28, 2025
0aaac31
Optimize lightning attention implementation for efficiency and clarity
qscqesze Mar 28, 2025
637ff5e
Refactor lightning attention test for improved resource management an…
qscqesze Mar 28, 2025
e4291f5
Enhance lightning attention implementation for improved numerical sta…
qscqesze Mar 28, 2025
84ef836
Refine lightning attention implementation to match output shape and e…
qscqesze Mar 28, 2025
cdf7ae6
Update lightning attention test parameters for simplification
qscqesze Mar 28, 2025
05b6ac6
Refactor lightning attention test for improved readability
qscqesze Mar 28, 2025
e61ac58
Refactor lightning attention tests to simplify tensor initialization
qscqesze Mar 31, 2025
4d9b75d
Fix formatting in lightning attention test by removing unnecessary wh…
qscqesze Mar 31, 2025
56a9f5d
Refactor ConstantSizeCache and MiniMaxText01 for improved clarity and…
qscqesze Mar 31, 2025
f252f56
Update tensor initialization in lightning attention tests to use rand…
qscqesze Apr 1, 2025
73fd424
Update lightning attention test to initialize KV cache with zeros and…
qscqesze Apr 1, 2025
1fb2336
Refactor tensor initialization in lightning attention tests to use sc…
qscqesze Apr 1, 2025
e5cec6f
Refactor formatting in lightning attention tests for improved readabi…
qscqesze Apr 1, 2025
c7d93c1
Merge branch 'vllm-project:main' into qinggangying/vllm
qscqesze Apr 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
* ✅︎
* ✅︎
- * `MiniMaxText01ForCausalLM`
* MiniMax-Text
* `MiniMaxAI/MiniMax-Text-01`, etc.
*
* ✅︎
- * `Zamba2ForCausalLM`
* Zamba2
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
Expand Down
286 changes: 286 additions & 0 deletions tests/kernels/test_lightning_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from vllm.model_executor.layers.lightning_attn import (
linear_decode_forward_triton)
from vllm.platforms import current_platform

NUM_HEADS = [4, 8]
HEAD_SIZES = [64]
BATCH_SIZES = [1, 2]
SEQ_LENGTHS = [16]
DTYPES = [torch.float32]


def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
"""Reference implementation of lightning attention core algorithm

The difference from the main implementation is that this processes
each step sequentially, instead of using parallelized triton kernels
"""
B, H, S, D = q.shape
E = v.shape[-1]
dtype = q.dtype
output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device)

# Use clone() to ensure an independent copy
if kv_history is None:
kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device)
else:
kv_cache = kv_history.clone()

# More efficient implementation
# Convert decay factors to matrix form
if ed.dim() == 1:
decay = torch.exp(-ed).view(1, -1, 1, 1)
else:
decay = torch.exp(-ed)

for b in range(B):
for step in range(S):
# Process all heads at once for this position
q_bs = q[b, :, step] # [H, D]
k_bs = k[b, :, step] # [H, D]
v_bs = v[b, :, step] # [H, E]

# Calculate KV outer products for all heads
for h in range(H):
# Calculate KV outer product
kv_outer = torch.outer(k_bs[h], v_bs[h])

# Update KV cache with decay
# Note: Using the same order as in the Triton kernel
kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer

# Calculate attention output
output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h])

# Match the shape returned by the actual implementation
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
# where dimension 2 contains both KV and KV history
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
dim=2) # [B, H, 2, D, E]

return output, final_kv_cache


def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
"""Reference implementation: linear attention decode function"""
B, H, _, D = q.shape
output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device)

# Calculate decay factors once (more efficient)
decay = torch.exp(-slope_rate).view(-1, 1, 1) # [H, 1, 1]

# Process each batch
for b in range(B):
slot_id = slot_idx[b].item()

# Skip padding positions
if slot_id == -1:
continue

# Process all heads at once for this batch
q_b = q[b, :, 0] # [H, D]
k_b = k[b, :, 0] # [H, D]
v_b = v[b, :, 0] # [H, D]

# Process each attention head
for h in range(H):
# Get current query, key and value
q_bh = q_b[h]
k_bh = k_b[h]
v_bh = v_b[h]

# Get cache
kv_cache_old = kv_caches[b, h]

# Calculate new key-value outer product
kv_outer = torch.outer(k_bh, v_bh)

# Apply decay and update cache
kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old

# Calculate output
out_h = torch.matmul(q_bh, kv_new)

# Update output and cache
output[b, h * D:(h + 1) * D] = out_h
kv_caches[b, h] = kv_new

return output


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_linear_decode_forward_triton(
batch_size: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)
base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)

kv_caches = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")

kv_caches_copy = kv_caches.clone()

slope_rate = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)

slot_idx = torch.arange(batch_size, device="cuda")

triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
slope_rate, slot_idx)

reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
slope_rate, slot_idx)
torch.testing.assert_close(triton_output,
reference_output,
rtol=1e-1,
atol=1e-1)
Comment on lines +159 to +160
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1e-1 seems pretty high for rtol and atol here, so wondering if that's related to the above initialization

torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)

assert triton_output.shape == (batch_size, num_heads * head_size)


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_linear_decode_forward_triton_with_padding(
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)

batch_size = 4
base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)

kv_caches = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")

kv_caches_copy = kv_caches.clone()

slope_rate = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)

slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")

triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
slope_rate, slot_idx)

reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
slope_rate, slot_idx)

padding_mask = (slot_idx
!= -1).unsqueeze(1).expand(-1, num_heads * head_size)

triton_masked = triton_output[padding_mask]
reference_masked = reference_output[padding_mask]

atol, rtol = 1.5e-1, 1.5e-1

valid_indices = slot_idx != -1

for i in range(batch_size):
if valid_indices[i] > 0:
torch.testing.assert_close(kv_caches[i],
kv_caches_copy[i],
rtol=rtol,
atol=atol)

torch.testing.assert_close(triton_masked,
reference_masked,
rtol=rtol,
atol=atol)

assert triton_output.shape == (batch_size, num_heads * head_size)


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENGTHS)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_lightning_attention_reference(
batch_size: int,
num_heads: int,
head_size: int,
seq_len: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42)

base = 0.01
q = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
k = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)
v = base * torch.randn(
batch_size, num_heads, seq_len, head_size, dtype=dtype)

ed = torch.zeros(num_heads, device="cuda")
for h in range(num_heads):
ed[h] = 0.1 * (h + 1)

kv_history = base * torch.randn(batch_size,
num_heads,
head_size,
head_size,
dtype=dtype,
device="cuda")

kv_history_clone = kv_history.clone()

ref_output, ref_kv_cache = reference_lightning_attention(
q, k, v, ed, 256, kv_history)

from vllm.model_executor.layers.lightning_attn import lightning_attention
actual_output, actual_kv_cache = lightning_attention(
q, k, v, ed, 256, kv_history_clone)

atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
torch.testing.assert_close(ref_kv_cache,
actual_kv_cache,
rtol=rtol,
atol=atol)

assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
assert ref_kv_cache.shape == actual_kv_cache.shape
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def check_available_online(
trust_remote_code=True),
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
trust_remote_code=True),
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
trust_remote_code=True),
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501
Expand Down
42 changes: 25 additions & 17 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,26 +971,34 @@ def get_num_layers_by_block_type(
return sum(not bc.attention.no_op
for bc in block_configs[start:end])
else:
# Hybrid model
# Hybrid model Jamba
layers_block_type_value = getattr(self.hf_config,
"layers_block_type", None)
if layers_block_type_value is None:
raise ValueError("The model is an hybrid without a "
"layers_block_type in the hf_config, "
"cannot determine the num of "
f"{block_type.value} layers")

if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
== "zamba2"):
if attn_block_type:
return sum(t == "hybrid"
for t in layers_block_type_value[start:end])
else:
return self.get_num_layers(parallel_config)
if layers_block_type_value is not None:
if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
== "zamba2"):
if attn_block_type:
return sum(t == "hybrid"
for t in layers_block_type_value[start:end])
else:
return self.get_num_layers(parallel_config)
return sum(t == block_type.value
for t in layers_block_type_value[start:end])

# Hybrid model Minimax
attn_type_list = getattr(self.hf_config, "attn_type_list", None)
if attn_type_list:
return sum(t == 1 for t in attn_type_list[start:end])

if layers_block_type_value is None and attn_type_list is None:
raise ValueError(
"The model is an hybrid without a"
"layers_block_type or an attn_type_list in the hf_config,"
"cannot determine the num of "
f"{block_type.value} layers")

return sum(t == block_type.value
for t in layers_block_type_value[start:end])
return sum(t == 1 for t in attn_type_list[start:end])

def get_multimodal_config(self) -> "MultiModalConfig":
"""
Expand Down
7 changes: 5 additions & 2 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,11 @@ async def step_async(
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs

finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
if not scheduler_outputs.is_empty():
# this will cause mamba_cache/minimax_cache failed
# to release finished_requests_ids of the last steps
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
Comment on lines -306 to +310
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you describe in a bit more detail the problem you hit here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because finished_requests_ids is always deleted from the mamba cache during the next model inference, but when scheduler_outputs is empty, no model inference step is executed, and here it gets emptied (self._finished_requests_ids = list() in file vllm/core/scheduler.py), this can lead to a series of issues.


# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
Expand Down
Loading