|
5 | 5 | # This source code is licensed under the license found in the |
6 | 6 | # LICENSE file in the root directory of this source tree. |
7 | 7 |
|
| 8 | +import math |
| 9 | + |
8 | 10 | import pytest |
9 | 11 | import torch |
10 | 12 | from torch._inductor.utils import run_and_get_code |
|
22 | 24 | ScaleCalculationMode, |
23 | 25 | to_dtype, |
24 | 26 | ) |
| 27 | +from torchao.prototype.mx_formats.utils import from_blocked, to_blocked |
25 | 28 | from torchao.quantization.utils import compute_error |
26 | 29 | from torchao.utils import ( |
27 | 30 | is_sm_at_least_89, |
@@ -388,6 +391,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6): |
388 | 391 | MXGemmKernelChoice.EMULATED, |
389 | 392 | pack_fp6, |
390 | 393 | None, |
| 394 | + False, |
391 | 395 | ) |
392 | 396 | tensor_hp = tensor_mx.dequantize(torch.float) |
393 | 397 | assert torch.all(torch.isnan(tensor_hp.flatten()[0:4])) |
@@ -645,8 +649,6 @@ def to_f8(x): |
645 | 649 | not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" |
646 | 650 | ) |
647 | 651 | def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): |
648 | | - from torchao.prototype.mx_formats.utils import from_blocked, to_blocked |
649 | | - |
650 | 652 | rows, cols = shape |
651 | 653 | device = "cuda" if torch.cuda.is_available() else "cpu" |
652 | 654 |
|
@@ -716,3 +718,68 @@ def test_scale_shape_matches_qdata(transpose, shape): |
716 | 718 | assert expected_padded_k == actual_padded_k, ( |
717 | 719 | f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x.scale.shape}" |
718 | 720 | ) |
| 721 | + |
| 722 | + |
| 723 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 724 | +@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") |
| 725 | +@pytest.mark.parametrize("elem_dtype", (torch.float8_e4m3fn, torch.float4_e2m1fn_x2)) |
| 726 | +@pytest.mark.parametrize("transpose", [False, True]) |
| 727 | +@pytest.mark.parametrize( |
| 728 | + "shape", |
| 729 | + ( |
| 730 | + (128, 64), |
| 731 | + (1, 128, 64), |
| 732 | + ), |
| 733 | +) |
| 734 | +def test_swizzle(elem_dtype, transpose, shape): |
| 735 | + if len(shape) == 3 and transpose: |
| 736 | + pytest.skip("transpose not yet implemented for 3D MXTensor") |
| 737 | + |
| 738 | + block_size = 32 |
| 739 | + |
| 740 | + x_hp = torch.randn(*shape, device="cuda") |
| 741 | + x = MXTensor.to_mx( |
| 742 | + x_hp, |
| 743 | + elem_dtype, |
| 744 | + block_size, |
| 745 | + ScaleCalculationMode.FLOOR, |
| 746 | + ) |
| 747 | + |
| 748 | + xs = MXTensor.to_mx( |
| 749 | + x_hp, |
| 750 | + elem_dtype, |
| 751 | + block_size, |
| 752 | + ScaleCalculationMode.FLOOR, |
| 753 | + is_swizzled_scales=True, |
| 754 | + ) |
| 755 | + |
| 756 | + if transpose: |
| 757 | + x = x.t() |
| 758 | + xs = xs.t() |
| 759 | + |
| 760 | + torch.testing.assert_close(x.qdata, xs.qdata, atol=0, rtol=0) |
| 761 | + |
| 762 | + if transpose: |
| 763 | + leading_dims, M, K = x.shape[:-2], x.shape[-1], x.shape[-2] |
| 764 | + xs_scale_unblocked = from_blocked( |
| 765 | + xs.scale.t(), math.prod(leading_dims) * M, K // block_size |
| 766 | + ) |
| 767 | + xs_scale_unblocked = xs_scale_unblocked.view(*leading_dims, M, K // block_size) |
| 768 | + xs_scale_unblocked = xs_scale_unblocked.t() |
| 769 | + else: |
| 770 | + leading_dims, M, K = x.shape[:-2], x.shape[-2], x.shape[-1] |
| 771 | + xs_scale_unblocked = from_blocked( |
| 772 | + xs.scale, math.prod(leading_dims) * M, K // block_size |
| 773 | + ) |
| 774 | + xs_scale_unblocked = xs_scale_unblocked.view(*leading_dims, M, K // block_size) |
| 775 | + |
| 776 | + torch.testing.assert_close( |
| 777 | + x.scale, |
| 778 | + xs_scale_unblocked, |
| 779 | + atol=0, |
| 780 | + rtol=0, |
| 781 | + ) |
| 782 | + |
| 783 | + x_dq = x.dequantize(x.dtype) |
| 784 | + xs_dq = xs.dequantize(xs.dtype) |
| 785 | + torch.testing.assert_close(x_dq, xs_dq, atol=0, rtol=0) |
0 commit comments