-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[Transform] [Quantization] Add QuTLASS support to vLLM #24440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request integrates the QuTLASS library to add support for 4-bit quantization kernels, including new custom ops, benchmarks, and tests. The changes are well-structured. I have two high-severity suggestions: one to improve build reproducibility by pinning the QuTLASS dependency to a specific version, and another to fix a bug in a new test file to prevent future issues.
| FetchContent_Declare( | ||
| qutlass | ||
| GIT_REPOSITORY https://github.com/IST-DASLab/qutlass.git | ||
| GIT_TAG main |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using main as the GIT_TAG can lead to non-reproducible builds and may break the build if there are incompatible changes in the QuTLASS repository's main branch. It is highly recommended to pin this to a specific commit hash or a release tag (like v0.1.0 as mentioned in the PR description) to ensure build stability and reproducibility.
GIT_TAG v0.1.0
| b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.) | ||
| out_ref = a_dq @ b_dq.transpose(-2, -1) | ||
|
|
||
| out = qutlass.matmul_ada_mxf4_bf16_tn(a_e2m1, b_e2m1, a_e8m0, b_e8m0, alpha) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The run_problem_ada function attempts to call qutlass.matmul_ada_mxf4_bf16_tn, but qutlass is not defined or imported. This will result in a NameError. Although this function is not currently called, it's best to fix it to prevent future issues.
To fix this, you should add matmul_ada_mxf4_bf16_tn to your imports at the top of the file:
from vllm._custom_ops import matmul_mxf4_bf16_tn, fusedQuantizeMx, matmul_ada_mxf4_bf16_tnAnd then update this line accordingly.
| out = qutlass.matmul_ada_mxf4_bf16_tn(a_e2m1, b_e2m1, a_e8m0, b_e8m0, alpha) | |
| out = matmul_ada_mxf4_bf16_tn(a_e2m1, b_e2m1, a_e8m0, b_e8m0, alpha) |
|
@LopezCastroRoberto does this PR support gpt-oss on sm120 ? How to exactly test some mxfp4 models with this PR? Would love to test rtx 6000 pro on this |
| return torch.tensor( | ||
| hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use our hadamard utility for consistency?
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
| return torch.tensor( | |
| hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device | |
| ) | |
| deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) * group_size**-0.5 |
|
|
||
| def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device): | ||
| weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4(b, forward_hadamard_matrix, device) | ||
| alpha = torch.Tensor([1.]).to("cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| alpha = torch.Tensor([1.]).to("cuda") | |
| alpha = torch.Tensor([1.], device="cuda") |
|
|
||
| def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): | ||
| return torch.tensor( | ||
| hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, use our util
| 'Llama-3.1-70B': [(8192, 8192), (8192, 57344), (28672, 8192)] | ||
| } | ||
|
|
||
| for model, layers in MODELS.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please wrap in `if name == "main"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding some user arguments
| 'Llama-3.1-70B': [(8192, 8192), (8192, 57344), (28672, 8192)] | ||
| } | ||
|
|
||
| for model, layers in MODELS.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please wrap in `if name == "main"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider allowing users to specify arguments, that way you don't have to have commented code
vllm/_custom_ops.py
Outdated
|
|
||
| def fusedQuantizeMx(a: torch.Tensor, | ||
| b: torch.Tensor, | ||
| *, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the point of this *?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Means all arguments that come after the * must be passed by keyword, not by position. My point was to make the API clearer and less error-prone.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fair!
vllm/_custom_ops.py
Outdated
| xh_e8m0 = torch.empty(padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=a.device) | ||
|
|
||
| if method=="quest": | ||
| return torch.ops._qutlass_C.fusedQuantizeMxQuest(a, b, xh_e2m1, xh_e8m0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because these functions have a return value, you'll want to register a fake function so torch compile works right
if hasattr(torch.ops._C, "_qutlass_C"):
@register_fake("_C::_qutlass_C::fusedQuantizeMxQuest")
def fake_qutlass_mx_quest(a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return (torch.empty(...), torch.empty(...))
vllm/qutlass_utils/utils.py
Outdated
| output_block_stride, | ||
| BLOCK_ROWS: tl.constexpr, | ||
| BLOCK_COLS: tl.constexpr, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this over-indented? I think we should standardize on 4 space indent
vllm/qutlass_utils/utils.py
Outdated
| return (a + b - 1) // b | ||
|
|
||
|
|
||
| def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just as a style thing, consider calling triton_mx_block_rearrange in cases where you want to use the triton kernel and to_blocked otherwise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about keeping one to_blocked but making the backend explicit (e.g. backend="torch" | "triton" | "auto")?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both good!
| # Quantize activation on-the-fly | ||
| def run(): | ||
| input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(a, forward_hadamard_matrix, global_scale) | ||
| input_hf_scale_block = to_blocked(input_hf_e8m0, True).view(-1,K//16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will the triton jit affect benchmarked runtime? Ie, first time compile causes the first graph to take longer than normal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes—the very first time is slower, but after that it's cached
|
@voipmonitor This PR supports dense models only, and it's perfectly fine to use an RTX 6000 Pro. We will add usage examples to this PR soon. We’re actively working on MoE support in QuTLASS—stay tuned :) |
f9ca647 to
dce5334
Compare
Signed-off-by: LopezCastroRoberto <[email protected]>
Signed-off-by: LopezCastroRoberto <[email protected]>
Signed-off-by: LopezCastroRoberto <[email protected]>
Signed-off-by: LopezCastroRoberto <[email protected]>
…24440) Signed-off-by: LopezCastroRoberto <[email protected]> Signed-off-by: Roberto L. Castro <[email protected]> Signed-off-by: Andrei Panferov <[email protected]> Co-authored-by: Andrei Panferov <[email protected]> Co-authored-by: Michael Goin <[email protected]> Signed-off-by: Dhruvil Bhatt <[email protected]>
…24440) Signed-off-by: LopezCastroRoberto <[email protected]> Signed-off-by: Roberto L. Castro <[email protected]> Signed-off-by: Andrei Panferov <[email protected]> Co-authored-by: Andrei Panferov <[email protected]> Co-authored-by: Michael Goin <[email protected]> Signed-off-by: bbartels <[email protected]>
…24440) Signed-off-by: LopezCastroRoberto <[email protected]> Signed-off-by: Roberto L. Castro <[email protected]> Signed-off-by: Andrei Panferov <[email protected]> Co-authored-by: Andrei Panferov <[email protected]> Co-authored-by: Michael Goin <[email protected]>
…24440) Signed-off-by: LopezCastroRoberto <[email protected]> Signed-off-by: Roberto L. Castro <[email protected]> Signed-off-by: Andrei Panferov <[email protected]> Co-authored-by: Andrei Panferov <[email protected]> Co-authored-by: Michael Goin <[email protected]>
…24440) Signed-off-by: LopezCastroRoberto <[email protected]> Signed-off-by: Roberto L. Castro <[email protected]> Signed-off-by: Andrei Panferov <[email protected]> Co-authored-by: Andrei Panferov <[email protected]> Co-authored-by: Michael Goin <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
…24440) Signed-off-by: LopezCastroRoberto <[email protected]> Signed-off-by: Roberto L. Castro <[email protected]> Signed-off-by: Andrei Panferov <[email protected]> Co-authored-by: Andrei Panferov <[email protected]> Co-authored-by: Michael Goin <[email protected]> Signed-off-by: 0xrushi <[email protected]>
…24440) Signed-off-by: LopezCastroRoberto <[email protected]> Signed-off-by: Roberto L. Castro <[email protected]> Signed-off-by: Andrei Panferov <[email protected]> Co-authored-by: Andrei Panferov <[email protected]> Co-authored-by: Michael Goin <[email protected]> Signed-off-by: 0xrushi <[email protected]>
…24440) Signed-off-by: LopezCastroRoberto <[email protected]> Signed-off-by: Roberto L. Castro <[email protected]> Signed-off-by: Andrei Panferov <[email protected]> Co-authored-by: Andrei Panferov <[email protected]> Co-authored-by: Michael Goin <[email protected]>
| cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a" "${CUDA_ARCHS}") | ||
| if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND QUTLASS_ARCHS) | ||
|
|
||
| if(QUTLASS_ARCHS MATCHES "10\\.0a") | ||
| set(QUTLASS_TARGET_CC 100) | ||
| elseif(QUTLASS_ARCHS MATCHES "12\\.0a") | ||
| set(QUTLASS_TARGET_CC 120) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This excludes the DGX Spark (SM121) from using QuTLASS, since it is not 12.0a. If you adjust this to allow 12.0f or 12.1a, this will also get built for the spark. I built this on my Spark locally and all tests in test_mxfp4_qutlass.py and test_nvfp4_qutlass.py pass, but I'm not sure if that's all that's needed to verify this would work for that hardware.
Is there a way I can verify this is getting used in a running vLLM, outside of those tests?
|
There have been reports of NVFP4 QuTLASS failing on RTX 5090 and DGX Spark too. |
I tested locally loosening up the cmake CUDA_ARCHS matching and was able to run an example FP-Quant model fine on a DGX Spark building vLLM from source. I don't have an RTX 5090 to test, but on the surface I would have assumed it already matched this CUDA_ARCHS check. |
|
Can we just build for 10.0+PTX for instance? |
|
If I remember correctly, the kernel dispatch is static w.r.t. arch because the kernels were separately tuned for each architecture. Compiling for a non-supported arch would throw runtime errors. @LopezCastroRoberto would we have to re-tune the kernels for |
…24440) Signed-off-by: LopezCastroRoberto <[email protected]> Signed-off-by: Roberto L. Castro <[email protected]> Signed-off-by: Andrei Panferov <[email protected]> Co-authored-by: Andrei Panferov <[email protected]> Co-authored-by: Michael Goin <[email protected]>
Purpose
This pull request brings in the QuTLASS library: https://github.com/iST-DASLab/qutlass
QuTLASS is a high-performance library designed for low-precision kernel support in deep learning quantization, built on top of NVIDIA CUTLASS.
QuTLASS v0.1.0 introduces 4-bit microscaling routines tailored for Large Language Model (LLM) inference on NVIDIA Blackwell GPUs.
Microbenchmarking
benchmarks/kernels/bench_mxfp4_qutlass.pybenchmarks/kernels/bench_nvfp4_qutlass.pyQuTLASS performance on a single Qwen3-32B layer with NVIDIA RTX5090 GPU
QuTLASS performance on a single Llama-3.1-70B layer with NVIDIA B200 GPU
[WIP] End-to-end
python benchmarks/benchmark_latency.pydaslab-testing/Llama-3.3-70B-Instruct-FPQuant-GPTQ-MXFP4-hadamardmeta-llama/Llama-3.3-70B-InstructFP16
MXFP4
Testing
tests/kernels/quantization/test_mxfp4_qutlass.pytests/kernels/quantization/test_nvfp4_qutlass.py