Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import enum
from enum import Enum
from typing import Callable, List, Optional
from typing import Callable, Optional

import torch
from compressed_tensors import CompressionFormat
Expand All @@ -14,9 +14,12 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, marlin_make_workspace_new,
marlin_moe_permute_scales)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
Expand Down Expand Up @@ -54,18 +57,19 @@ def get_moe_method(
"input_activations")

if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
# Prefer to use the non-marlin kernel when:
# 1. Many experts (MarlinMoE gives poor performance when >= 16)
# 2. Non-FP16 dtype (MarlinMoE only supports FP16)
# 3. Actorder is not group/dynamic (g_idx is unsupported)
# 4. Scaled are grouped (channelwise is unsupported)
if ((layer.local_num_experts >= 16
or layer.params_dtype != torch.float16) and
weight_quant.actorder not in (ActivationOrdering.GROUP,
ActivationOrdering.DYNAMIC)
and weight_quant.strategy in QuantizationStrategy.GROUP):
# Prefer to use the MarlinMoE kernel when it is supported.
if not check_moe_marlin_supports_layer(layer,
weight_quant.group_size):
if (weight_quant.strategy in QuantizationStrategy.GROUP and
weight_quant.actorder in (ActivationOrdering.GROUP,
ActivationOrdering.DYNAMIC)):
raise ValueError(
"WNA16MoE is not supported with actorder=group/dynamic."
)
Comment on lines +66 to +68
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there cases where we raise a ValueError now where we didn't before?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. The limitations of the marlin or triton kernels have not changed under the hood, however we are now actually checking these at a higher level. So this is just failing faster

logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config)
else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
and layer.activation == "silu"):
Expand Down Expand Up @@ -705,15 +709,12 @@ def __init__(
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}")
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):

assert params_dtype == torch.float16, (
"float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
)

intermediate_size_full = extra_weight_attrs.pop(
"intermediate_size_full")

Expand Down Expand Up @@ -837,50 +838,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
layer.marlin_state = GPTQMarlinState.REPACK

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t

def get_scale_perms(num_bits: int):
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single

def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
group_size: int, num_bits: int):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:,
scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s

def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
size_n: int, group_size: int,
num_bits: int):
num_experts = s.shape[0]
output = torch.empty((num_experts, s.shape[1], s.shape[2]),
device=s.device,
dtype=s.dtype)
for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n,
group_size, num_bits)
return output

size_k2 = layer.w2_weight_packed.shape[2]
size_k13 = layer.w13_weight_packed.shape[2]

num_experts = layer.w13_weight_g_idx.shape[0]
device = layer.w13_weight_g_idx.device

Expand Down Expand Up @@ -938,33 +895,33 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
layer.w13_weight_packed.shape[2],
self.num_bits,
)
replace_tensor("w13_weight_packed", marlin_w13_qweight)
replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_weight_packed,
layer.w2_g_idx_sort_indices,
layer.w2_weight_packed.shape[1] * self.packed_factor,
layer.w2_weight_packed.shape[2],
self.num_bits,
)
replace_tensor("w2_weight_packed", marlin_w2_qweight)
replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
layer.w13_weight_scale,
size_k13,
layer.w13_weight_scale.shape[2],
self.group_size,
self.num_bits,
s=layer.w13_weight_scale,
size_k=layer.w13_weight_packed.shape[2],
size_n=layer.w13_weight_scale.shape[2],
group_size=self.group_size,
)
replace_tensor("w13_weight_scale", marlin_w13_scales)
replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
layer.w2_weight_scale,
layer.w2_weight_scale.shape[1] *
s=layer.w2_weight_scale,
size_k=layer.w2_weight_scale.shape[1] *
(self.group_size if self.group_size != -1 else self.packed_factor),
size_k2,
self.group_size,
self.num_bits,
size_n=layer.w2_weight_scale.shape[2],
group_size=self.group_size,
)
replace_tensor("w2_weight_scale", marlin_w2_scales)
replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)

layer.workspace = marlin_make_workspace_new(device, 4)

def apply(
self,
Expand All @@ -985,10 +942,6 @@ def apply(
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
if expert_map is not None:
raise NotImplementedError(
"Expert Parallelism is not supported for "
"fused Marlin MoE method.")
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for "
Expand All @@ -1015,11 +968,14 @@ def apply(
router_logits,
topk_weights,
topk_ids,
quant_type_id=self.quant_type.id,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_weight_g_idx,
g_idx2=layer.w2_weight_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.num_bits,
workspace=layer.workspace,
is_k_full=self.is_k_full)


Expand Down Expand Up @@ -1203,7 +1159,7 @@ def apply(
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."

topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
Expand All @@ -1223,6 +1179,7 @@ def apply(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int4_w4a16=self.num_bits == 4,
use_int8_w8a16=self.num_bits == 8,
global_num_experts=global_num_experts,
Expand Down