diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index a20f8a05bc..9f77ad020d 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -40,10 +40,10 @@ class ToyLinearModel(torch.nn.Module): - def __init__(self, in_features, out_features): + def __init__(self, in_features, out_features, bias): super().__init__() - self.linear1 = torch.nn.Linear(in_features, out_features, bias=False) - self.linear2 = torch.nn.Linear(out_features, in_features, bias=False) + self.linear1 = torch.nn.Linear(in_features, out_features, bias=bias) + self.linear2 = torch.nn.Linear(out_features, in_features, bias=bias) def forward(self, x): x = self.linear1(x) @@ -104,6 +104,8 @@ def setUp(self): ((32, 128), 256, 512), ], ) + @common_utils.parametrize("bias", [False, True]) + @torch.no_grad() def test_fp8_linear_variants( self, dtype: torch.dtype, @@ -112,6 +114,7 @@ def test_fp8_linear_variants( granularity, kernel_preference: KernelPreference, sizes: Tuple, + bias: bool, ): if isinstance(granularity, PerTensor): if kernel_preference is KernelPreference.FBGEMM: @@ -132,6 +135,16 @@ def test_fp8_linear_variants( ): return unittest.skip("unimplemented") + if bias is True: + sizes_to_keep = ((128,), 256, 128) + if ( + sizes != sizes_to_keep + or kernel_preference is not KernelPreference.TORCH + ): + return unittest.skip( + "cut down on number of options to save test time" + ) + error_message = None if isinstance(granularity, PerRow): if mode == "dynamic" and dtype != torch.bfloat16: @@ -160,7 +173,7 @@ def test_fp8_linear_variants( input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") # Create a linear layer with bfloat16 dtype - model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + model = ToyLinearModel(K, N, bias).eval().to(dtype).to("cuda") quantized_model = copy.deepcopy(model) @@ -362,7 +375,7 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes): dtype = torch.bfloat16 input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") # Create a linear layer with bfloat16 dtype - model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + model = ToyLinearModel(K, N, bias=False).eval().to(dtype).to("cuda") # reference kernel preference and results # we are using KerenelPreference.TORCH as the reference diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 7bbcd81e1e..3581cb619c 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -370,6 +370,8 @@ def _(func, types, args, kwargs): w_scale.t(), block_size=128, ) + if bias is not None: + res = res + bias else: res = addmm_float8_unwrapped_inference( inpt_data,