|
13 | 13 | QuantizationConfig, QuantizeMethodBase)
|
14 | 14 | from vllm.model_executor.utils import set_weight_attrs
|
15 | 15 | from vllm.platforms import current_platform
|
| 16 | +from vllm.platforms.interface import CpuArchEnum |
16 | 17 |
|
17 | 18 | if current_platform.is_cuda_alike():
|
18 | 19 | from .fused_moe import fused_experts
|
@@ -83,6 +84,20 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
83 | 84 | layer.register_parameter("w2_weight", w2_weight)
|
84 | 85 | set_weight_attrs(w2_weight, extra_weight_attrs)
|
85 | 86 |
|
| 87 | + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| 88 | + super().process_weights_after_loading(layer) |
| 89 | + |
| 90 | + if current_platform.is_cpu(): |
| 91 | + if current_platform.get_cpu_architecture() == CpuArchEnum.X86: |
| 92 | + import intel_extension_for_pytorch as ipex |
| 93 | + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( |
| 94 | + layer.w13_weight, |
| 95 | + layer.w2_weight, |
| 96 | + use_prepack=True, |
| 97 | + ) |
| 98 | + else: |
| 99 | + raise NotImplementedError("CPU MOE only supports x86 arch.") |
| 100 | + |
86 | 101 | def apply(
|
87 | 102 | self,
|
88 | 103 | layer: torch.nn.Module,
|
@@ -142,9 +157,29 @@ def forward_cuda(
|
142 | 157 | topk_ids=topk_ids,
|
143 | 158 | inplace=True)
|
144 | 159 |
|
145 |
| - def forward_cpu(self, *args, **kwargs): |
146 |
| - raise NotImplementedError( |
147 |
| - "The CPU backend currently does not support MoE.") |
| 160 | + def forward_cpu( |
| 161 | + self, |
| 162 | + layer: torch.nn.Module, |
| 163 | + x: torch.Tensor, |
| 164 | + use_grouped_topk: bool, |
| 165 | + top_k: int, |
| 166 | + router_logits: torch.Tensor, |
| 167 | + renormalize: bool, |
| 168 | + topk_group: Optional[int] = None, |
| 169 | + num_expert_group: Optional[int] = None, |
| 170 | + custom_routing_function: Optional[Callable] = None, |
| 171 | + **kwargs, |
| 172 | + ): |
| 173 | + assert custom_routing_function is None |
| 174 | + return layer.ipex_fusion( |
| 175 | + x, |
| 176 | + use_grouped_topk, |
| 177 | + top_k, |
| 178 | + router_logits, |
| 179 | + renormalize, |
| 180 | + topk_group, |
| 181 | + num_expert_group, |
| 182 | + ) |
148 | 183 |
|
149 | 184 | def forward_tpu(
|
150 | 185 | self,
|
|
0 commit comments