Skip to content

Commit 4a29159

Browse files
committed
Update
[ghstack-poisoned]
2 parents 526b741 + 76671f9 commit 4a29159

File tree

5 files changed

+231
-3
lines changed

5 files changed

+231
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ If you believe there's other CUDA kernels we should be taking a closer look at p
254254

255255
TorchAO is integrated into some of the leading open-source libraries including:
256256

257-
* Unsloth for QAT, blog post coming soon!
257+
* Unsloth now supports QAT: [Read blog](https://docs.unsloth.ai/new/quantization-aware-training-qat) and [guide](https://docs.unsloth.ai/new/quantization-aware-training-qat#qat--lora-finetuning).
258258
* HuggingFace transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865)
259259
* HuggingFace diffusers best practices with `torch.compile` and TorchAO in a standalone repo [diffusers-torchao](https://github.com/huggingface/diffusers/blob/main/docs/source/en/quantization/torchao.md)
260260
* vLLM for LLM serving: [usage](https://docs.vllm.ai/en/latest/features/quantization/torchao.html), [detailed docs](https://docs.pytorch.org/ao/main/torchao_vllm_integration.html)

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
_is_fbgemm_gpu_genai_available,
3232
is_sm_at_least_89,
3333
is_sm_at_least_90,
34+
is_sm_at_least_100,
3435
torch_version_at_least,
3536
)
3637

@@ -50,6 +51,28 @@ def forward(self, x):
5051
return x
5152

5253

54+
class ToyConvModel(torch.nn.Module):
55+
def __init__(
56+
self, dim, in_channels, out_channels, kernel_size, bias, padding, dtype, device
57+
):
58+
super().__init__()
59+
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
60+
self.conv = convs[dim](
61+
in_channels,
62+
out_channels,
63+
kernel_size,
64+
bias=bias,
65+
padding=padding,
66+
dtype=dtype,
67+
device=device,
68+
)
69+
if dim == 3:
70+
self.conv = self.conv.to(memory_format=torch.channels_last_3d)
71+
72+
def forward(self, x):
73+
return self.conv(x)
74+
75+
5376
# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations
5477
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
5578
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -190,6 +213,85 @@ def test_fp8_linear_variants(
190213
f"Quantization error is too high got a SQNR of {error}"
191214
)
192215

216+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
217+
@unittest.skipIf(
218+
not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0"
219+
)
220+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
221+
@common_utils.parametrize("compile", [True, False])
222+
@common_utils.parametrize("granularity", [PerTensor()])
223+
@common_utils.parametrize("inference_mode", [True, False])
224+
@common_utils.parametrize(
225+
"kernel_preference",
226+
[KernelPreference.AUTO],
227+
)
228+
# only test for 3D conv for now
229+
# Inputs are (N, C_in, C_out, D, H, W)
230+
@common_utils.parametrize(
231+
"sizes",
232+
[
233+
(4, 16, 64, 32, 32, 32),
234+
],
235+
)
236+
def test_fp8_conv_variants(
237+
self,
238+
dtype: torch.dtype,
239+
compile: bool,
240+
granularity,
241+
inference_mode: bool,
242+
kernel_preference: KernelPreference,
243+
sizes: Tuple,
244+
):
245+
if (not _is_fbgemm_gpu_genai_available()) or (not is_sm_at_least_100()):
246+
return unittest.skip(
247+
"Requires fbgemm_gpu_genai and sm version >= 10.0 to run "
248+
"fbgemm kernel preference test"
249+
)
250+
251+
dim = 3
252+
N, C_in, C_out, D, H, W = sizes
253+
kernel_size = 3
254+
255+
# Note: this is channel last memory format
256+
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
257+
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
258+
259+
# Create a linear layer with bfloat16 dtype
260+
model = ToyConvModel(
261+
dim,
262+
C_in,
263+
C_out,
264+
kernel_size,
265+
bias=False,
266+
padding=0,
267+
dtype=dtype,
268+
device="cuda",
269+
).eval()
270+
271+
quantized_model = copy.deepcopy(model)
272+
273+
config = Float8DynamicActivationFloat8WeightConfig(
274+
granularity=granularity,
275+
kernel_preference=kernel_preference,
276+
)
277+
278+
_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)
279+
280+
quantize_(quantized_model, config, filter_fn=_is_conv3d)
281+
282+
if compile:
283+
quantized_model = torch.compile(quantized_model, fullgraph=True)
284+
285+
inference_mode_ctx = torch.inference_mode() if inference_mode else nullcontext()
286+
with inference_mode_ctx:
287+
output_original = model(input_tensor)
288+
output_quantized = quantized_model(input_tensor)
289+
290+
error = compute_error(output_original, output_quantized)
291+
assert compute_error(output_original, output_quantized) > 20, (
292+
f"Quantization error is too high got a SQNR of {error}"
293+
)
294+
193295
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
194296
@unittest.skipIf(
195297
not is_sm_at_least_90(),

torchao/quantization/quant_api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1811,7 +1811,12 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18111811
_check_hardware_support(granularity)
18121812
activation_granularity, weight_granularity = granularity
18131813

1814-
if not _fp8_mm_compat(weight):
1814+
if weight.dim() == 5:
1815+
# weights for conv3d
1816+
assert isinstance(activation_granularity, PerTensor) and isinstance(
1817+
weight_granularity, PerTensor
1818+
), "5D tensor only supports per tensor activation and weight quantization"
1819+
elif not _fp8_mm_compat(weight):
18151820
# TODO(future PR): this should really throw an exception instead of silently
18161821
# not doing what the user asked
18171822
return weight

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
_is_fbgemm_gpu_genai_available,
4545
fill_defaults,
4646
is_sm_at_least_90,
47+
is_sm_at_least_100,
4748
)
4849

4950
__all__ = [
@@ -266,7 +267,7 @@ def _(func, types, args, kwargs):
266267
)
267268

268269
act_quant_kwargs = weight_tensor.act_quant_kwargs
269-
# quantizing activation, if `act_quant_kwargs` is specified
270+
# quantize activation, if `act_quant_kwargs` is specified
270271
if act_quant_kwargs is not None:
271272
input_tensor = _choose_quant_func_and_quantize_tensor(
272273
input_tensor, act_quant_kwargs
@@ -450,6 +451,125 @@ def _(func, types, args, kwargs):
450451
return res
451452

452453

454+
def _quantize_and_scaled_conv3d(
455+
input_tensor,
456+
weight_tensor,
457+
bias,
458+
stride,
459+
padding,
460+
dilation,
461+
):
462+
assert isinstance(weight_tensor, Float8Tensor), (
463+
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
464+
)
465+
466+
assert input_tensor.dim() == 5 and weight_tensor.dim() == 5, (
467+
"Only support 3D conv currently"
468+
)
469+
assert _is_fbgemm_gpu_genai_available(), (
470+
"quantized fp8 conv3d requires fbgemm_gpu_genai to be available"
471+
)
472+
act_quant_kwargs = weight_tensor.act_quant_kwargs
473+
# quantize activation, if `act_quant_kwargs` is specified
474+
if act_quant_kwargs is not None:
475+
input_tensor = _choose_quant_func_and_quantize_tensor(
476+
input_tensor, act_quant_kwargs
477+
)
478+
479+
if isinstance(input_tensor, Float8Tensor):
480+
kernel_choice = None
481+
if weight_tensor.kernel_preference == KernelPreference.AUTO:
482+
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_100():
483+
kernel_choice = "fbgemm"
484+
else:
485+
raise NotImplementedError(
486+
f"No available kernel choice for {weight_tensor.kernel_preference}"
487+
)
488+
elif weight_tensor.kernel_preference == KernelPreference.FBGEMM:
489+
kernel_choice = "fbgemm"
490+
else:
491+
raise NotImplementedError(
492+
f"No available kernel choice for {weight_tensor.kernel_preference}"
493+
)
494+
495+
assert kernel_choice == "fbgemm", "Only fbgemm kernel choice is supported currently"
496+
# move C_in to last dim
497+
# after permute: (N, D, H, W, C_in)
498+
act_qdata = input_tensor.qdata.permute([0, 2, 3, 4, 1])
499+
500+
# move C_in to last dim
501+
# after permute: (C_out, K1, K2, K3, C_in)
502+
weight_qdata = weight_tensor.qdata.permute([0, 2, 3, 4, 1])
503+
504+
assert act_qdata.is_contiguous() and weight_qdata.is_contiguous(), (
505+
"Please make sure both activation and weights are in the `channels_last_3d` memory_format"
506+
)
507+
508+
act_scale = input_tensor.scale
509+
weight_scale = weight_tensor.scale
510+
output = torch.ops.fbgemm.f8f8bf16_conv(
511+
act_qdata,
512+
weight_qdata,
513+
act_scale * weight_scale,
514+
padding,
515+
stride,
516+
dilation,
517+
)
518+
# output shape after permute: N, C_out, D_out, H_out, W_out
519+
output = output.permute([0, 4, 1, 2, 3])
520+
return output
521+
522+
523+
@implements(aten.convolution.default)
524+
def _(func, types, args, kwargs):
525+
(
526+
input_tensor,
527+
weight_tensor,
528+
bias,
529+
stride,
530+
padding,
531+
dilation,
532+
transposed,
533+
output_padding,
534+
groups,
535+
) = args
536+
assert not transposed, "transposed conv is not supported currently"
537+
assert tuple(output_padding) == (0, 0, 0), (
538+
f"Only (0, 0, 0) is supported for `output_padding`, got: f{output_padding}"
539+
)
540+
assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}"
541+
return _quantize_and_scaled_conv3d(
542+
input_tensor,
543+
weight_tensor,
544+
bias,
545+
stride,
546+
padding,
547+
dilation,
548+
)
549+
550+
551+
@implements(aten.conv3d.default)
552+
def _(func, types, args, kwargs):
553+
(
554+
input_tensor,
555+
weight_tensor,
556+
bias,
557+
stride,
558+
padding,
559+
dilation,
560+
groups,
561+
) = fill_defaults(args, 7, [None, [1, 1, 1], [0, 0, 0], [1, 1, 1], 1])
562+
assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}"
563+
return _quantize_and_scaled_conv3d(
564+
input_tensor,
565+
weight_tensor,
566+
bias,
567+
stride,
568+
padding,
569+
dilation,
570+
)
571+
572+
453573
@implements(aten.slice.Tensor)
454574
def _(func, types, args, kwargs):
455575
"""Supports slicing for 1d, 2d, and 3d tensors

torchao/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"is_MI300",
3333
"is_sm_at_least_89",
3434
"is_sm_at_least_90",
35+
"is_sm_at_least_100",
3536
"is_package_at_least",
3637
"DummyModule",
3738
# Deprecated

0 commit comments

Comments
 (0)