-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Quantization] Fp8 Channelwise Dynamic Per Token GroupedGEMM #15587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
robertgshaw2-redhat
merged 71 commits into
vllm-project:main
from
robertgshaw2-redhat:dyn-per-token-grouped-gemm
Mar 27, 2025
Merged
Changes from all commits
Commits
Show all changes
71 commits
Select commit
Hold shift + click to select a range
1825ef8
Cutlass grouped gemm files
ElizaWszola 5fd48e5
runs, bad result
ElizaWszola d5942cf
A little closer to working
ElizaWszola c570c69
Working for identical sizes
ElizaWszola 6ed63f2
Grouped gemm working
ElizaWszola e2b1fc0
Small cleanup
ElizaWszola dd163f5
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola acfd3ef
Benchmark grouped cutlass against bfloat16 torch.mm
ElizaWszola c6231b6
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola f1a5666
Start working on fused moe cutlass implementation
ElizaWszola 6414e31
Working halfway
ElizaWszola 67e2dd4
working mul test but the topk_weights are not yet included in kernel
ElizaWszola 6523529
cleaned up cutlass moe test, fixes
ElizaWszola b302d98
benchmark fused
ElizaWszola 342d1a4
pass input as one tensor with an array of offsets rather than a list …
ElizaWszola 7549e3d
Using tensors rather than tensor lists works with test_cutlass test
ElizaWszola 64c2a68
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola 1ea7874
cleanup, add import
ElizaWszola d608164
working fused op
ElizaWszola 286f6c8
benchmark, create strides directly on device, small name refactor
ElizaWszola b6867bb
works with cuda graphs
ElizaWszola df04bc0
move stride tensor creation outside c++ code, cleanup
ElizaWszola 88c7134
cleanup benchmark
ElizaWszola 02e1d4e
profile
ElizaWszola 1d9c429
tuned shapes, fix
ElizaWszola b824ad2
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola ae90eee
Performance, add channelwise scales everywhere
ElizaWszola f191b35
name fix
ElizaWszola 22d4f7b
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola 51941ff
perf improvements in data preparation
ElizaWszola d3cf1db
Integrate with deepseek v2
ElizaWszola 175ecdd
cudagraphs fix
ElizaWszola 3d7a487
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola ec0cb94
larger index type to support very large batches
ElizaWszola 6dd6d48
update benchmarks
ElizaWszola 716d8c0
Faster data preparation kernels, bring back correct benchmark shapes
ElizaWszola 975ab5f
enable cutlass grouped gemm only on sm90
ElizaWszola e83910e
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola 89f2d1c
Move arch detection to CompressedTensorsMoEMethod, cleanup, bring bac…
ElizaWszola 4d2f62f
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola 8fddd4f
Fix merge, cleanup imports
ElizaWszola 583f749
fix benchmark precommit hooks
ElizaWszola 10f5a97
Various cleanups
ElizaWszola 5e85587
precommit hook fix
ElizaWszola 63f6733
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola 8f5ac77
Post-merge fix, fallback to triton if not yet implemented features ar…
ElizaWszola 3a01616
Lots of minor feedback changes, self-commenting names
ElizaWszola 3159141
format
ElizaWszola baa503d
Decide whether to use cutlass or triton in compressed tensors method …
ElizaWszola ed673cb
Docs, remove redundant args
ElizaWszola 5287681
Changed CUDA version error message, added tp TODO to benchmark
ElizaWszola 42dc92c
Add tp argument to benchmarks
ElizaWszola 53ab07a
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola d8de3c9
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola 83f7084
Add bfloat16 type to the kernel
ElizaWszola be83180
Rename groups to num_experts in kernel, make group starts kernel more…
ElizaWszola e6481c8
format
ElizaWszola f0c2f06
format
ElizaWszola 8d0e700
format 3
ElizaWszola 84dbc2a
Add hack for accepting int input in weak_ref_tensors
ElizaWszola 5ad4b0b
Fixes
ElizaWszola 41eb522
format utils.py
ElizaWszola f5b5c7d
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola c6076b3
Make handling of both input scales consistent in the code
ElizaWszola c8f1567
Fix handling optional vals
ElizaWszola 96296cb
feedback: version checks, file structure
ElizaWszola 3977d67
Change cmake flag, remove unused code
ElizaWszola 83ee170
update kernel run conditions in scaled_mm_entry.cu
ElizaWszola fbe2b80
added channelwise, dynamic per token
robertgshaw2-redhat e0af782
updated
robertgshaw2-redhat e0bae3c
updated
robertgshaw2-redhat File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -268,14 +268,23 @@ def __init__( | |
self.input_quant = self.quant_config.target_scheme_map["Linear"].get( | ||
"input_activations") | ||
|
||
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR | ||
and self.input_quant.strategy == QuantizationStrategy.TENSOR): | ||
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR | ||
and self.input_quant.strategy | ||
== QuantizationStrategy.TENSOR) | ||
per_channel = ( | ||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL | ||
and self.input_quant.strategy == QuantizationStrategy.TOKEN) | ||
Comment on lines
+271
to
+276
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can mix and match There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, will change this in a FUP |
||
if not (per_tensor or per_channel): | ||
raise ValueError( | ||
"For FP8 Fused MoE layers, only per-tensor scales " | ||
"for weights and activations are supported. Found " | ||
"For FP8 Fused MoE layers, we require per tensor " | ||
"or channelwise, dynamic per token quantization. Found " | ||
f"{self.weight_quant}, {self.input_quant}") | ||
|
||
self.static_input_scales = not self.input_quant.dynamic | ||
if self.static_input_scales and per_channel: | ||
raise ValueError( | ||
"For FP8 Fused MoE layer, we require either per tensor or " | ||
"channelwise, dynamic per token quantization.") | ||
|
||
def create_weights(self, layer: torch.nn.Module, num_experts: int, | ||
hidden_size: int, intermediate_size_per_partition: int, | ||
|
@@ -303,24 +312,40 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, | |
set_weight_attrs(w2_weight, extra_weight_attrs) | ||
|
||
# WEIGHT_SCALES | ||
# Allocate 2 scales for w1 and w3 respectively. | ||
# They will be combined to a single scale after weight loading. | ||
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, | ||
2, | ||
dtype=torch.float32), | ||
requires_grad=False) | ||
layer.register_parameter("w13_weight_scale", w13_weight_scale) | ||
|
||
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, | ||
dtype=torch.float32), | ||
requires_grad=False) | ||
layer.register_parameter("w2_weight_scale", w2_weight_scale) | ||
# Add the quantization method used (per tensor/grouped/channel) | ||
# to ensure the weight scales are loaded in properly | ||
extra_weight_attrs.update( | ||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) | ||
set_weight_attrs(w13_weight_scale, extra_weight_attrs) | ||
set_weight_attrs(w2_weight_scale, extra_weight_attrs) | ||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR: | ||
# Allocate 2 scales for w1 and w3 respectively. | ||
# They are combined to a single scale after weight loading. | ||
w13_weight_scale = torch.nn.Parameter(torch.ones( | ||
num_experts, 2, dtype=torch.float32), | ||
requires_grad=False) | ||
layer.register_parameter("w13_weight_scale", w13_weight_scale) | ||
w2_weight_scale = torch.nn.Parameter(torch.ones( | ||
num_experts, dtype=torch.float32), | ||
requires_grad=False) | ||
layer.register_parameter("w2_weight_scale", w2_weight_scale) | ||
# Add PER-TENSOR quantization for FusedMoE.weight_loader. | ||
extra_weight_attrs.update( | ||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) | ||
set_weight_attrs(w13_weight_scale, extra_weight_attrs) | ||
set_weight_attrs(w2_weight_scale, extra_weight_attrs) | ||
|
||
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: | ||
w13_weight_scale = torch.nn.Parameter(torch.ones( | ||
num_experts, | ||
2 * intermediate_size_per_partition, | ||
1, | ||
dtype=torch.float32), | ||
requires_grad=False) | ||
layer.register_parameter("w13_weight_scale", w13_weight_scale) | ||
w2_weight_scale = torch.nn.Parameter(torch.ones( | ||
num_experts, hidden_size, 1, dtype=torch.float32), | ||
requires_grad=False) | ||
layer.register_parameter("w2_weight_scale", w2_weight_scale) | ||
# Add PER-CHANNEL quantization for FusedMoE.weight_loader. | ||
extra_weight_attrs.update( | ||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) | ||
set_weight_attrs(w13_weight_scale, extra_weight_attrs) | ||
set_weight_attrs(w2_weight_scale, extra_weight_attrs) | ||
|
||
# INPUT_SCALES | ||
if self.static_input_scales: | ||
|
@@ -362,6 +387,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
# Fp8 moe kernels require a single activation scale. | ||
# We take the max of all the scales in case they differ. | ||
if self.static_input_scales: | ||
assert self.input_quant.strategy == QuantizationStrategy.TENSOR | ||
if (layer.w13_input_scale is None or layer.w2_input_scale is None): | ||
raise ValueError( | ||
"QuantConfig has static quantization, but found " | ||
|
@@ -377,24 +403,25 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
layer.w2_input_scale = torch.nn.Parameter( | ||
layer.w2_input_scale.max(), requires_grad=False) | ||
|
||
# Fp8 moe kernel needs single weight scale for w13 per expert. | ||
# We take the max then dequant and requant each expert. | ||
assert layer.w13_weight_scale is not None | ||
shard_size = layer.intermediate_size_per_partition | ||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values | ||
for expert_id in range(layer.local_num_experts): | ||
start = 0 | ||
for shard_id in range(2): | ||
dq_weight = per_tensor_dequantize( | ||
layer.w13_weight[expert_id][start:start + shard_size, :], | ||
layer.w13_weight_scale[expert_id][shard_id]) | ||
layer.w13_weight[expert_id][ | ||
start:start + shard_size, :], _ = ops.scaled_fp8_quant( | ||
dq_weight, max_w13_scales[expert_id]) | ||
start += shard_size | ||
|
||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, | ||
requires_grad=False) | ||
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale | ||
# for w13 per expert. Use max then dequant and requant each expert. | ||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR: | ||
assert layer.w13_weight_scale is not None | ||
shard_size = layer.intermediate_size_per_partition | ||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values | ||
for expert_id in range(layer.local_num_experts): | ||
start = 0 | ||
for shard_id in range(2): | ||
dq_weight = per_tensor_dequantize( | ||
layer.w13_weight[expert_id][start:start + | ||
shard_size, :], | ||
layer.w13_weight_scale[expert_id][shard_id]) | ||
layer.w13_weight[expert_id][ | ||
start:start + shard_size, :], _ = ops.scaled_fp8_quant( | ||
dq_weight, max_w13_scales[expert_id]) | ||
start += shard_size | ||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, | ||
requires_grad=False) | ||
|
||
def apply( | ||
self, | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is dead code (its not used in the codebase)