Skip to content

Commit acfd12d

Browse files
author
Jerry
committed
[not4land] Some fixes for MXFP8
1 parent ff16308 commit acfd12d

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,14 @@ def cuda_kernel_profiler(kernel_pattern):
6767
@pytest.mark.skipif(
6868
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
6969
)
70-
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
71-
@pytest.mark.parametrize("bias", [True, False])
72-
@pytest.mark.parametrize("compile", [True, False])
73-
@pytest.mark.parametrize("emulate", [True, False])
70+
# @pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
71+
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn])
72+
# @pytest.mark.parametrize("bias", [True, False])
73+
# @pytest.mark.parametrize("compile", [True, False])
74+
# @pytest.mark.parametrize("emulate", [True, False])
75+
@pytest.mark.parametrize("bias", [False])
76+
@pytest.mark.parametrize("compile", [False])
77+
@pytest.mark.parametrize("emulate", [False])
7478
@torch.no_grad()
7579
@skip_if_rocm(
7680
"ROCm float4 gemm require gfx950"
@@ -93,7 +97,11 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b
9397
# TODO(future PR): investigate and fix this
9498
pytest.skip("mxfp4 + compile currently does not work, low SQNR")
9599

96-
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
100+
# M, N, K = 16, 3072, 4096
101+
# M, N, K = 1920, 3072, 256
102+
M, N, K = 1920, 18432, 3072
103+
# m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
104+
m = nn.Linear(K, N, bias=bias, dtype=torch.bfloat16, device="cuda")
97105
m_mx = copy.deepcopy(m)
98106

99107
if emulate:
@@ -108,18 +116,22 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b
108116
gemm_kernel_choice=kernel_choice,
109117
)
110118
quantize_(m_mx, config=config)
119+
print("m_mx:", m_mx)
120+
111121
if compile:
112122
m_mx = torch.compile(m_mx, fullgraph=True)
113123

114-
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
115-
y_ref = m(x)
116-
y_mx = m_mx(x)
124+
with torch.inference_mode():
125+
x = torch.randn(1, M, K, device="cuda", dtype=torch.bfloat16)
126+
y_ref = m(x)
127+
y_mx = m_mx(x)
117128
sqnr = compute_error(y_ref, y_mx)
118129
SQNR_THRESHOLD = 25.0 if elem_dtype == torch.float8_e4m3fn else 15.0
119130
assert sqnr >= SQNR_THRESHOLD, (
120131
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
121132
)
122133

134+
raise Exception("stop")
123135
# serialization
124136
with tempfile.NamedTemporaryFile() as f:
125137
torch.save(m_mx.state_dict(), f)

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,10 @@ def _addmm_mx_dispatch(
665665
The only difference is whether bias is None or not.
666666
"""
667667

668+
out_starting_shape = a.shape[:-1]
669+
668670
if not isinstance(a, MXTensor):
671+
a = a.reshape(-1, a.shape[-1])
669672
assert b.act_quant_kwargs is not None, "weight-only quant not yet supported"
670673
k = b.act_quant_kwargs
671674
a = MXTensor.to_mx(
@@ -698,6 +701,7 @@ def _addmm_mx_dispatch(
698701
"CUBLAS is the only supported kernel choice for MX FP8 operations"
699702
)
700703

704+
print(f"scaled_mm info: {a.qdata.shape=}, {b.qdata.shape=}, {a_scale_block.shape=}, {b_scale_block.shape=}, a.qdata contig: {a.qdata.is_contiguous()} b.qdata contig: {b.qdata.is_contiguous()} a_scale_block contig: {a_scale_block.is_contiguous()}, b_scale_block contig: {b_scale_block.is_contiguous()}")
701705
res = torch._scaled_mm(
702706
a.qdata,
703707
b.qdata,
@@ -706,6 +710,7 @@ def _addmm_mx_dispatch(
706710
bias=bias,
707711
out_dtype=torch.bfloat16,
708712
)
713+
print(f"after scaled mm {a.qdata.shape=}, {b.qdata.shape=}")
709714
else:
710715
assert a._elem_dtype == torch.float4_e2m1fn_x2
711716
assert b._elem_dtype == torch.float4_e2m1fn_x2
@@ -717,7 +722,6 @@ def _addmm_mx_dispatch(
717722
# TODO add optional bias to kernel
718723
if bias is not None:
719724
res = res + bias
720-
721725
else:
722726
# emulated MX gemm
723727
a_hp = a.dequantize(a._orig_dtype)
@@ -726,12 +730,17 @@ def _addmm_mx_dispatch(
726730
assert a_hp.is_contiguous()
727731
assert b_hp.t().is_contiguous()
728732

729-
# Call appropriate aten_op based on whether bias is provided
730-
if bias is not None:
731-
res = aten_op(bias, a_hp, b_hp) # addmm
733+
if aten_op == aten.linear.default:
734+
res = aten_op(a_hp, b_hp.t(), bias)
732735
else:
733-
res = aten_op(a_hp, b_hp) # mm
736+
# Call appropriate aten_op based on whether bias is provided
737+
if bias is not None:
738+
res = aten_op(bias, a_hp, b_hp) # addmm
739+
else:
740+
res = aten_op(a_hp, b_hp) # mm
734741

742+
res = res.reshape(*out_starting_shape, res.shape[-1])
743+
735744
return res
736745

737746

@@ -752,6 +761,14 @@ def mx_addmm(func, types, args, kwargs):
752761
b = args[2]
753762
return _addmm_mx_dispatch(a, b, func, bias=bias)
754763

764+
@implements([aten.linear.default])
765+
def mx_linear(func, types, args, kwargs):
766+
assert isinstance(args[0], torch.Tensor) and isinstance(args[1], MXTensor)
767+
a = args[0]
768+
b = args[1]
769+
bias = args[2] if len(args) > 2 else None
770+
return _addmm_mx_dispatch(a, b.t(), func, bias=bias)
771+
755772

756773
@implements([aten.t.default])
757774
def mx_t(func, types, args, kwargs):

0 commit comments

Comments
 (0)