Skip to content

Commit 66072b3

Browse files
authored
[Bugfix][Mamba] - Fix Conv State Kernel FP32 Support (#24883)
Signed-off-by: asafg <[email protected]>
1 parent 3ed1ec4 commit 66072b3

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,14 +418,17 @@ def test_full_cuda_graph(
418418
@pytest.mark.parametrize("model", FP32_STATE_MODELS)
419419
@pytest.mark.parametrize("max_tokens", [64])
420420
@pytest.mark.parametrize("num_logprobs", [5])
421-
def test_fp32_state(
421+
@pytest.mark.parametrize("cache_dtype_param",
422+
["mamba_ssm_cache_dtype", "mamba_cache_dtype"])
423+
def test_fp32_cache_state(
422424
hf_runner,
423425
vllm_runner,
424426
example_prompts,
425427
monkeypatch,
426428
model: str,
427429
max_tokens: int,
428430
num_logprobs: int,
431+
cache_dtype_param: str,
429432
) -> None:
430433

431434
try:
@@ -443,13 +446,13 @@ def test_fp32_state(
443446
m.setenv("VLLM_USE_V1", "0")
444447
with vllm_runner(model,
445448
max_num_seqs=MAX_NUM_SEQS,
446-
mamba_ssm_cache_dtype="float32") as vllm_model:
449+
**{cache_dtype_param: "float32"}) as vllm_model:
447450
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
448451
example_prompts, max_tokens, num_logprobs)
449452

450453
with vllm_runner(model,
451454
max_num_seqs=MAX_NUM_SEQS,
452-
mamba_ssm_cache_dtype="float32") as vllm_model:
455+
**{cache_dtype_param: "float32"}) as vllm_model:
453456
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
454457
example_prompts, max_tokens, num_logprobs)
455458

vllm/model_executor/layers/mamba/ops/causal_conv1d.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,9 @@ def causal_conv1d_fn(
415415
activation = "silu"
416416

417417
args = None
418+
# Store original dtype to cast back at the end
419+
original_x_dtype = x.dtype
420+
x = x.to(conv_states.dtype)
418421
out = torch.empty_like(x)
419422
if metadata is not None:
420423
cu_seqlen = metadata.cu_seqlen
@@ -613,7 +616,7 @@ def grid(META):
613616
BLOCK_N=256,
614617
num_stages=2,
615618
)
616-
return out
619+
return out.to(original_x_dtype)
617620

618621

619622
@triton.jit()
@@ -973,6 +976,9 @@ def causal_conv1d_update(
973976
activation = "silu" if activation is True else None
974977
elif activation is not None:
975978
assert activation in ["silu", "swish"]
979+
980+
original_x_dtype = x.dtype
981+
x = x.to(conv_state.dtype)
976982
unsqueeze = query_start_loc is None and x.dim() == 2
977983
if unsqueeze:
978984
# make it (batch, dim, seqlen) with seqlen == 1
@@ -1081,4 +1087,4 @@ def grid(META):
10811087
)
10821088
if unsqueeze:
10831089
out = out.squeeze(-1)
1084-
return out
1090+
return out.to(original_x_dtype)

0 commit comments

Comments
 (0)