Skip to content

Commit 3f904ed

Browse files
committed
add bias handling for a_1_128_w_128_128 float8 scaling
Summary: As titled, adds support for bias and a unit test Test Plan: ``` pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -s -k fp8_linear_variants ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 50f7a26 ghstack-comment-id: 3463298384 Pull-Request: #3259
1 parent 839d480 commit 3f904ed

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@
3939

4040

4141
class ToyLinearModel(torch.nn.Module):
42-
def __init__(self, in_features, out_features):
42+
def __init__(self, in_features, out_features, bias):
4343
super().__init__()
44-
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
45-
self.linear2 = torch.nn.Linear(out_features, in_features, bias=False)
44+
self.linear1 = torch.nn.Linear(in_features, out_features, bias=bias)
45+
self.linear2 = torch.nn.Linear(out_features, in_features, bias=bias)
4646

4747
def forward(self, x):
4848
x = self.linear1(x)
@@ -81,6 +81,8 @@ def setUp(self):
8181
((32, 128), 256, 512),
8282
],
8383
)
84+
@common_utils.parametrize("bias", [False, True])
85+
@torch.no_grad()
8486
def test_fp8_linear_variants(
8587
self,
8688
dtype: torch.dtype,
@@ -89,6 +91,7 @@ def test_fp8_linear_variants(
8991
granularity,
9092
kernel_preference: KernelPreference,
9193
sizes: Tuple,
94+
bias: bool,
9295
):
9396
if isinstance(granularity, PerTensor):
9497
if kernel_preference is KernelPreference.FBGEMM:
@@ -106,6 +109,16 @@ def test_fp8_linear_variants(
106109
elif kernel_preference is KernelPreference.FBGEMM:
107110
return unittest.skip("unimplemented")
108111

112+
if bias is True:
113+
sizes_to_keep = ((128,), 256, 128)
114+
if (
115+
sizes != sizes_to_keep
116+
or kernel_preference is not KernelPreference.TORCH
117+
):
118+
return unittest.skip(
119+
"cut down on number of options to save test time"
120+
)
121+
109122
error_message = None
110123
if isinstance(granularity, PerRow):
111124
if mode == "dynamic" and dtype != torch.bfloat16:
@@ -134,7 +147,7 @@ def test_fp8_linear_variants(
134147
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
135148

136149
# Create a linear layer with bfloat16 dtype
137-
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
150+
model = ToyLinearModel(K, N, bias).eval().to(dtype).to("cuda")
138151

139152
quantized_model = copy.deepcopy(model)
140153

@@ -257,7 +270,7 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
257270
dtype = torch.bfloat16
258271
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
259272
# Create a linear layer with bfloat16 dtype
260-
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
273+
model = ToyLinearModel(K, N, bias=False).eval().to(dtype).to("cuda")
261274

262275
# reference kernel preference and results
263276
# we are using KerenelPreference.TORCH as the reference

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,14 +362,15 @@ def _(func, types, args, kwargs):
362362
# TODO(future PR): add fbgemm_gpu_genai path if available
363363
# TODO(before land): proper out_dtype handling
364364
assert _is_1_128_scaled(input_tensor), "unsupported"
365-
# breakpoint()
366365
res = blockwise_fp8_gemm(
367366
inpt_data,
368367
input_scale,
369368
w_data.t(),
370369
w_scale.t(),
371370
block_size=128,
372371
)
372+
if bias is not None:
373+
res = res + bias
373374
else:
374375
res = addmm_float8_unwrapped_inference(
375376
inpt_data,

0 commit comments

Comments
 (0)