Skip to content

Commit e88c01d

Browse files
committed
Float8 autoquant weight only
1 parent b8ab4ee commit e88c01d

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

scripts/hf_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def all_linear(mod, name):
8989
with torch.no_grad():
9090
result = evaluate(
9191
HFLM(
92-
pretrained=model,
92+
pretrained=model.to(device),
9393
tokenizer=tokenizer,
9494
batch_size=batch_size,
9595
max_length=max_length),

torchao/quantization/autoquant.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,12 +479,19 @@ def from_float(cls, weight):
479479

480480
class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
481481
"""
482-
AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight
482+
AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn
483483
"""
484+
target_dtype: torch.dtype = torch.float8_e4m3fn
485+
486+
@staticmethod
487+
def _quantized_linear_op(act_mat, w_qtensor, bias):
488+
return torch.nn.functional.linear(act_mat, w_qtensor.dequantize(), bias)
489+
484490
@classmethod
485491
def from_float(cls, weight):
486492
block_size = (1, weight.shape[1])
487-
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=torch.float8_e4m3fn, layout_type=Float8LayoutType())
493+
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType())
494+
488495

489496
# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
490497
DEFAULT_AUTOQUANT_CLASS_LIST = [
@@ -500,7 +507,7 @@ def from_float(cls, weight):
500507
DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [
501508
AQFloatLinearWeight,
502509
AQInt8DynamicallyQuantizedLinearWeight,
503-
AQInt4G64WeightOnlyQuantizedLinearWeight,
510+
AQInt4G64WeightOnlyQuantizedLinearWeight
504511
]
505512

506513
def _change_linears_to_autoquantizable(model, **kwargs):

0 commit comments

Comments
 (0)