Skip to content
Merged
23 changes: 18 additions & 5 deletions test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@


class ToyLinearModel(torch.nn.Module):
def __init__(self, in_features, out_features):
def __init__(self, in_features, out_features, bias):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: feels like bias is a bit confusing (since it can be a flag v.s. Tensor), even though it's used official in nn.Linear, maybe use has_bias as the other tests are doing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, let me do that in a future PR

super().__init__()
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
self.linear2 = torch.nn.Linear(out_features, in_features, bias=False)
self.linear1 = torch.nn.Linear(in_features, out_features, bias=bias)
self.linear2 = torch.nn.Linear(out_features, in_features, bias=bias)

def forward(self, x):
x = self.linear1(x)
Expand Down Expand Up @@ -104,6 +104,8 @@ def setUp(self):
((32, 128), 256, 512),
],
)
@common_utils.parametrize("bias", [False, True])
@torch.no_grad()
def test_fp8_linear_variants(
self,
dtype: torch.dtype,
Expand All @@ -112,6 +114,7 @@ def test_fp8_linear_variants(
granularity,
kernel_preference: KernelPreference,
sizes: Tuple,
bias: bool,
):
if isinstance(granularity, PerTensor):
if kernel_preference is KernelPreference.FBGEMM:
Expand All @@ -132,6 +135,16 @@ def test_fp8_linear_variants(
):
return unittest.skip("unimplemented")

if bias is True:
sizes_to_keep = ((128,), 256, 128)
if (
sizes != sizes_to_keep
or kernel_preference is not KernelPreference.TORCH
):
return unittest.skip(
"cut down on number of options to save test time"
)

error_message = None
if isinstance(granularity, PerRow):
if mode == "dynamic" and dtype != torch.bfloat16:
Expand Down Expand Up @@ -160,7 +173,7 @@ def test_fp8_linear_variants(
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")

# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
model = ToyLinearModel(K, N, bias).eval().to(dtype).to("cuda")

quantized_model = copy.deepcopy(model)

Expand Down Expand Up @@ -362,7 +375,7 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
dtype = torch.bfloat16
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
model = ToyLinearModel(K, N, bias=False).eval().to(dtype).to("cuda")

# reference kernel preference and results
# we are using KerenelPreference.TORCH as the reference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ def _(func, types, args, kwargs):
w_scale.t(),
block_size=128,
)
if bias is not None:
res = res + bias
else:
res = addmm_float8_unwrapped_inference(
inpt_data,
Expand Down
Loading