Skip to content

Commit c901108

Browse files
[mxfp8] fix test nan != nan issue
1 parent 1e473ed commit c901108

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_some_zeros(elem_dtype):
117117

118118

119119
# TODO(future PR): fix and reenable this test
120-
@pytest.mark.skip(reason="does not pass on B200 yet")
120+
# @pytest.mark.skip(reason="does not pass on B200 yet")
121121
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
122122
def test_to_mx_rceil():
123123
# nan
@@ -131,11 +131,7 @@ def test_to_mx_rceil():
131131
],
132132
dtype=torch.uint32,
133133
).view(torch.float32)
134-
# fmt: on
135-
ground_truth_scale = torch.tensor([255], dtype=torch.uint8).view(
136-
torch.float8_e8m0fnu
137-
)
138-
# fmt: off
134+
139135
ground_truth_fp8 = torch.tensor(
140136
[
141137
127, 0, 0, 0, 0, 0, 0, 0,
@@ -149,7 +145,7 @@ def test_to_mx_rceil():
149145
data_mx = MXTensor.to_mx(
150146
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
151147
)
152-
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
148+
assert torch.isnan(data_mx.scale)
153149
assert torch.isnan(data_mx.qdata[0])
154150
assert torch.all(data_mx.qdata[1:] == 0)
155151
# fp32 denorm

0 commit comments

Comments
 (0)