From 94fe48bdf2386d1ef69284faf4d4d050a9113c39 Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 5 Feb 2025 02:33:57 +0000 Subject: [PATCH 1/3] clean-up Signed-off-by: Dipika --- vllm/model_executor/layers/fused_moe/layer.py | 4 ++-- .../layers/quantization/gptq_marlin.py | 17 +++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3c7ef5e0080..f18c0313355 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -302,8 +302,8 @@ def __init__( "weight_loader": self.weight_loader, } # need full intermediate size pre-sharding for WNA16 act order - if (self.quant_method.__class__.__name__ == - "CompressedTensorsWNA16MoEMethod"): + if (self.quant_method.__class__.__name__ + in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 99ab299958b..b22d10a2299 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -323,13 +323,18 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - # Currently assuming is_k_full is always True - # (input size per partition is the same as full input size) - # Supports only sym for now (no zp) + intermediate_size_full = extra_weight_attrs.pop( + "intermediate_size_full") + + self.is_k_full = (not self.quant_config.desc_act) or ( + intermediate_size_per_partition == intermediate_size_full) + if self.quant_config.group_size != -1: scales_size13 = hidden_size // self.quant_config.group_size - scales_size2 = (intermediate_size_per_partition // - self.quant_config.group_size) + w2_scales_size = (intermediate_size_full + if self.quant_config.desc_act else + intermediate_size_per_partition) + scales_size2 = (w2_scales_size // self.quant_config.group_size) strategy = FusedMoeWeightScaleSupported.GROUP.value else: scales_size13 = 1 @@ -575,4 +580,4 @@ def apply( sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, num_bits=self.quant_config.quant_type.size_bits, - ).to(orig_dtype) + is_k_full=self.is_k_full).to(orig_dtype) From 4c99c79b0ea9dd8880a7462449857de5ae820b7c Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 5 Feb 2025 08:06:35 +0000 Subject: [PATCH 2/3] set scales and zeros attribute to load full w2 Signed-off-by: ElizaWszola --- vllm/model_executor/layers/quantization/gptq_marlin.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index b22d10a2299..39c24d532e6 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -326,6 +326,8 @@ def create_weights( intermediate_size_full = extra_weight_attrs.pop( "intermediate_size_full") + load_full_w2 = self.quant_config.desc_act + self.is_k_full = (not self.quant_config.desc_act) or ( intermediate_size_per_partition == intermediate_size_full) @@ -390,6 +392,7 @@ def create_weights( ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) + set_weight_attrs(w2_scales, {"load_full_w2": load_full_w2}) # up_proj scales w13_qzeros = torch.nn.Parameter( torch.empty(num_experts, @@ -411,6 +414,7 @@ def create_weights( ) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) + set_weight_attrs(w2_qzeros, {"load_full_w2": load_full_w2}) w13_g_idx = torch.nn.Parameter( torch.empty( num_experts, From 591b231feb3bbe693ce4d0332e19cc2e48b66a6e Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 6 Feb 2025 02:20:33 +0000 Subject: [PATCH 3/3] clean Signed-off-by: Dipika --- tests/weight_loading/models-large.txt | 2 ++ vllm/model_executor/layers/quantization/gptq_marlin.py | 10 ++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index 8ab7f05d7d1..9c1c11da572 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -1,5 +1,7 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main +compressed-tensors, nm-testing/test-w4a16-mixtral-actorder-group, main gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main +gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, gptq-8bit-128g-actorder_True awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 39c24d532e6..84c53b2c16d 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -326,8 +326,6 @@ def create_weights( intermediate_size_full = extra_weight_attrs.pop( "intermediate_size_full") - load_full_w2 = self.quant_config.desc_act - self.is_k_full = (not self.quant_config.desc_act) or ( intermediate_size_per_partition == intermediate_size_full) @@ -392,7 +390,9 @@ def create_weights( ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) - set_weight_attrs(w2_scales, {"load_full_w2": load_full_w2}) + # dont shard the w2 scales when running act order + set_weight_attrs(w2_scales, + {"load_full_w2": self.quant_config.desc_act}) # up_proj scales w13_qzeros = torch.nn.Parameter( torch.empty(num_experts, @@ -414,7 +414,9 @@ def create_weights( ) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) - set_weight_attrs(w2_qzeros, {"load_full_w2": load_full_w2}) + # dont shard the w2 scales when running act order + set_weight_attrs(w2_qzeros, + {"load_full_w2": self.quant_config.desc_act}) w13_g_idx = torch.nn.Parameter( torch.empty( num_experts,