Skip to content

Commit ec317fc

Browse files
authored
[int8 woq] make the scale type the same as input for bf16 autocast (#534)
1 parent d477c0e commit ec317fc

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -795,17 +795,15 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
795795
w_vals_int8_t = weight_qtensor.layout_tensor.int_data.t()
796796
scale = weight_qtensor.layout_tensor.scale
797797
orig_dtype = input_tensor.dtype
798-
y = (
799-
torch.mm(
798+
m = torch.mm(
800799
input_tensor.reshape(-1, input_tensor.shape[-1]),
801800
w_vals_int8_t.to(input_tensor.dtype),
802801
)
803-
* scale
804-
)
802+
y = m * scale.to(m.dtype)
805803
y = y.reshape(*input_tensor.shape[:-1], y.shape[-1])
806804
if bias is not None:
807-
y += bias
808-
return y.to(orig_dtype)
805+
y += bias.to(m.dtype)
806+
return y
809807

810808
# is_cpu and is_mps only, some issue with is_contiguous() currently
811809
# return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_qtensor.layout_tensor.scale)

0 commit comments

Comments
 (0)