Skip to content

Commit 28612d0

Browse files
authored
mxfp4 and nvfp4: align fp4 packing to PyTorch Core definition (#3123)
Update [ghstack-poisoned]
1 parent 2fe0ca0 commit 28612d0

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b
7070
elif elem_dtype == torch.float4_e2m1fn_x2:
7171
if not is_sm_at_least_100() and not emulate:
7272
pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm")
73-
elif emulate and compile:
73+
elif compile:
7474
# TODO(future PR): investigate and fix this
75-
pytest.skip("mxfp4 + emulate + compile currently does not work, low SQNR")
75+
pytest.skip("mxfp4 + compile currently does not work, low SQNR")
7676

7777
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
7878
m_mx = copy.deepcopy(m)

test/prototype/mx_formats/test_kernels.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,26 @@ def test_fp4_pack_unpack():
320320
orig_vals = torch.Tensor([[0.0, 0.5, 4.0, -0.0], [-0.0, 1.0, -6.0, 3.0]])
321321
orig_vals_f4_unpacked = f32_to_f4_unpacked(orig_vals)
322322
orig_vals_f4_packed = pack_uint4(orig_vals_f4_unpacked)
323+
324+
# ensure packing is
325+
#
326+
# 7654:3210
327+
# val1:val0
328+
expected_f4_packed = torch.tensor(
329+
[
330+
[
331+
0b00010000,
332+
0b10000110,
333+
],
334+
[
335+
0b00101000,
336+
0b01011111,
337+
],
338+
],
339+
dtype=torch.uint8,
340+
)
341+
342+
assert torch.all(orig_vals_f4_packed == expected_f4_packed)
323343
assert orig_vals_f4_packed.numel() == (orig_vals.numel() / 2)
324344
orig_vals_f4_packed_unpacked = unpack_uint4(orig_vals_f4_packed)
325345
orig_vals_dq = f4_unpacked_to_f32(orig_vals_f4_packed_unpacked)

torchao/prototype/mx_formats/kernels.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,10 @@ def _fp4_packed_to_bf16(
142142
Output: a tensor of bfloat16 values
143143
"""
144144

145-
# low-bits: original location 0:3
146-
# high-bits: original location 4:7
147-
x_low_bits = x_packed >> 4
148-
x_high_bits = x_packed & 0xF
145+
# high-bits: original location 0:3
146+
# low-bits: original location 4:7
147+
x_high_bits = x_packed >> 4
148+
x_low_bits = x_packed & 0xF
149149
x = tl.interleave(x_low_bits, x_high_bits)
150150

151151
# cast logic below
@@ -735,8 +735,8 @@ def unpack_uint4(uint8_data) -> torch.Tensor:
735735
# verified that we get a single triton kernel, but that is even slower
736736
# than the two kernels before this PR
737737
# * TODO add a microbenchmark of just the cast and profile this
738-
first_elements = (uint8_data >> 4).to(torch.uint8)
739-
second_elements = (uint8_data & 0b1111).to(torch.uint8)
738+
first_elements = (uint8_data & 0b1111).to(torch.uint8)
739+
second_elements = (uint8_data >> 4).to(torch.uint8)
740740
unpacked = torch.stack([first_elements, second_elements], dim=-1).view(
741741
up_size(shape)
742742
)
@@ -758,7 +758,7 @@ def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor:
758758
shape = uint8_data.shape
759759
assert shape[-1] % 2 == 0
760760
uint8_data = uint8_data.contiguous().view(-1)
761-
return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape))
761+
return (uint8_data[::2] | uint8_data[1::2] << 4).view(down_size(shape))
762762

763763

764764
# PyTorch implementation of fp6 packing for reference purposes
@@ -1250,8 +1250,8 @@ def convert_fp32_to_fp4_packed(x_pairs):
12501250
Returns:
12511251
Packed tensor with shape [...] (last dimension removed) where each
12521252
element is an int8 containing 2 FP4 values:
1253-
- First value of pair → high nibble (bits 4-7)
1254-
- Second value of pair → low nibble (bits 0-3)
1253+
- First value of pair → low nibble (bits 0-3)
1254+
- Second value of pair → high nibble (bits 4-7)
12551255
12561256
Example:
12571257
Input: [128, 32, 2] containing FP32 pairs
@@ -1263,10 +1263,10 @@ def convert_fp32_to_fp4_packed(x_pairs):
12631263
asm="""
12641264
{
12651265
.reg .b8 byte0, byte1, byte2, byte3;
1266-
cvt.rn.satfinite.e2m1x2.f32 byte0, $1, $5;
1267-
cvt.rn.satfinite.e2m1x2.f32 byte1, $2, $6;
1268-
cvt.rn.satfinite.e2m1x2.f32 byte2, $3, $7;
1269-
cvt.rn.satfinite.e2m1x2.f32 byte3, $4, $8;
1266+
cvt.rn.satfinite.e2m1x2.f32 byte0, $5, $1;
1267+
cvt.rn.satfinite.e2m1x2.f32 byte1, $6, $2;
1268+
cvt.rn.satfinite.e2m1x2.f32 byte2, $7, $3;
1269+
cvt.rn.satfinite.e2m1x2.f32 byte3, $8, $4;
12701270
mov.b32 $0, {byte0, byte1, byte2, byte3};
12711271
}
12721272
""",

0 commit comments

Comments
 (0)