@@ -479,12 +479,19 @@ def from_float(cls, weight):
479
479
480
480
class AQFloat8WeightOnlyQuantizedLinearWeight (AffineQuantizedTensor , AQMixin ):
481
481
"""
482
- AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight
482
+ AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn
483
483
"""
484
+ target_dtype : torch .dtype = torch .float8_e4m3fn
485
+
486
+ @staticmethod
487
+ def _quantized_linear_op (act_mat , w_qtensor , bias ):
488
+ return torch .nn .functional .linear (act_mat , w_qtensor .dequantize (), bias )
489
+
484
490
@classmethod
485
491
def from_float (cls , weight ):
486
492
block_size = (1 , weight .shape [1 ])
487
- return super (AQFloat8WeightOnlyQuantizedLinearWeight , cls ).from_hp_to_floatx (weight , block_size , target_dtype = torch .float8_e4m3fn , layout_type = Float8LayoutType ())
493
+ return super (AQFloat8WeightOnlyQuantizedLinearWeight , cls ).from_hp_to_floatx (weight , block_size , target_dtype = cls .target_dtype , layout_type = Float8LayoutType ())
494
+
488
495
489
496
# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
490
497
DEFAULT_AUTOQUANT_CLASS_LIST = [
@@ -500,7 +507,7 @@ def from_float(cls, weight):
500
507
DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [
501
508
AQFloatLinearWeight ,
502
509
AQInt8DynamicallyQuantizedLinearWeight ,
503
- AQInt4G64WeightOnlyQuantizedLinearWeight ,
510
+ AQInt4G64WeightOnlyQuantizedLinearWeight
504
511
]
505
512
506
513
def _change_linears_to_autoquantizable (model , ** kwargs ):
0 commit comments