Skip to content

Commit 16da39f

Browse files
committed
mxtensor: add pre-swizzle support
Summary: Adds the ability to pre-swizzle scales for `MXTensor`, and turns it on for the inference workflow. For activations, this is no-change for now but if we write a fused kernel we'll hook into the pre-swizzled path. For weights, this is a performance win in this PR as now we swizzle ahead of time. Rough magnitude of the weight pre-swizzling win: on M, K, N == 4096, 4096, 4096, the inference fwd speedup on mxfp8 increases from 1.24x to 1.30x Test Plan: ```bash // correctness CUDA_VISIBLE_DEVICES=5 pytest test/prototype/mx_formats/ -s // performance CUDA_VISIBLE_DEVICES=5 python benchmarks/float8/float8_inference_roofline.py ~/local/tmp/20251017_test.csv --recipe_name mxfp8_cublas --shape_gen_name pow2_extended // before: https://www.internalfb.com/phabricator/paste/view/P1996942931 // after: https://www.internalfb.com/phabricator/paste/view/P1996941798 ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 46b8d23 ghstack-comment-id: 3415966576 Pull-Request: #3200
1 parent bbfd981 commit 16da39f

File tree

5 files changed

+361
-240
lines changed

5 files changed

+361
-240
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# This source code is licensed under the license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
import math
9+
810
import pytest
911
import torch
1012
from torch._inductor.utils import run_and_get_code
@@ -22,6 +24,7 @@
2224
ScaleCalculationMode,
2325
to_dtype,
2426
)
27+
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
2528
from torchao.quantization.utils import compute_error
2629
from torchao.utils import (
2730
is_sm_at_least_89,
@@ -388,6 +391,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
388391
MXGemmKernelChoice.EMULATED,
389392
pack_fp6,
390393
None,
394+
False,
391395
)
392396
tensor_hp = tensor_mx.dequantize(torch.float)
393397
assert torch.all(torch.isnan(tensor_hp.flatten()[0:4]))
@@ -645,8 +649,6 @@ def to_f8(x):
645649
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
646650
)
647651
def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
648-
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
649-
650652
rows, cols = shape
651653
device = "cuda" if torch.cuda.is_available() else "cpu"
652654

@@ -716,3 +718,68 @@ def test_scale_shape_matches_qdata(transpose, shape):
716718
assert expected_padded_k == actual_padded_k, (
717719
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x.scale.shape}"
718720
)
721+
722+
723+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
724+
@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
725+
@pytest.mark.parametrize("elem_dtype", (torch.float8_e4m3fn, torch.float4_e2m1fn_x2))
726+
@pytest.mark.parametrize("transpose", [False, True])
727+
@pytest.mark.parametrize(
728+
"shape",
729+
(
730+
(128, 64),
731+
(1, 128, 64),
732+
),
733+
)
734+
def test_swizzle(elem_dtype, transpose, shape):
735+
if len(shape) == 3 and transpose:
736+
pytest.skip("transpose not yet implemented for 3D MXTensor")
737+
738+
block_size = 32
739+
740+
x_hp = torch.randn(*shape, device="cuda")
741+
x = MXTensor.to_mx(
742+
x_hp,
743+
elem_dtype,
744+
block_size,
745+
ScaleCalculationMode.FLOOR,
746+
)
747+
748+
xs = MXTensor.to_mx(
749+
x_hp,
750+
elem_dtype,
751+
block_size,
752+
ScaleCalculationMode.FLOOR,
753+
is_swizzled_scales=True,
754+
)
755+
756+
if transpose:
757+
x = x.t()
758+
xs = xs.t()
759+
760+
torch.testing.assert_close(x.qdata, xs.qdata, atol=0, rtol=0)
761+
762+
if transpose:
763+
leading_dims, M, K = x.shape[:-2], x.shape[-1], x.shape[-2]
764+
xs_scale_unblocked = from_blocked(
765+
xs.scale.t(), math.prod(leading_dims) * M, K // block_size
766+
)
767+
xs_scale_unblocked = xs_scale_unblocked.view(*leading_dims, M, K // block_size)
768+
xs_scale_unblocked = xs_scale_unblocked.t()
769+
else:
770+
leading_dims, M, K = x.shape[:-2], x.shape[-2], x.shape[-1]
771+
xs_scale_unblocked = from_blocked(
772+
xs.scale, math.prod(leading_dims) * M, K // block_size
773+
)
774+
xs_scale_unblocked = xs_scale_unblocked.view(*leading_dims, M, K // block_size)
775+
776+
torch.testing.assert_close(
777+
x.scale,
778+
xs_scale_unblocked,
779+
atol=0,
780+
rtol=0,
781+
)
782+
783+
x_dq = x.dequantize(x.dtype)
784+
xs_dq = xs.dequantize(xs.dtype)
785+
torch.testing.assert_close(x_dq, xs_dq, atol=0, rtol=0)

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def _mx_inference_linear_transform(
111111
block_size=config.block_size,
112112
gemm_kernel_choice=config.gemm_kernel_choice,
113113
pack_fp6=False,
114+
is_swizzled_scales=True,
114115
)
115116

116117
# Convert weight to MX Tensor
@@ -121,6 +122,7 @@ def _mx_inference_linear_transform(
121122
gemm_kernel_choice=config.gemm_kernel_choice,
122123
pack_fp6=False, # TODO
123124
act_quant_kwargs=act_quant_kwargs,
125+
is_swizzled_scales=True,
124126
)
125127

126128
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)

0 commit comments

Comments
 (0)