Skip to content

Commit 3b90cc4

Browse files
committed
add bias handling for a_1_128_w_128_128 float8 scaling
Summary: As titled, adds support for bias and a unit test Test Plan: ``` pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -s -k fp8_linear_variants ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 2a5eb86 ghstack-comment-id: 3463298384 Pull-Request: #3259
1 parent b49178c commit 3b90cc4

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@
4040

4141

4242
class ToyLinearModel(torch.nn.Module):
43-
def __init__(self, in_features, out_features):
43+
def __init__(self, in_features, out_features, bias):
4444
super().__init__()
45-
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
46-
self.linear2 = torch.nn.Linear(out_features, in_features, bias=False)
45+
self.linear1 = torch.nn.Linear(in_features, out_features, bias=bias)
46+
self.linear2 = torch.nn.Linear(out_features, in_features, bias=bias)
4747

4848
def forward(self, x):
4949
x = self.linear1(x)
@@ -104,6 +104,8 @@ def setUp(self):
104104
((32, 128), 256, 512),
105105
],
106106
)
107+
@common_utils.parametrize("bias", [False, True])
108+
@torch.no_grad()
107109
def test_fp8_linear_variants(
108110
self,
109111
dtype: torch.dtype,
@@ -112,6 +114,7 @@ def test_fp8_linear_variants(
112114
granularity,
113115
kernel_preference: KernelPreference,
114116
sizes: Tuple,
117+
bias: bool,
115118
):
116119
if isinstance(granularity, PerTensor):
117120
if kernel_preference is KernelPreference.FBGEMM:
@@ -132,6 +135,16 @@ def test_fp8_linear_variants(
132135
):
133136
return unittest.skip("unimplemented")
134137

138+
if bias is True:
139+
sizes_to_keep = ((128,), 256, 128)
140+
if (
141+
sizes != sizes_to_keep
142+
or kernel_preference is not KernelPreference.TORCH
143+
):
144+
return unittest.skip(
145+
"cut down on number of options to save test time"
146+
)
147+
135148
error_message = None
136149
if isinstance(granularity, PerRow):
137150
if mode == "dynamic" and dtype != torch.bfloat16:
@@ -160,7 +173,7 @@ def test_fp8_linear_variants(
160173
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
161174

162175
# Create a linear layer with bfloat16 dtype
163-
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
176+
model = ToyLinearModel(K, N, bias).eval().to(dtype).to("cuda")
164177

165178
quantized_model = copy.deepcopy(model)
166179

@@ -362,7 +375,7 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
362375
dtype = torch.bfloat16
363376
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
364377
# Create a linear layer with bfloat16 dtype
365-
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
378+
model = ToyLinearModel(K, N, bias=False).eval().to(dtype).to("cuda")
366379

367380
# reference kernel preference and results
368381
# we are using KerenelPreference.TORCH as the reference

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ def _(func, types, args, kwargs):
370370
w_scale.t(),
371371
block_size=128,
372372
)
373+
if bias is not None:
374+
res = res + bias
373375
else:
374376
res = addmm_float8_unwrapped_inference(
375377
inpt_data,

0 commit comments

Comments
 (0)