Skip to content

Commit 2574c70

Browse files
robertgshaw2-redhatElizaWszolaLucasWilkinson
authored andcommitted
[Quantization] Fp8 Channelwise Dynamic Per Token GroupedGEMM (vllm-project#15587)
Signed-off-by: ElizaWszola <[email protected]> Signed-off-by: ElizaWszola <[email protected]> Signed-off-by: [email protected] <[email protected]> Co-authored-by: ElizaWszola <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]> Co-authored-by: ElizaWszola <[email protected]>
1 parent 577585e commit 2574c70

File tree

2 files changed

+67
-66
lines changed

2 files changed

+67
-66
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -885,32 +885,6 @@ def make_expert_params_mapping(
885885
]
886886
]
887887

888-
def _load_fp8_scale(self, param: torch.nn.Parameter,
889-
loaded_weight: torch.Tensor, weight_name: str,
890-
shard_id: str, expert_id: int) -> None:
891-
param_data = param.data
892-
893-
# Input scales can be loaded directly and should be equal.
894-
if "input_scale" in weight_name:
895-
if param_data[expert_id] != 1 and (param_data[expert_id] -
896-
loaded_weight).abs() > 1e-5:
897-
raise ValueError(
898-
"input_scales of w1 and w3 of a layer "
899-
f"must be equal. But got {param_data[expert_id]} "
900-
f"vs. {loaded_weight}")
901-
param_data[expert_id] = loaded_weight
902-
# Weight scales
903-
elif "weight_scale" in weight_name:
904-
# If we are in merged column case (gate_up_proj)
905-
if shard_id in ("w1", "w3"):
906-
# We have to keep the weight scales of w1 and w3 because
907-
# we need to re-quantize w1/w3 weights after weight loading.
908-
idx = 0 if shard_id == "w1" else 1
909-
param_data[expert_id][idx] = loaded_weight
910-
# If we are in the row parallel case (down_proj)
911-
else:
912-
param_data[expert_id] = loaded_weight
913-
914888
def extra_repr(self) -> str:
915889

916890
s = (

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 67 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -268,14 +268,23 @@ def __init__(
268268
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
269269
"input_activations")
270270

271-
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR
272-
and self.input_quant.strategy == QuantizationStrategy.TENSOR):
271+
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
272+
and self.input_quant.strategy
273+
== QuantizationStrategy.TENSOR)
274+
per_channel = (
275+
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
276+
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
277+
if not (per_tensor or per_channel):
273278
raise ValueError(
274-
"For FP8 Fused MoE layers, only per-tensor scales "
275-
"for weights and activations are supported. Found "
279+
"For FP8 Fused MoE layers, we require per tensor "
280+
"or channelwise, dynamic per token quantization. Found "
276281
f"{self.weight_quant}, {self.input_quant}")
277282

278283
self.static_input_scales = not self.input_quant.dynamic
284+
if self.static_input_scales and per_channel:
285+
raise ValueError(
286+
"For FP8 Fused MoE layer, we require either per tensor or "
287+
"channelwise, dynamic per token quantization.")
279288

280289
def create_weights(self, layer: torch.nn.Module, num_experts: int,
281290
hidden_size: int, intermediate_size_per_partition: int,
@@ -303,24 +312,40 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
303312
set_weight_attrs(w2_weight, extra_weight_attrs)
304313

305314
# WEIGHT_SCALES
306-
# Allocate 2 scales for w1 and w3 respectively.
307-
# They will be combined to a single scale after weight loading.
308-
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
309-
2,
310-
dtype=torch.float32),
311-
requires_grad=False)
312-
layer.register_parameter("w13_weight_scale", w13_weight_scale)
313-
314-
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
315-
dtype=torch.float32),
316-
requires_grad=False)
317-
layer.register_parameter("w2_weight_scale", w2_weight_scale)
318-
# Add the quantization method used (per tensor/grouped/channel)
319-
# to ensure the weight scales are loaded in properly
320-
extra_weight_attrs.update(
321-
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
322-
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
323-
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
315+
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
316+
# Allocate 2 scales for w1 and w3 respectively.
317+
# They are combined to a single scale after weight loading.
318+
w13_weight_scale = torch.nn.Parameter(torch.ones(
319+
num_experts, 2, dtype=torch.float32),
320+
requires_grad=False)
321+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
322+
w2_weight_scale = torch.nn.Parameter(torch.ones(
323+
num_experts, dtype=torch.float32),
324+
requires_grad=False)
325+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
326+
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
327+
extra_weight_attrs.update(
328+
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
329+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
330+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
331+
332+
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
333+
w13_weight_scale = torch.nn.Parameter(torch.ones(
334+
num_experts,
335+
2 * intermediate_size_per_partition,
336+
1,
337+
dtype=torch.float32),
338+
requires_grad=False)
339+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
340+
w2_weight_scale = torch.nn.Parameter(torch.ones(
341+
num_experts, hidden_size, 1, dtype=torch.float32),
342+
requires_grad=False)
343+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
344+
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
345+
extra_weight_attrs.update(
346+
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
347+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
348+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
324349

325350
# INPUT_SCALES
326351
if self.static_input_scales:
@@ -362,6 +387,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
362387
# Fp8 moe kernels require a single activation scale.
363388
# We take the max of all the scales in case they differ.
364389
if self.static_input_scales:
390+
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
365391
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
366392
raise ValueError(
367393
"QuantConfig has static quantization, but found "
@@ -377,24 +403,25 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
377403
layer.w2_input_scale = torch.nn.Parameter(
378404
layer.w2_input_scale.max(), requires_grad=False)
379405

380-
# Fp8 moe kernel needs single weight scale for w13 per expert.
381-
# We take the max then dequant and requant each expert.
382-
assert layer.w13_weight_scale is not None
383-
shard_size = layer.intermediate_size_per_partition
384-
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
385-
for expert_id in range(layer.local_num_experts):
386-
start = 0
387-
for shard_id in range(2):
388-
dq_weight = per_tensor_dequantize(
389-
layer.w13_weight[expert_id][start:start + shard_size, :],
390-
layer.w13_weight_scale[expert_id][shard_id])
391-
layer.w13_weight[expert_id][
392-
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
393-
dq_weight, max_w13_scales[expert_id])
394-
start += shard_size
395-
396-
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
397-
requires_grad=False)
406+
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
407+
# for w13 per expert. Use max then dequant and requant each expert.
408+
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
409+
assert layer.w13_weight_scale is not None
410+
shard_size = layer.intermediate_size_per_partition
411+
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
412+
for expert_id in range(layer.local_num_experts):
413+
start = 0
414+
for shard_id in range(2):
415+
dq_weight = per_tensor_dequantize(
416+
layer.w13_weight[expert_id][start:start +
417+
shard_size, :],
418+
layer.w13_weight_scale[expert_id][shard_id])
419+
layer.w13_weight[expert_id][
420+
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
421+
dq_weight, max_w13_scales[expert_id])
422+
start += shard_size
423+
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
424+
requires_grad=False)
398425

399426
def apply(
400427
self,

0 commit comments

Comments
 (0)