-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[Misc] Add triton_kernels dependency #27370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,3 +13,5 @@ torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytor | |
| # xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8 | ||
| # FlashInfer should be updated together with the Dockerfile | ||
| flashinfer-python==0.4.1 | ||
| # Triton Kernels are needed for mxfp4 fused moe. (Should be updated alongside torch) | ||
| triton_kernels @ git+https://github.com/triton-lang/[email protected]#subdirectory=python/triton_kernels | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,15 +23,9 @@ | |
| from triton_kernels.testing import assert_close | ||
|
|
||
| from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig | ||
| from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( | ||
| BatchedPrepareAndFinalize, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk | ||
| from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( | ||
| BatchedOAITritonExperts, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| triton_kernel_moe_forward, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel | ||
| from vllm.model_executor.layers.utils import shuffle_weight | ||
| from vllm.utils import round_up | ||
|
|
||
|
|
@@ -302,8 +296,8 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): | |
| quant_config = FusedMoEQuantConfig.make( | ||
| w1_bias=w1_bias_tri, | ||
| w2_bias=w2_bias_tri, | ||
| w1_precision=pc1, | ||
| w2_precision=pc2, | ||
| w1_scale=pc1, | ||
| w2_scale=pc2, | ||
| ) | ||
|
|
||
| out_triton_monolithic = triton_kernel_moe_forward( | ||
|
|
@@ -329,115 +323,6 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): | |
| assert_close(ref=out_ref, tri=out_triton_monolithic, maxtol=0.025, rmstol=0.005) | ||
|
|
||
|
|
||
| def batched_moe( | ||
| a: torch.Tensor, | ||
| w1, | ||
| w2, | ||
| gating_output: torch.Tensor, | ||
| topk: int, | ||
| renormalize: bool, | ||
| w1_bias: torch.Tensor, | ||
| w2_bias: torch.Tensor, | ||
| w1_precision: PrecisionConfig, | ||
| w2_precision: PrecisionConfig, | ||
| ) -> torch.Tensor: | ||
| max_num_tokens = round_up(a.shape[0], 64) | ||
|
|
||
| quant_config = FusedMoEQuantConfig.make( | ||
| w1_precision=w1_precision, | ||
| w2_precision=w2_precision, | ||
| w1_bias=w1_bias, | ||
| w2_bias=w2_bias, | ||
| ) | ||
|
|
||
| fused_experts = FusedMoEModularKernel( | ||
| BatchedPrepareAndFinalize( | ||
| max_num_tokens, | ||
| num_dispatchers=1, | ||
| num_local_experts=w1.shape[0], | ||
| rank=0, | ||
| ), | ||
| BatchedOAITritonExperts( | ||
| max_num_tokens=max_num_tokens, | ||
| num_dispatchers=1, | ||
| quant_config=quant_config, | ||
| ), | ||
| ) | ||
|
|
||
| topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize) | ||
|
|
||
| return fused_experts( | ||
| a, | ||
| w1, | ||
| w2, | ||
| topk_weight, | ||
| topk_ids, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ", ".join(f.name for f in fields(Case)), | ||
| [ | ||
| tuple(getattr(case, f.name) for f in fields(Case)) | ||
| for case in [ | ||
| # Case(a_dtype="bf16", w_dtype="bf16"), | ||
| # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"), | ||
| Case(a_dtype="bf16", w_dtype="mx4") | ||
| ] | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("num_token", [64]) | ||
| @pytest.mark.parametrize("ep", [1, 2, 4, 8]) | ||
| def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep): | ||
| M = num_token | ||
| E = ModelConfig.num_experts // ep | ||
| K = ModelConfig.hidden_size | ||
| N = ModelConfig.intermediate_size | ||
| topk = ModelConfig.experts_per_token | ||
|
|
||
| ( | ||
| x, | ||
| w1, | ||
| w1_bias, | ||
| w2, | ||
| w2_bias, | ||
| exp_data, | ||
| x_tri, | ||
| w1_tri, | ||
| w2_tri, | ||
| exp_data_tri, | ||
| w1_bias_tri, | ||
| w2_bias_tri, | ||
| pc1, | ||
| pc2, | ||
| ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=4) | ||
|
|
||
| out_tri = batched_moe( | ||
| a=x_tri, | ||
| w1=w1_tri, | ||
| w2=w2_tri, | ||
| gating_output=exp_data_tri, | ||
| topk=topk, | ||
| renormalize=True, | ||
| w1_bias=w1_bias_tri, | ||
| w2_bias=w2_bias_tri, | ||
| w1_precision=pc1, | ||
| w2_precision=pc2, | ||
| ) | ||
| out_tri = out_tri[..., :K] | ||
|
|
||
| out_ref = oai_moe_forward( | ||
| hidden_states=x, | ||
| w1=w1, | ||
| w1_bias=w1_bias, | ||
| w2=w2, | ||
| w2_bias=w2_bias, | ||
| gating_output=exp_data, | ||
| topk=topk, | ||
| ) | ||
| assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005) | ||
|
|
||
|
|
||
| def test_unit_shuffle(): | ||
| N = ModelConfig.intermediate_size | ||
| K = ModelConfig.hidden_size | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.