Skip to content

Commit 3339fd7

Browse files
committed
Add per tensor fp8 quantization support conv3d
Summary: att, we added support of quantization conv3d weights, with Float8DynamicActivationFloat8WeightConfig API: ``` config = Float8DynamicActivationFloat8WeightConfig( granularity=PerTensor(), ) _is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d) quantize_(quantized_model, config, filter_fn=_is_conv3d) ``` Test Plan: pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants Reviewers: Subscribers: Tasks: Tags:
1 parent 7e5d907 commit 3339fd7

File tree

4 files changed

+226
-3
lines changed

4 files changed

+226
-3
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_is_fbgemm_gpu_genai_available,
3131
is_sm_at_least_89,
3232
is_sm_at_least_90,
33+
is_sm_at_least_100,
3334
torch_version_at_least,
3435
)
3536

@@ -49,10 +50,32 @@ def forward(self, x):
4950
return x
5051

5152

53+
class ToyConvModel(torch.nn.Module):
54+
def __init__(
55+
self, dim, in_channels, out_channels, kernel_size, bias, padding, dtype, device
56+
):
57+
super().__init__()
58+
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
59+
self.conv = convs[dim](
60+
in_channels,
61+
out_channels,
62+
kernel_size,
63+
bias=bias,
64+
padding=padding,
65+
dtype=dtype,
66+
device=device,
67+
)
68+
if dim == 3:
69+
self.conv = self.conv.to(memory_format=torch.channels_last_3d)
70+
71+
def forward(self, x):
72+
return self.conv(x)
73+
74+
5275
# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations
5376
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
5477
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
55-
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
78+
# @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
5679
class TestFloat8Tensor(TorchAOIntegrationTestCase):
5780
def setUp(self):
5881
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
@@ -148,6 +171,85 @@ def test_fp8_linear_variants(
148171
f"Quantization error is too high got a SQNR of {error}"
149172
)
150173

174+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
175+
@unittest.skipIf(
176+
not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0"
177+
)
178+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
179+
@common_utils.parametrize("compile", [True, False])
180+
@common_utils.parametrize("granularity", [PerTensor()])
181+
@common_utils.parametrize("inference_mode", [True, False])
182+
@common_utils.parametrize(
183+
"kernel_preference",
184+
[KernelPreference.AUTO],
185+
)
186+
# only test for 3D conv for now
187+
# Inputs are (N, C_in, C_out, D, H, W)
188+
@common_utils.parametrize(
189+
"sizes",
190+
[
191+
(4, 16, 64, 32, 32, 32),
192+
],
193+
)
194+
def test_fp8_conv_variants(
195+
self,
196+
dtype: torch.dtype,
197+
compile: bool,
198+
granularity,
199+
inference_mode: bool,
200+
kernel_preference: KernelPreference,
201+
sizes: Tuple,
202+
):
203+
if (not _is_fbgemm_gpu_genai_available()) or (not is_sm_at_least_100()):
204+
return unittest.skip(
205+
"Requires fbgemm_gpu_genai and sm version >= 10.0 to run "
206+
"fbgemm kernel preference test"
207+
)
208+
209+
dim = 3
210+
N, C_in, C_out, D, H, W = sizes
211+
kernel_size = 3
212+
213+
# Note: this is channel last memory format
214+
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
215+
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
216+
217+
# Create a linear layer with bfloat16 dtype
218+
model = ToyConvModel(
219+
dim,
220+
C_in,
221+
C_out,
222+
kernel_size,
223+
bias=False,
224+
padding=0,
225+
dtype=dtype,
226+
device="cuda",
227+
).eval()
228+
229+
quantized_model = copy.deepcopy(model)
230+
231+
config = Float8DynamicActivationFloat8WeightConfig(
232+
granularity=granularity,
233+
kernel_preference=kernel_preference,
234+
)
235+
236+
_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)
237+
238+
quantize_(quantized_model, config, filter_fn=_is_conv3d)
239+
240+
if compile:
241+
quantized_model = torch.compile(quantized_model, fullgraph=True)
242+
243+
inference_mode_ctx = torch.inference_mode() if inference_mode else nullcontext()
244+
with inference_mode_ctx:
245+
output_original = model(input_tensor)
246+
output_quantized = quantized_model(input_tensor)
247+
248+
error = compute_error(output_original, output_quantized)
249+
assert compute_error(output_original, output_quantized) > 20, (
250+
f"Quantization error is too high got a SQNR of {error}"
251+
)
252+
151253
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
152254
@unittest.skipIf(
153255
not is_sm_at_least_90(),

torchao/quantization/quant_api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1813,7 +1813,11 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18131813
_check_hardware_support(granularity)
18141814
activation_granularity, weight_granularity = granularity
18151815

1816-
if not _fp8_mm_compat(weight):
1816+
if weight.dim() == 5:
1817+
# weights for conv3d
1818+
assert isinstance(activation_granularity, PerTensor) and \
1819+
isinstance(weight_granularity, PerTensor), "5D tensor only supports per tensor activation and weight quantization"
1820+
elif not _fp8_mm_compat(weight):
18171821
# TODO(future PR): this should really throw an exception instead of silently
18181822
# not doing what the user asked
18191823
return weight

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
_is_fbgemm_gpu_genai_available,
4040
fill_defaults,
4141
is_sm_at_least_90,
42+
is_sm_at_least_100,
4243
)
4344

4445
__all__ = [
@@ -261,7 +262,7 @@ def _(func, types, args, kwargs):
261262
)
262263

263264
act_quant_kwargs = weight_tensor.act_quant_kwargs
264-
# quantizing activation, if `act_quant_kwargs` is specified
265+
# quantize activation, if `act_quant_kwargs` is specified
265266
if act_quant_kwargs is not None:
266267
input_tensor = _choose_quant_func_and_quantize_tensor(
267268
input_tensor, act_quant_kwargs
@@ -418,6 +419,121 @@ def _(func, types, args, kwargs):
418419
return res
419420

420421

422+
def _quantize_and_scaled_conv3d(
423+
input_tensor,
424+
weight_tensor,
425+
bias,
426+
stride,
427+
padding,
428+
dilation,
429+
):
430+
assert isinstance(weight_tensor, Float8Tensor), (
431+
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
432+
)
433+
434+
assert input_tensor.dim() == 5 and weight_tensor.dim() == 5, (
435+
"Only support 3D conv currently"
436+
)
437+
assert _is_fbgemm_gpu_genai_available(), (
438+
"quantized fp8 conv3d requires fbgemm_gpu_genai to be available"
439+
)
440+
act_quant_kwargs = weight_tensor.act_quant_kwargs
441+
# quantize activation, if `act_quant_kwargs` is specified
442+
if act_quant_kwargs is not None:
443+
input_tensor = _choose_quant_func_and_quantize_tensor(
444+
input_tensor, act_quant_kwargs
445+
)
446+
447+
if isinstance(input_tensor, Float8Tensor):
448+
kernel_choice = None
449+
if weight_tensor.kernel_preference == KernelPreference.AUTO:
450+
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_100():
451+
kernel_choice = "fbgemm"
452+
else:
453+
raise NotImplementedError(f"No available kernel choice for {weight_tensor.kernel_preference}")
454+
elif weight_tensor.kernel_preference == KernelPreference.FBGEMM:
455+
kernel_choice = "fbgemm"
456+
else:
457+
raise NotImplementedError(f"No available kernel choice for {weight_tensor.kernel_preference}")
458+
459+
assert kernel_choice == "fbgemm", "Only fbgemm kernel choice is supported currently"
460+
# move C_in to last dim
461+
# after permute: (N, D, H, W, C_in)
462+
act_qdata = input_tensor.qdata.permute([0, 2, 3, 4, 1])
463+
464+
# move C_in to last dim
465+
# after permute: (C_out, K1, K2, K3, C_in)
466+
weight_qdata = weight_tensor.qdata.permute([0, 2, 3, 4, 1])
467+
468+
assert act_qdata.is_contiguous() and weight_qdata.is_contiguous(), (
469+
"Please make sure both activation and weights are in the `channels_last_3d` memory_format"
470+
)
471+
472+
act_scale = input_tensor.scale
473+
weight_scale = weight_tensor.scale
474+
output = torch.ops.fbgemm.f8f8bf16_conv(
475+
act_qdata,
476+
weight_qdata,
477+
act_scale * weight_scale,
478+
padding,
479+
stride,
480+
dilation,
481+
)
482+
# output shape after permute: N, C_out, D_out, H_out, W_out
483+
output = output.permute([0, 4, 1, 2, 3])
484+
return output
485+
486+
487+
@implements(aten.convolution.default)
488+
def _(func, types, args, kwargs):
489+
(
490+
input_tensor,
491+
weight_tensor,
492+
bias,
493+
stride,
494+
padding,
495+
dilation,
496+
transposed,
497+
output_padding,
498+
groups,
499+
) = args
500+
assert not transposed, "transposed conv is not supported currently"
501+
assert tuple(output_padding) == (0, 0, 0), (
502+
f"Only (0, 0, 0) is supported for `output_padding`, got: f{output_padding}"
503+
)
504+
assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}"
505+
return _quantize_and_scaled_conv3d(
506+
input_tensor,
507+
weight_tensor,
508+
bias,
509+
stride,
510+
padding,
511+
dilation,
512+
)
513+
514+
515+
@implements(aten.conv3d.default)
516+
def _(func, types, args, kwargs):
517+
(
518+
input_tensor,
519+
weight_tensor,
520+
bias,
521+
stride,
522+
padding,
523+
dilation,
524+
groups,
525+
) = fill_defaults(args, 7, [None, [1, 1, 1], [0, 0, 0], [1, 1, 1], 1])
526+
assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}"
527+
return _quantize_and_scaled_conv3d(
528+
input_tensor,
529+
weight_tensor,
530+
bias,
531+
stride,
532+
padding,
533+
dilation,
534+
)
535+
536+
421537
@implements(aten.slice.Tensor)
422538
def _(func, types, args, kwargs):
423539
"""Supports slicing for 1d, 2d, and 3d tensors

torchao/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"is_MI300",
3333
"is_sm_at_least_89",
3434
"is_sm_at_least_90",
35+
"is_sm_at_least_100",
3536
"is_package_at_least",
3637
"DummyModule",
3738
# Deprecated

0 commit comments

Comments
 (0)