Skip to content

Commit 1248e85

Browse files
wenxcsYour Namezelinmslinzeqipku
authored
[Model] Adding support for MSFT Phi-3.5-MoE (#7729)
Co-authored-by: Your Name <[email protected]> Co-authored-by: Zeqi Lin <[email protected]> Co-authored-by: Zeqi Lin <[email protected]>
1 parent 2684efc commit 1248e85

13 files changed

+1255
-82
lines changed

docs/source/models/supported_models.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ Decoder-only Language Models
147147
- Phi-3-Small
148148
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
149149
-
150+
* - :code:`PhiMoEForCausalLM`
151+
- Phi-3.5-MoE
152+
- :code:`microsoft/Phi-3.5-MoE-instruct`, etc.
153+
-
150154
* - :code:`PersimmonForCausalLM`
151155
- Persimmon
152156
- :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc.

tests/models/test_phimoe.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Compare the outputs of HF and vLLM for moe models using greedy sampling.
2+
3+
Run `pytest tests/models/test_phimoe.py`.
4+
"""
5+
import pytest
6+
import torch
7+
8+
from vllm.utils import is_cpu
9+
10+
from .utils import check_logprobs_close
11+
12+
MODELS = [
13+
"microsoft/Phi-3.5-MoE-instruct",
14+
]
15+
16+
17+
def test_phimoe_routing_function():
18+
from vllm.model_executor.models.phimoe import phimoe_routing_function
19+
test_case = {
20+
0: {
21+
"hidden_states":
22+
torch.tensor([1, 2, 3, 4, 5, 6, 7, 8],
23+
dtype=torch.float32,
24+
requires_grad=False).view(4, 2),
25+
"gating_output":
26+
torch.tensor([0.1, 0.2, 0.3, 0.4],
27+
dtype=torch.float32,
28+
requires_grad=False),
29+
"topk":
30+
2,
31+
"renormalize":
32+
False,
33+
},
34+
1: {
35+
"hidden_states":
36+
torch.tensor([1, 2, 3, 4, 5, 6, 7, 8],
37+
dtype=torch.float32,
38+
requires_grad=False).view(4, 2),
39+
"gating_output":
40+
torch.tensor([0.4, 0.2, 0.3, 0.4],
41+
dtype=torch.float32,
42+
requires_grad=False),
43+
"topk":
44+
2,
45+
"renormalize":
46+
False,
47+
}
48+
}
49+
50+
ground_truth = {
51+
0: {
52+
"topk_weights":
53+
torch.tensor([1., 1.], dtype=torch.float32, requires_grad=False),
54+
"topk_ids":
55+
torch.tensor([3, 2], dtype=torch.long, requires_grad=False),
56+
},
57+
1: {
58+
"topk_weights":
59+
torch.tensor([0.5, 1.], dtype=torch.float32, requires_grad=False),
60+
"topk_ids":
61+
torch.tensor([0, 3], dtype=torch.long, requires_grad=False),
62+
}
63+
}
64+
65+
for test_id in test_case:
66+
topk_weights, topk_ids = phimoe_routing_function(**test_case[test_id])
67+
assert torch.allclose(topk_weights,
68+
ground_truth[test_id]["topk_weights"])
69+
assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"])
70+
71+
72+
def get_gpu_memory():
73+
try:
74+
props = torch.cuda.get_device_properties(torch.cuda.current_device())
75+
gpu_memory = props.total_memory / (1024**3)
76+
return gpu_memory
77+
except Exception:
78+
return 0
79+
80+
81+
@pytest.mark.skipif(condition=is_cpu(),
82+
reason="This test takes a lot time to run on CPU, "
83+
"and vllm CI's disk space is not enough for this model.")
84+
@pytest.mark.skipif(condition=get_gpu_memory() < 100,
85+
reason="Skip this test if GPU memory is insufficient.")
86+
@pytest.mark.parametrize("model", MODELS)
87+
@pytest.mark.parametrize("dtype", ["bfloat16"])
88+
@pytest.mark.parametrize("max_tokens", [64])
89+
@pytest.mark.parametrize("num_logprobs", [5])
90+
def test_models(
91+
hf_runner,
92+
vllm_runner,
93+
example_prompts,
94+
model: str,
95+
dtype: str,
96+
max_tokens: int,
97+
num_logprobs: int,
98+
) -> None:
99+
with hf_runner(model, dtype=dtype) as hf_model:
100+
hf_outputs = hf_model.generate_greedy_logprobs_limit(
101+
example_prompts, max_tokens, num_logprobs)
102+
103+
with vllm_runner(model, dtype=dtype) as vllm_model:
104+
vllm_outputs = vllm_model.generate_greedy_logprobs(
105+
example_prompts, max_tokens, num_logprobs)
106+
check_logprobs_close(
107+
outputs_0_lst=hf_outputs,
108+
outputs_1_lst=vllm_outputs,
109+
name_0="hf",
110+
name_1="vllm",
111+
)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
{
2+
"3328": {
3+
"BLOCK_SIZE_M": 64,
4+
"BLOCK_SIZE_N": 256,
5+
"BLOCK_SIZE_K": 64,
6+
"GROUP_SIZE_M": 16,
7+
"num_warps": 4,
8+
"num_stages": 2
9+
},
10+
"1024": {
11+
"BLOCK_SIZE_M": 64,
12+
"BLOCK_SIZE_N": 256,
13+
"BLOCK_SIZE_K": 32,
14+
"GROUP_SIZE_M": 32,
15+
"num_warps": 4,
16+
"num_stages": 4
17+
},
18+
"3072": {
19+
"BLOCK_SIZE_M": 64,
20+
"BLOCK_SIZE_N": 256,
21+
"BLOCK_SIZE_K": 64,
22+
"GROUP_SIZE_M": 32,
23+
"num_warps": 4,
24+
"num_stages": 2
25+
},
26+
"256": {
27+
"BLOCK_SIZE_M": 32,
28+
"BLOCK_SIZE_N": 256,
29+
"BLOCK_SIZE_K": 128,
30+
"GROUP_SIZE_M": 8,
31+
"num_warps": 4,
32+
"num_stages": 4
33+
},
34+
"768": {
35+
"BLOCK_SIZE_M": 128,
36+
"BLOCK_SIZE_N": 128,
37+
"BLOCK_SIZE_K": 64,
38+
"GROUP_SIZE_M": 8,
39+
"num_warps": 4,
40+
"num_stages": 4
41+
},
42+
"1792": {
43+
"BLOCK_SIZE_M": 128,
44+
"BLOCK_SIZE_N": 128,
45+
"BLOCK_SIZE_K": 64,
46+
"GROUP_SIZE_M": 16,
47+
"num_warps": 4,
48+
"num_stages": 4
49+
},
50+
"2560": {
51+
"BLOCK_SIZE_M": 64,
52+
"BLOCK_SIZE_N": 256,
53+
"BLOCK_SIZE_K": 64,
54+
"GROUP_SIZE_M": 32,
55+
"num_warps": 4,
56+
"num_stages": 2
57+
},
58+
"2816": {
59+
"BLOCK_SIZE_M": 128,
60+
"BLOCK_SIZE_N": 128,
61+
"BLOCK_SIZE_K": 64,
62+
"GROUP_SIZE_M": 16,
63+
"num_warps": 4,
64+
"num_stages": 4
65+
},
66+
"3584": {
67+
"BLOCK_SIZE_M": 64,
68+
"BLOCK_SIZE_N": 256,
69+
"BLOCK_SIZE_K": 64,
70+
"GROUP_SIZE_M": 32,
71+
"num_warps": 4,
72+
"num_stages": 2
73+
},
74+
"1536": {
75+
"BLOCK_SIZE_M": 64,
76+
"BLOCK_SIZE_N": 256,
77+
"BLOCK_SIZE_K": 64,
78+
"GROUP_SIZE_M": 64,
79+
"num_warps": 4,
80+
"num_stages": 2
81+
},
82+
"2048": {
83+
"BLOCK_SIZE_M": 64,
84+
"BLOCK_SIZE_N": 256,
85+
"BLOCK_SIZE_K": 64,
86+
"GROUP_SIZE_M": 64,
87+
"num_warps": 4,
88+
"num_stages": 2
89+
},
90+
"512": {
91+
"BLOCK_SIZE_M": 64,
92+
"BLOCK_SIZE_N": 256,
93+
"BLOCK_SIZE_K": 64,
94+
"GROUP_SIZE_M": 8,
95+
"num_warps": 4,
96+
"num_stages": 4
97+
},
98+
"3840": {
99+
"BLOCK_SIZE_M": 128,
100+
"BLOCK_SIZE_N": 128,
101+
"BLOCK_SIZE_K": 64,
102+
"GROUP_SIZE_M": 16,
103+
"num_warps": 4,
104+
"num_stages": 4
105+
},
106+
"1280": {
107+
"BLOCK_SIZE_M": 64,
108+
"BLOCK_SIZE_N": 256,
109+
"BLOCK_SIZE_K": 64,
110+
"GROUP_SIZE_M": 64,
111+
"num_warps": 4,
112+
"num_stages": 2
113+
},
114+
"2304": {
115+
"BLOCK_SIZE_M": 64,
116+
"BLOCK_SIZE_N": 256,
117+
"BLOCK_SIZE_K": 64,
118+
"GROUP_SIZE_M": 32,
119+
"num_warps": 4,
120+
"num_stages": 2
121+
},
122+
"4096": {
123+
"BLOCK_SIZE_M": 64,
124+
"BLOCK_SIZE_N": 256,
125+
"BLOCK_SIZE_K": 64,
126+
"GROUP_SIZE_M": 32,
127+
"num_warps": 4,
128+
"num_stages": 2
129+
}
130+
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
{
2+
"3840": {
3+
"BLOCK_SIZE_M": 128,
4+
"BLOCK_SIZE_N": 128,
5+
"BLOCK_SIZE_K": 64,
6+
"GROUP_SIZE_M": 8,
7+
"num_warps": 4,
8+
"num_stages": 4
9+
},
10+
"1792": {
11+
"BLOCK_SIZE_M": 128,
12+
"BLOCK_SIZE_N": 128,
13+
"BLOCK_SIZE_K": 64,
14+
"GROUP_SIZE_M": 8,
15+
"num_warps": 4,
16+
"num_stages": 4
17+
},
18+
"3584": {
19+
"BLOCK_SIZE_M": 64,
20+
"BLOCK_SIZE_N": 256,
21+
"BLOCK_SIZE_K": 64,
22+
"GROUP_SIZE_M": 16,
23+
"num_warps": 4,
24+
"num_stages": 2
25+
},
26+
"512": {
27+
"BLOCK_SIZE_M": 64,
28+
"BLOCK_SIZE_N": 256,
29+
"BLOCK_SIZE_K": 64,
30+
"GROUP_SIZE_M": 16,
31+
"num_warps": 4,
32+
"num_stages": 2
33+
},
34+
"3072": {
35+
"BLOCK_SIZE_M": 64,
36+
"BLOCK_SIZE_N": 256,
37+
"BLOCK_SIZE_K": 64,
38+
"GROUP_SIZE_M": 32,
39+
"num_warps": 4,
40+
"num_stages": 2
41+
},
42+
"2048": {
43+
"BLOCK_SIZE_M": 64,
44+
"BLOCK_SIZE_N": 256,
45+
"BLOCK_SIZE_K": 64,
46+
"GROUP_SIZE_M": 16,
47+
"num_warps": 4,
48+
"num_stages": 2
49+
},
50+
"2816": {
51+
"BLOCK_SIZE_M": 128,
52+
"BLOCK_SIZE_N": 256,
53+
"BLOCK_SIZE_K": 32,
54+
"GROUP_SIZE_M": 32,
55+
"num_warps": 8,
56+
"num_stages": 4
57+
},
58+
"1280": {
59+
"BLOCK_SIZE_M": 64,
60+
"BLOCK_SIZE_N": 256,
61+
"BLOCK_SIZE_K": 64,
62+
"GROUP_SIZE_M": 64,
63+
"num_warps": 4,
64+
"num_stages": 2
65+
},
66+
"768": {
67+
"BLOCK_SIZE_M": 128,
68+
"BLOCK_SIZE_N": 128,
69+
"BLOCK_SIZE_K": 64,
70+
"GROUP_SIZE_M": 1,
71+
"num_warps": 4,
72+
"num_stages": 4
73+
},
74+
"4096": {
75+
"BLOCK_SIZE_M": 128,
76+
"BLOCK_SIZE_N": 128,
77+
"BLOCK_SIZE_K": 64,
78+
"GROUP_SIZE_M": 8,
79+
"num_warps": 4,
80+
"num_stages": 4
81+
},
82+
"3328": {
83+
"BLOCK_SIZE_M": 64,
84+
"BLOCK_SIZE_N": 256,
85+
"BLOCK_SIZE_K": 64,
86+
"GROUP_SIZE_M": 32,
87+
"num_warps": 4,
88+
"num_stages": 2
89+
},
90+
"2560": {
91+
"BLOCK_SIZE_M": 128,
92+
"BLOCK_SIZE_N": 128,
93+
"BLOCK_SIZE_K": 64,
94+
"GROUP_SIZE_M": 8,
95+
"num_warps": 4,
96+
"num_stages": 4
97+
},
98+
"1024": {
99+
"BLOCK_SIZE_M": 64,
100+
"BLOCK_SIZE_N": 256,
101+
"BLOCK_SIZE_K": 32,
102+
"GROUP_SIZE_M": 8,
103+
"num_warps": 4,
104+
"num_stages": 4
105+
},
106+
"2304": {
107+
"BLOCK_SIZE_M": 64,
108+
"BLOCK_SIZE_N": 256,
109+
"BLOCK_SIZE_K": 64,
110+
"GROUP_SIZE_M": 16,
111+
"num_warps": 4,
112+
"num_stages": 2
113+
},
114+
"1536": {
115+
"BLOCK_SIZE_M": 64,
116+
"BLOCK_SIZE_N": 256,
117+
"BLOCK_SIZE_K": 64,
118+
"GROUP_SIZE_M": 32,
119+
"num_warps": 4,
120+
"num_stages": 2
121+
},
122+
"256": {
123+
"BLOCK_SIZE_M": 64,
124+
"BLOCK_SIZE_N": 256,
125+
"BLOCK_SIZE_K": 64,
126+
"GROUP_SIZE_M": 1,
127+
"num_warps": 4,
128+
"num_stages": 4
129+
}
130+
}

0 commit comments

Comments
 (0)