Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 71 additions & 14 deletions test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8Tensor,
Float8WeightOnlyConfig,
Granularity,
PerBlock,
Expand All @@ -25,7 +26,6 @@
quantize_,
)
from torchao.quantization.quantize_.common import KernelPreference
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor
from torchao.quantization.utils import compute_error
from torchao.testing.utils import TorchAOIntegrationTestCase
from torchao.utils import (
Expand Down Expand Up @@ -329,14 +329,13 @@ def _test_fp8_matmul_model(
@unittest.skipIf(
not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0"
)
@unittest.skipIf(
not _is_fbgemm_gpu_genai_available(),
"Requires fbgemm_gpu_genai to be installed",
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize("granularity", [PerTensor()])
@common_utils.parametrize("inference_mode", [True, False])
@common_utils.parametrize(
"kernel_preference",
[KernelPreference.AUTO],
)
# only test for 3D conv for now
# Inputs are (N, C_in, C_out, D, H, W)
@common_utils.parametrize(
Expand All @@ -349,19 +348,14 @@ def test_fp8_conv_variants(
self,
dtype: torch.dtype,
compile: bool,
granularity,
inference_mode: bool,
kernel_preference: KernelPreference,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry just noticed that this is not removed, the test is skipped in CI so should not be ran for now, will remove in next PR

sizes: Tuple,
):
if (not _is_fbgemm_gpu_genai_available()) or (not is_sm_at_least_100()):
return unittest.skip(
"Requires fbgemm_gpu_genai and sm version >= 10.0 to run "
"fbgemm kernel preference test"
)

dim = 3
granularity = PerTensor()
kernel_preference = KernelPreference.AUTO
N, C_in, C_out, D, H, W = sizes
dim = 3
kernel_size = 3

# Note: this is channel last memory format
Expand Down Expand Up @@ -404,6 +398,69 @@ def test_fp8_conv_variants(
f"Quantization error is too high got a SQNR of {error}"
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0"
)
@unittest.skipIf(
not _is_fbgemm_gpu_genai_available(),
"Requires fbgemm_gpu_genai to be installed",
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
# only test for 3D conv for now
# Inputs are (N, C_in, C_out, D, H, W)
@common_utils.parametrize(
"sizes",
[
(4, 12, 64, 32, 32, 32),
(4, 16, 12, 32, 32, 32),
],
)
def test_fp8_conv_skip_quant(
self,
dtype: torch.dtype,
sizes: Tuple,
):
"""Some shapes are not supported so we won't quantize the module
Specifically, we skip quantization when C_in or C_out is not a multiple of 16
"""
granularity = PerTensor()
kernel_preference = KernelPreference.AUTO
N, C_in, C_out, D, H, W = sizes
dim = 3
kernel_size = 3

# Note: this is channel last memory format
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
# Create a linear layer with bfloat16 dtype
model = ToyConvModel(
dim,
C_in,
C_out,
kernel_size,
bias=False,
padding=0,
dtype=dtype,
device="cuda",
).eval()

quantized_model = copy.deepcopy(model)

config = Float8DynamicActivationFloat8WeightConfig(
granularity=granularity,
kernel_preference=kernel_preference,
)

_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)

quantize_(quantized_model, config, filter_fn=_is_conv3d)
assert not isinstance(quantized_model.conv.weight, Float8Tensor)

output_original = model(input_tensor)
output_quantized = quantized_model(input_tensor)
self.assertEqual(output_original, output_quantized)

@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@unittest.skipIf(
not is_sm_at_least_90(),
Expand Down
7 changes: 7 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,13 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
assert isinstance(activation_granularity, PerTensor) and isinstance(
weight_granularity, PerTensor
), "5D tensor only supports per tensor activation and weight quantization"

# weight dim: (C_out, C_in, K1, K2, K3)
# skip quantization when either C_out or C_in
# is not a multiple of 16
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
return weight

elif not _fp8_mm_compat(weight):
# TODO(future PR): this should really throw an exception instead of silently
# not doing what the user asked
Expand Down
Loading