Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
1825ef8
Cutlass grouped gemm files
ElizaWszola Dec 6, 2024
5fd48e5
runs, bad result
ElizaWszola Dec 9, 2024
d5942cf
A little closer to working
ElizaWszola Dec 10, 2024
c570c69
Working for identical sizes
ElizaWszola Dec 11, 2024
6ed63f2
Grouped gemm working
ElizaWszola Dec 17, 2024
e2b1fc0
Small cleanup
ElizaWszola Dec 17, 2024
dd163f5
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Jan 8, 2025
acfd3ef
Benchmark grouped cutlass against bfloat16 torch.mm
ElizaWszola Jan 13, 2025
c6231b6
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Jan 13, 2025
f1a5666
Start working on fused moe cutlass implementation
ElizaWszola Jan 17, 2025
6414e31
Working halfway
ElizaWszola Jan 20, 2025
67e2dd4
working mul test but the topk_weights are not yet included in kernel
ElizaWszola Jan 23, 2025
6523529
cleaned up cutlass moe test, fixes
ElizaWszola Jan 23, 2025
b302d98
benchmark fused
ElizaWszola Jan 23, 2025
342d1a4
pass input as one tensor with an array of offsets rather than a list …
ElizaWszola Jan 24, 2025
7549e3d
Using tensors rather than tensor lists works with test_cutlass test
ElizaWszola Jan 28, 2025
64c2a68
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Jan 28, 2025
1ea7874
cleanup, add import
ElizaWszola Jan 28, 2025
d608164
working fused op
ElizaWszola Jan 29, 2025
286f6c8
benchmark, create strides directly on device, small name refactor
ElizaWszola Jan 29, 2025
b6867bb
works with cuda graphs
ElizaWszola Jan 31, 2025
df04bc0
move stride tensor creation outside c++ code, cleanup
ElizaWszola Jan 31, 2025
88c7134
cleanup benchmark
ElizaWszola Jan 31, 2025
02e1d4e
profile
ElizaWszola Feb 4, 2025
1d9c429
tuned shapes, fix
ElizaWszola Feb 14, 2025
b824ad2
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Feb 14, 2025
ae90eee
Performance, add channelwise scales everywhere
ElizaWszola Feb 18, 2025
f191b35
name fix
ElizaWszola Feb 20, 2025
22d4f7b
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Feb 20, 2025
51941ff
perf improvements in data preparation
ElizaWszola Feb 20, 2025
d3cf1db
Integrate with deepseek v2
ElizaWszola Feb 24, 2025
175ecdd
cudagraphs fix
ElizaWszola Feb 24, 2025
3d7a487
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Feb 25, 2025
ec0cb94
larger index type to support very large batches
ElizaWszola Feb 25, 2025
6dd6d48
update benchmarks
ElizaWszola Feb 25, 2025
716d8c0
Faster data preparation kernels, bring back correct benchmark shapes
ElizaWszola Feb 27, 2025
975ab5f
enable cutlass grouped gemm only on sm90
ElizaWszola Feb 28, 2025
e83910e
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Mar 4, 2025
89f2d1c
Move arch detection to CompressedTensorsMoEMethod, cleanup, bring bac…
ElizaWszola Mar 5, 2025
4d2f62f
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Mar 5, 2025
8fddd4f
Fix merge, cleanup imports
ElizaWszola Mar 5, 2025
583f749
fix benchmark precommit hooks
ElizaWszola Mar 5, 2025
10f5a97
Various cleanups
ElizaWszola Mar 5, 2025
5e85587
precommit hook fix
ElizaWszola Mar 5, 2025
63f6733
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Mar 12, 2025
8f5ac77
Post-merge fix, fallback to triton if not yet implemented features ar…
ElizaWszola Mar 12, 2025
3a01616
Lots of minor feedback changes, self-commenting names
ElizaWszola Mar 17, 2025
3159141
format
ElizaWszola Mar 17, 2025
baa503d
Decide whether to use cutlass or triton in compressed tensors method …
ElizaWszola Mar 17, 2025
ed673cb
Docs, remove redundant args
ElizaWszola Mar 18, 2025
5287681
Changed CUDA version error message, added tp TODO to benchmark
ElizaWszola Mar 18, 2025
42dc92c
Add tp argument to benchmarks
ElizaWszola Mar 18, 2025
53ab07a
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Mar 18, 2025
d8de3c9
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Mar 18, 2025
83f7084
Add bfloat16 type to the kernel
ElizaWszola Mar 18, 2025
be83180
Rename groups to num_experts in kernel, make group starts kernel more…
ElizaWszola Mar 19, 2025
e6481c8
format
ElizaWszola Mar 19, 2025
f0c2f06
format
ElizaWszola Mar 19, 2025
8d0e700
format 3
ElizaWszola Mar 19, 2025
84dbc2a
Add hack for accepting int input in weak_ref_tensors
ElizaWszola Mar 21, 2025
5ad4b0b
Fixes
ElizaWszola Mar 24, 2025
41eb522
format utils.py
ElizaWszola Mar 24, 2025
f5b5c7d
Merge branch 'main' into grouped-gemm-with-group-id
ElizaWszola Mar 24, 2025
c6076b3
Make handling of both input scales consistent in the code
ElizaWszola Mar 26, 2025
c8f1567
Fix handling optional vals
ElizaWszola Mar 26, 2025
96296cb
feedback: version checks, file structure
ElizaWszola Mar 26, 2025
3977d67
Change cmake flag, remove unused code
ElizaWszola Mar 26, 2025
83ee170
update kernel run conditions in scaled_mm_entry.cu
ElizaWszola Mar 26, 2025
fbe2b80
added channelwise, dynamic per token
robertgshaw2-redhat Mar 27, 2025
e0af782
updated
robertgshaw2-redhat Mar 27, 2025
e0bae3c
updated
robertgshaw2-redhat Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,32 +884,6 @@ def make_expert_params_mapping(
]
]

def _load_fp8_scale(self, param: torch.nn.Parameter,
Copy link
Collaborator Author

@robertgshaw2-redhat robertgshaw2-redhat Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is dead code (its not used in the codebase)

loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
param_data = param.data

# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight

def extra_repr(self) -> str:

s = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,23 @@ def __init__(
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")

if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR):
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy
== QuantizationStrategy.TENSOR)
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
Comment on lines +271 to +276
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can mix and match

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, will change this in a FUP

if not (per_tensor or per_channel):
raise ValueError(
"For FP8 Fused MoE layers, only per-tensor scales "
"for weights and activations are supported. Found "
"For FP8 Fused MoE layers, we require per tensor "
"or channelwise, dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")

self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales and per_channel:
raise ValueError(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization.")

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
Expand Down Expand Up @@ -303,24 +312,40 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
set_weight_attrs(w2_weight, extra_weight_attrs)

# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)

w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# Allocate 2 scales for w1 and w3 respectively.
# They are combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, 2, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

# INPUT_SCALES
if self.static_input_scales:
Expand Down Expand Up @@ -362,6 +387,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.static_input_scales:
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
raise ValueError(
"QuantConfig has static quantization, but found "
Expand All @@ -377,24 +403,25 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False)

# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size

layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
# for w13 per expert. Use max then dequant and requant each expert.
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start +
shard_size, :],
layer.w13_weight_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)

def apply(
self,
Expand Down