@@ -268,14 +268,23 @@ def __init__(
268
268
self .input_quant = self .quant_config .target_scheme_map ["Linear" ].get (
269
269
"input_activations" )
270
270
271
- if not (self .weight_quant .strategy == QuantizationStrategy .TENSOR
272
- and self .input_quant .strategy == QuantizationStrategy .TENSOR ):
271
+ per_tensor = (self .weight_quant .strategy == QuantizationStrategy .TENSOR
272
+ and self .input_quant .strategy
273
+ == QuantizationStrategy .TENSOR )
274
+ per_channel = (
275
+ self .weight_quant .strategy == QuantizationStrategy .CHANNEL
276
+ and self .input_quant .strategy == QuantizationStrategy .TOKEN )
277
+ if not (per_tensor or per_channel ):
273
278
raise ValueError (
274
- "For FP8 Fused MoE layers, only per-tensor scales "
275
- "for weights and activations are supported . Found "
279
+ "For FP8 Fused MoE layers, we require per tensor "
280
+ "or channelwise, dynamic per token quantization . Found "
276
281
f"{ self .weight_quant } , { self .input_quant } " )
277
282
278
283
self .static_input_scales = not self .input_quant .dynamic
284
+ if self .static_input_scales and per_channel :
285
+ raise ValueError (
286
+ "For FP8 Fused MoE layer, we require either per tensor or "
287
+ "channelwise, dynamic per token quantization." )
279
288
280
289
def create_weights (self , layer : torch .nn .Module , num_experts : int ,
281
290
hidden_size : int , intermediate_size_per_partition : int ,
@@ -303,24 +312,40 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
303
312
set_weight_attrs (w2_weight , extra_weight_attrs )
304
313
305
314
# WEIGHT_SCALES
306
- # Allocate 2 scales for w1 and w3 respectively.
307
- # They will be combined to a single scale after weight loading.
308
- w13_weight_scale = torch .nn .Parameter (torch .ones (num_experts ,
309
- 2 ,
310
- dtype = torch .float32 ),
311
- requires_grad = False )
312
- layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
313
-
314
- w2_weight_scale = torch .nn .Parameter (torch .ones (num_experts ,
315
- dtype = torch .float32 ),
316
- requires_grad = False )
317
- layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
318
- # Add the quantization method used (per tensor/grouped/channel)
319
- # to ensure the weight scales are loaded in properly
320
- extra_weight_attrs .update (
321
- {"quant_method" : FusedMoeWeightScaleSupported .TENSOR .value })
322
- set_weight_attrs (w13_weight_scale , extra_weight_attrs )
323
- set_weight_attrs (w2_weight_scale , extra_weight_attrs )
315
+ if self .weight_quant .strategy == QuantizationStrategy .TENSOR :
316
+ # Allocate 2 scales for w1 and w3 respectively.
317
+ # They are combined to a single scale after weight loading.
318
+ w13_weight_scale = torch .nn .Parameter (torch .ones (
319
+ num_experts , 2 , dtype = torch .float32 ),
320
+ requires_grad = False )
321
+ layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
322
+ w2_weight_scale = torch .nn .Parameter (torch .ones (
323
+ num_experts , dtype = torch .float32 ),
324
+ requires_grad = False )
325
+ layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
326
+ # Add PER-TENSOR quantization for FusedMoE.weight_loader.
327
+ extra_weight_attrs .update (
328
+ {"quant_method" : FusedMoeWeightScaleSupported .TENSOR .value })
329
+ set_weight_attrs (w13_weight_scale , extra_weight_attrs )
330
+ set_weight_attrs (w2_weight_scale , extra_weight_attrs )
331
+
332
+ elif self .weight_quant .strategy == QuantizationStrategy .CHANNEL :
333
+ w13_weight_scale = torch .nn .Parameter (torch .ones (
334
+ num_experts ,
335
+ 2 * intermediate_size_per_partition ,
336
+ 1 ,
337
+ dtype = torch .float32 ),
338
+ requires_grad = False )
339
+ layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
340
+ w2_weight_scale = torch .nn .Parameter (torch .ones (
341
+ num_experts , hidden_size , 1 , dtype = torch .float32 ),
342
+ requires_grad = False )
343
+ layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
344
+ # Add PER-CHANNEL quantization for FusedMoE.weight_loader.
345
+ extra_weight_attrs .update (
346
+ {"quant_method" : FusedMoeWeightScaleSupported .CHANNEL .value })
347
+ set_weight_attrs (w13_weight_scale , extra_weight_attrs )
348
+ set_weight_attrs (w2_weight_scale , extra_weight_attrs )
324
349
325
350
# INPUT_SCALES
326
351
if self .static_input_scales :
@@ -362,6 +387,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
362
387
# Fp8 moe kernels require a single activation scale.
363
388
# We take the max of all the scales in case they differ.
364
389
if self .static_input_scales :
390
+ assert self .input_quant .strategy == QuantizationStrategy .TENSOR
365
391
if (layer .w13_input_scale is None or layer .w2_input_scale is None ):
366
392
raise ValueError (
367
393
"QuantConfig has static quantization, but found "
@@ -377,24 +403,25 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
377
403
layer .w2_input_scale = torch .nn .Parameter (
378
404
layer .w2_input_scale .max (), requires_grad = False )
379
405
380
- # Fp8 moe kernel needs single weight scale for w13 per expert.
381
- # We take the max then dequant and requant each expert.
382
- assert layer .w13_weight_scale is not None
383
- shard_size = layer .intermediate_size_per_partition
384
- max_w13_scales = layer .w13_weight_scale .max (dim = 1 ).values
385
- for expert_id in range (layer .local_num_experts ):
386
- start = 0
387
- for shard_id in range (2 ):
388
- dq_weight = per_tensor_dequantize (
389
- layer .w13_weight [expert_id ][start :start + shard_size , :],
390
- layer .w13_weight_scale [expert_id ][shard_id ])
391
- layer .w13_weight [expert_id ][
392
- start :start + shard_size , :], _ = ops .scaled_fp8_quant (
393
- dq_weight , max_w13_scales [expert_id ])
394
- start += shard_size
395
-
396
- layer .w13_weight_scale = torch .nn .Parameter (max_w13_scales ,
397
- requires_grad = False )
406
+ # For Per-TENSOR case, Fp8 moe kernel needs single weight scale
407
+ # for w13 per expert. Use max then dequant and requant each expert.
408
+ if self .weight_quant .strategy == QuantizationStrategy .TENSOR :
409
+ assert layer .w13_weight_scale is not None
410
+ shard_size = layer .intermediate_size_per_partition
411
+ max_w13_scales = layer .w13_weight_scale .max (dim = 1 ).values
412
+ for expert_id in range (layer .local_num_experts ):
413
+ start = 0
414
+ for shard_id in range (2 ):
415
+ dq_weight = per_tensor_dequantize (
416
+ layer .w13_weight [expert_id ][start :start +
417
+ shard_size , :],
418
+ layer .w13_weight_scale [expert_id ][shard_id ])
419
+ layer .w13_weight [expert_id ][
420
+ start :start + shard_size , :], _ = ops .scaled_fp8_quant (
421
+ dq_weight , max_w13_scales [expert_id ])
422
+ start += shard_size
423
+ layer .w13_weight_scale = torch .nn .Parameter (max_w13_scales ,
424
+ requires_grad = False )
398
425
399
426
def apply (
400
427
self ,
0 commit comments