Skip to content

Commit db72466

Browse files
committed
add bias handling for a_1_128_w_128_128 float8 scaling
Summary: As titled, adds support for bias and a unit test Test Plan: ``` pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -s -k fp8_linear_variants ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 07b623e ghstack-comment-id: 3463298384 Pull-Request: #3259
1 parent 7e8f7b7 commit db72466

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@
3939

4040

4141
class ToyLinearModel(torch.nn.Module):
42-
def __init__(self, in_features, out_features):
42+
def __init__(self, in_features, out_features, bias):
4343
super().__init__()
44-
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
45-
self.linear2 = torch.nn.Linear(out_features, in_features, bias=False)
44+
self.linear1 = torch.nn.Linear(in_features, out_features, bias=bias)
45+
self.linear2 = torch.nn.Linear(out_features, in_features, bias=bias)
4646

4747
def forward(self, x):
4848
x = self.linear1(x)
@@ -81,6 +81,8 @@ def setUp(self):
8181
((32, 128), 256, 512),
8282
],
8383
)
84+
@common_utils.parametrize("bias", [False, True])
85+
@torch.no_grad()
8486
def test_fp8_linear_variants(
8587
self,
8688
dtype: torch.dtype,
@@ -89,6 +91,7 @@ def test_fp8_linear_variants(
8991
granularity,
9092
kernel_preference: KernelPreference,
9193
sizes: Tuple,
94+
bias: bool,
9295
):
9396
if isinstance(granularity, PerTensor):
9497
if kernel_preference is KernelPreference.FBGEMM:
@@ -106,6 +109,16 @@ def test_fp8_linear_variants(
106109
elif kernel_preference is KernelPreference.FBGEMM:
107110
return unittest.skip("unimplemented")
108111

112+
if bias is True:
113+
if (
114+
sizes != (128,),
115+
256,
116+
128,
117+
) or kernel_preference is not KernelPreference.TORCH:
118+
return unittest.skip(
119+
"cut down on number of options to save test time"
120+
)
121+
109122
error_message = None
110123
if isinstance(granularity, PerRow):
111124
if mode == "dynamic" and dtype != torch.bfloat16:
@@ -134,7 +147,7 @@ def test_fp8_linear_variants(
134147
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
135148

136149
# Create a linear layer with bfloat16 dtype
137-
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
150+
model = ToyLinearModel(K, N, bias).eval().to(dtype).to("cuda")
138151

139152
quantized_model = copy.deepcopy(model)
140153

@@ -257,7 +270,7 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
257270
dtype = torch.bfloat16
258271
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
259272
# Create a linear layer with bfloat16 dtype
260-
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
273+
model = ToyLinearModel(K, N, bias=False).eval().to(dtype).to("cuda")
261274

262275
# reference kernel preference and results
263276
# we are using KerenelPreference.TORCH as the reference

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ def _(func, types, args, kwargs):
370370
w_scale.t(),
371371
block_size=128,
372372
)
373+
if bias is not None:
374+
res = res + bias
373375
else:
374376
res = addmm_float8_unwrapped_inference(
375377
inpt_data,

0 commit comments

Comments
 (0)