-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[Model] Jamba support #4115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Jamba support #4115
Changes from all commits
337f67a
0330e14
07cc899
6d336f6
00bce1f
39c27b7
7c75868
30e6dcd
5c0efdc
19f11f3
1fb817a
30ae4a1
7bd9c0a
d5ac8e8
60b49b5
c583fe8
c951b7d
da6d0f2
eb79923
919edba
4668566
11a0737
7e3415e
adbd2ae
6daf2a2
7ee927b
87fa299
8bca3b6
b421877
d9c3319
1c0fad8
e831cfc
1033904
2b93182
b2f86f8
7061df7
054faf1
fb3fc83
6d8765d
10896ae
d1dc26f
07c8cd2
5c11285
2bb3360
a235c44
af7a4ac
988718e
49ce3df
4fa065f
7add09a
14fbab5
e3dec15
92778c4
db36427
7f6edfc
ee5f058
2d42367
6a6378c
1a8e2f9
eb89987
1cb8c1c
feca5d5
72c31cc
ddeb689
5d5a3be
85715fe
84aa88f
8c6d82d
628eec7
30030ce
3fba9bc
45f3d96
794f1c3
33eb405
25c03e7
94d40a8
976166f
8181821
4fdc35b
aadeca2
fee775e
668f3d9
10a44dc
cd9ba35
6df4f69
75dd84e
577f678
7bb332e
c051758
b6dc237
24b4bf2
b0b0836
e52e4d7
b4d49e0
68e27de
670ff3a
b7e31e3
571f63d
49da326
688732e
4a6b170
2047a91
f2c407f
5d932a4
3c15001
1ff2cdb
548f4e8
5a4b323
c92257c
60bb1a7
2ea2b80
10d8f3c
1331a8f
21c92b4
726ccad
4b6a491
da5d94a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Mamba dependencies | ||
mamba-ssm>=1.2.2 | ||
causal-conv1d>=1.2.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import pytest | ||
|
||
MODELS = ["ai21labs/Jamba-tiny-random"] | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
@pytest.mark.parametrize("max_tokens", [20]) | ||
def test_models( | ||
hf_runner, | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
max_tokens: int, | ||
) -> None: | ||
# To pass the small model tests, we need full precision. | ||
assert dtype == "float" | ||
mzusman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
with hf_runner(model, dtype=dtype) as hf_model: | ||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) | ||
|
||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) | ||
|
||
for i in range(len(example_prompts)): | ||
hf_output_ids, hf_output_str = hf_outputs[i] | ||
vllm_output_ids, vllm_output_str = vllm_outputs[i] | ||
assert hf_output_str == vllm_output_str, ( | ||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") | ||
assert hf_output_ids == vllm_output_ids, ( | ||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
def test_state_cleanup( | ||
vllm_runner, | ||
model: str, | ||
dtype: str, | ||
example_prompts, | ||
) -> None: | ||
# This test is for verifying that the Jamba state is cleaned up between | ||
# steps, If its not cleaned, an error would be expected. | ||
try: | ||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
for _ in range(10): | ||
vllm_model.generate_greedy([example_prompts[0]] * 100, 1) | ||
except ValueError: | ||
pytest.fail("Jamba inner state wasn't cleaned up between states, " | ||
"could be related to finished_requests_ids") | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float"]) | ||
def test_model_print( | ||
vllm_runner, | ||
model: str, | ||
dtype: str, | ||
) -> None: | ||
with vllm_runner(model, dtype=dtype) as vllm_model: | ||
# This test is for verifying whether the model's extra_repr | ||
# can be printed correctly. | ||
print(vllm_model.model.llm_engine.model_executor.driver_worker. | ||
model_runner.model) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -846,6 +846,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: | |
"as performance will be severely degraded otherwise.") | ||
seq_group_metadata_list, scheduler_outputs = self.scheduler[ | ||
0].schedule() | ||
finished_requests_ids = self.scheduler[ | ||
0].get_and_reset_finished_requests_ids() | ||
|
||
if not scheduler_outputs.is_empty(): | ||
execute_model_req = ExecuteModelRequest( | ||
|
@@ -855,7 +857,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: | |
blocks_to_copy=scheduler_outputs.blocks_to_copy, | ||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots, | ||
running_queue_size=scheduler_outputs.running_queue_size, | ||
) | ||
finished_requests_ids=finished_requests_ids) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mzusman If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's right 👍 Nice catch, I'll open a PR to fix it #6266 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have now realized that this code path is not hit it seems, at least in common circumstances due to this https://github.com/vllm-project/vllm/blob/main/vllm/engine/async_llm_engine.py#L563 but probably good to fix anyways! |
||
output = self.model_executor.execute_model( | ||
execute_model_req=execute_model_req) | ||
else: | ||
|
Uh oh!
There was an error while loading. Please reload this page.