-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Model][MiniMaxText01] Support MiniMaxText01 model inference #13454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
edddaf1
1719721
d61b446
faa8c6c
7c65c03
43f0152
d6e7798
5504867
6c3f08b
fad01e8
f674279
cb7c074
d48d375
989b488
0d4822d
b682944
9e9704a
5de6b1b
96c6dff
fc6ab05
a7f2e3a
bc17ba9
2f873a9
152b430
2e59aa7
9ff34fc
ed4ddca
4bee45b
e4fd74e
495a39a
e0dec3a
8f9891f
1774c66
88ec7c6
2aa1c0d
37e7fec
f1c8fb6
fc361d8
be625bf
95bdd4a
5b619bb
f46e997
fce7cae
f16f818
09c9cea
20d811a
aea72dc
925c02f
65c8274
01c5f9e
5a02fdf
f0e54a7
61b3820
727b572
09d044b
01c008a
078a836
42dc9b8
8005212
d30be90
c0581a3
4036f88
44d828b
25353a6
6147492
75fcabc
68d4549
1fdb4cc
358ba2d
703af1d
e8d5724
0c6a904
8663e13
e61d6e3
57471b8
ddabd28
7f32996
2ed7f2d
1107317
19ae251
2f1bed0
7bffe30
e791c9f
2bd8fcb
5483d26
19b1264
ea80155
33eecfa
c2abab4
2c04f99
c134e79
2850c68
2ac5d73
11c9b85
0aaac31
637ff5e
e4291f5
84ef836
cdf7ae6
05b6ac6
e61ac58
4d9b75d
56a9f5d
f252f56
73fd424
1fb2336
e5cec6f
c7d93c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import torch | ||
|
||
from vllm.model_executor.layers.lightning_attn import ( | ||
linear_decode_forward_triton) | ||
from vllm.platforms import current_platform | ||
|
||
NUM_HEADS = [4, 8] | ||
HEAD_SIZES = [64] | ||
BATCH_SIZES = [1, 2] | ||
SEQ_LENGTHS = [16] | ||
DTYPES = [torch.float32] | ||
|
||
|
||
def reference_lightning_attention(q, k, v, ed, block_size, kv_history): | ||
"""Reference implementation of lightning attention core algorithm | ||
|
||
The difference from the main implementation is that this processes | ||
each step sequentially, instead of using parallelized triton kernels | ||
""" | ||
B, H, S, D = q.shape | ||
E = v.shape[-1] | ||
dtype = q.dtype | ||
output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device) | ||
|
||
# Use clone() to ensure an independent copy | ||
if kv_history is None: | ||
kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device) | ||
else: | ||
kv_cache = kv_history.clone() | ||
|
||
# More efficient implementation | ||
# Convert decay factors to matrix form | ||
if ed.dim() == 1: | ||
decay = torch.exp(-ed).view(1, -1, 1, 1) | ||
else: | ||
decay = torch.exp(-ed) | ||
|
||
for b in range(B): | ||
for step in range(S): | ||
# Process all heads at once for this position | ||
q_bs = q[b, :, step] # [H, D] | ||
k_bs = k[b, :, step] # [H, D] | ||
v_bs = v[b, :, step] # [H, E] | ||
|
||
# Calculate KV outer products for all heads | ||
for h in range(H): | ||
# Calculate KV outer product | ||
kv_outer = torch.outer(k_bs[h], v_bs[h]) | ||
|
||
# Update KV cache with decay | ||
# Note: Using the same order as in the Triton kernel | ||
kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer | ||
|
||
# Calculate attention output | ||
output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h]) | ||
|
||
# Match the shape returned by the actual implementation | ||
# The actual implementation returns a tensor of shape [B, H, 2, D, E] | ||
# where dimension 2 contains both KV and KV history | ||
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E] | ||
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], | ||
dim=2) # [B, H, 2, D, E] | ||
|
||
return output, final_kv_cache | ||
|
||
|
||
def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx): | ||
"""Reference implementation: linear attention decode function""" | ||
B, H, _, D = q.shape | ||
output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device) | ||
|
||
# Calculate decay factors once (more efficient) | ||
decay = torch.exp(-slope_rate).view(-1, 1, 1) # [H, 1, 1] | ||
|
||
# Process each batch | ||
for b in range(B): | ||
slot_id = slot_idx[b].item() | ||
|
||
# Skip padding positions | ||
if slot_id == -1: | ||
continue | ||
|
||
# Process all heads at once for this batch | ||
q_b = q[b, :, 0] # [H, D] | ||
k_b = k[b, :, 0] # [H, D] | ||
v_b = v[b, :, 0] # [H, D] | ||
|
||
# Process each attention head | ||
for h in range(H): | ||
# Get current query, key and value | ||
q_bh = q_b[h] | ||
k_bh = k_b[h] | ||
v_bh = v_b[h] | ||
|
||
# Get cache | ||
kv_cache_old = kv_caches[b, h] | ||
|
||
# Calculate new key-value outer product | ||
kv_outer = torch.outer(k_bh, v_bh) | ||
|
||
# Apply decay and update cache | ||
kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old | ||
|
||
# Calculate output | ||
out_h = torch.matmul(q_bh, kv_new) | ||
|
||
# Update output and cache | ||
output[b, h * D:(h + 1) * D] = out_h | ||
kv_caches[b, h] = kv_new | ||
|
||
return output | ||
|
||
|
||
@pytest.mark.parametrize("batch_size", BATCH_SIZES) | ||
@pytest.mark.parametrize("num_heads", NUM_HEADS) | ||
@pytest.mark.parametrize("head_size", HEAD_SIZES) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@torch.inference_mode() | ||
def test_linear_decode_forward_triton( | ||
batch_size: int, | ||
num_heads: int, | ||
head_size: int, | ||
dtype: torch.dtype, | ||
): | ||
torch.set_default_device("cuda") | ||
torch.manual_seed(42) | ||
torch.cuda.manual_seed_all(42) | ||
current_platform.seed_everything(42) | ||
base = 0.01 | ||
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) | ||
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) | ||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) | ||
|
||
kv_caches = base * torch.randn(batch_size, | ||
num_heads, | ||
head_size, | ||
head_size, | ||
dtype=dtype, | ||
device="cuda") | ||
|
||
kv_caches_copy = kv_caches.clone() | ||
|
||
slope_rate = torch.zeros(num_heads, device="cuda") | ||
for h in range(num_heads): | ||
slope_rate[h] = 0.1 * (h + 1) | ||
|
||
slot_idx = torch.arange(batch_size, device="cuda") | ||
|
||
triton_output = linear_decode_forward_triton(q, k, v, kv_caches, | ||
slope_rate, slot_idx) | ||
|
||
reference_output = reference_linear_decode(q, k, v, kv_caches_copy, | ||
slope_rate, slot_idx) | ||
torch.testing.assert_close(triton_output, | ||
reference_output, | ||
rtol=1e-1, | ||
atol=1e-1) | ||
Comment on lines
+159
to
+160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1e-1 seems pretty high for |
||
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1) | ||
|
||
assert triton_output.shape == (batch_size, num_heads * head_size) | ||
|
||
|
||
@pytest.mark.parametrize("num_heads", NUM_HEADS) | ||
@pytest.mark.parametrize("head_size", HEAD_SIZES) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@torch.inference_mode() | ||
def test_linear_decode_forward_triton_with_padding( | ||
num_heads: int, | ||
head_size: int, | ||
dtype: torch.dtype, | ||
): | ||
torch.set_default_device("cuda") | ||
torch.manual_seed(42) | ||
torch.cuda.manual_seed_all(42) | ||
current_platform.seed_everything(42) | ||
|
||
batch_size = 4 | ||
base = 0.01 | ||
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) | ||
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) | ||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) | ||
|
||
kv_caches = base * torch.randn(batch_size, | ||
num_heads, | ||
head_size, | ||
head_size, | ||
dtype=dtype, | ||
device="cuda") | ||
|
||
kv_caches_copy = kv_caches.clone() | ||
|
||
slope_rate = torch.zeros(num_heads, device="cuda") | ||
for h in range(num_heads): | ||
slope_rate[h] = 0.1 * (h + 1) | ||
|
||
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda") | ||
|
||
triton_output = linear_decode_forward_triton(q, k, v, kv_caches, | ||
slope_rate, slot_idx) | ||
|
||
reference_output = reference_linear_decode(q, k, v, kv_caches_copy, | ||
slope_rate, slot_idx) | ||
|
||
padding_mask = (slot_idx | ||
!= -1).unsqueeze(1).expand(-1, num_heads * head_size) | ||
|
||
triton_masked = triton_output[padding_mask] | ||
reference_masked = reference_output[padding_mask] | ||
|
||
atol, rtol = 1.5e-1, 1.5e-1 | ||
|
||
valid_indices = slot_idx != -1 | ||
|
||
for i in range(batch_size): | ||
if valid_indices[i] > 0: | ||
torch.testing.assert_close(kv_caches[i], | ||
kv_caches_copy[i], | ||
rtol=rtol, | ||
atol=atol) | ||
|
||
torch.testing.assert_close(triton_masked, | ||
reference_masked, | ||
rtol=rtol, | ||
atol=atol) | ||
|
||
assert triton_output.shape == (batch_size, num_heads * head_size) | ||
|
||
|
||
@pytest.mark.parametrize("batch_size", BATCH_SIZES) | ||
@pytest.mark.parametrize("num_heads", NUM_HEADS) | ||
@pytest.mark.parametrize("head_size", HEAD_SIZES) | ||
@pytest.mark.parametrize("seq_len", SEQ_LENGTHS) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@torch.inference_mode() | ||
def test_lightning_attention_reference( | ||
batch_size: int, | ||
num_heads: int, | ||
head_size: int, | ||
seq_len: int, | ||
dtype: torch.dtype, | ||
): | ||
torch.set_default_device("cuda") | ||
torch.manual_seed(42) | ||
torch.cuda.manual_seed_all(42) | ||
current_platform.seed_everything(42) | ||
|
||
base = 0.01 | ||
q = base * torch.randn( | ||
batch_size, num_heads, seq_len, head_size, dtype=dtype) | ||
k = base * torch.randn( | ||
batch_size, num_heads, seq_len, head_size, dtype=dtype) | ||
v = base * torch.randn( | ||
batch_size, num_heads, seq_len, head_size, dtype=dtype) | ||
|
||
ed = torch.zeros(num_heads, device="cuda") | ||
for h in range(num_heads): | ||
ed[h] = 0.1 * (h + 1) | ||
|
||
kv_history = base * torch.randn(batch_size, | ||
num_heads, | ||
head_size, | ||
head_size, | ||
dtype=dtype, | ||
device="cuda") | ||
|
||
kv_history_clone = kv_history.clone() | ||
|
||
ref_output, ref_kv_cache = reference_lightning_attention( | ||
q, k, v, ed, 256, kv_history) | ||
|
||
from vllm.model_executor.layers.lightning_attn import lightning_attention | ||
actual_output, actual_kv_cache = lightning_attention( | ||
q, k, v, ed, 256, kv_history_clone) | ||
|
||
atol, rtol = 1.5e-1, 1.5e-1 | ||
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol) | ||
torch.testing.assert_close(ref_kv_cache, | ||
actual_kv_cache, | ||
rtol=rtol, | ||
atol=atol) | ||
|
||
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size) | ||
assert ref_kv_cache.shape == actual_kv_cache.shape |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -303,8 +303,11 @@ async def step_async( | |
ctx.seq_group_metadata_list = seq_group_metadata_list | ||
ctx.scheduler_outputs = scheduler_outputs | ||
|
||
finished_requests_ids = self.scheduler[ | ||
virtual_engine].get_and_reset_finished_requests_ids() | ||
if not scheduler_outputs.is_empty(): | ||
# this will cause mamba_cache/minimax_cache failed | ||
# to release finished_requests_ids of the last steps | ||
finished_requests_ids = self.scheduler[ | ||
virtual_engine].get_and_reset_finished_requests_ids() | ||
Comment on lines
-306
to
+310
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you describe in a bit more detail the problem you hit here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because |
||
|
||
# Maybe switch from async mode to sync mode | ||
if not allow_async_output_proc and len(ctx.output_queue) > 0: | ||
|
Uh oh!
There was an error while loading. Please reload this page.