Skip to content

Improves F8 conversion #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
27 changes: 26 additions & 1 deletion _unittests/ut_validation/test_f8.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,32 @@ def test_fe4m3fn_to_float32_bug(self):
continue
raise AssertionError(f"Unexpected value for pt={pt}.")

def test_inf(self):
for x, e in [(numpy.float32(numpy.inf), 126), (numpy.float32(-numpy.inf), 254)]:
f8 = float32_to_fe4m3(x)
self.assertEqual(e, f8)

def test_nan(self):
expected = 127
values = [
(
None,
int.from_bytes(struct.pack("<f", numpy.float32(numpy.nan)), "little"),
numpy.float32(numpy.nan),
expected,
)
]
for i in range(0, 23):
v = 0x7F800000 | (1 << i)
f = numpy.uint32(v).view(numpy.float32)
values.append((i, v, f, expected))
values.append((i, v, -f, expected | 128))

for i, v, x, e in values:
with self.subTest(x=x, e=e, h=hex(v), i=i):
f8 = float32_to_fe4m3(x)
self.assertEqual(e, f8)


if __name__ == "__main__":
TestF8().test_fe4m3fn_to_float32_bug()
unittest.main(verbosity=2)
32 changes: 20 additions & 12 deletions onnx_array_api/validation/f8.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,15 +503,18 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
"""
if not fn:
raise NotImplementedError("fn=False is not implemented.")
b = int.from_bytes(struct.pack("<f", numpy.float32(x)), "little")
if not isinstance(x, numpy.float32):
x = numpy.float32(x)
b = int.from_bytes(struct.pack("<f", x), "little")
ret = (b & 0x80000000) >> 24 # sign
if uz:
if (b & 0x7FC00000) == 0x7FC00000:
return 0x80
if numpy.isinf(x):
if (b & 0x7FFFFFFF) == 0x7F800000:
# infinity
if saturate:
return ret | 127
return 0x80
if (b & 0x7F800000) == 0x7F800000:
return 0x80
e = (b & 0x7F800000) >> 23 # exponent
m = b & 0x007FFFFF # mantissa

Expand Down Expand Up @@ -558,12 +561,14 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
ret = 0
return int(ret)
else:
if (b & 0x7FC00000) == 0x7FC00000:
return 0x7F | ret
if numpy.isinf(x):
if (b & 0x7FFFFFFF) == 0x7F800000:
# infinity
if saturate:
return ret | 126
return 0x7F | ret
if (b & 0x7F800000) == 0x7F800000:
# non
return 0x7F | ret
e = (b & 0x7F800000) >> 23 # exponent
m = b & 0x007FFFFF # mantissa

Expand Down Expand Up @@ -624,13 +629,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
ret = (b & 0x80000000) >> 24 # sign

if fn and uz:
if (b & 0x7FC00000) == 0x7FC00000:
return 0x80
if (b & 0x7FFFFFFF) == 0x7F800000:
# inf
if saturate:
return ret | 0x7F
return 0x80
if (b & 0x7F800000) == 0x7F800000:
# nan
return 0x80
e = (b & 0x7F800000) >> 23 # exponent
m = b & 0x007FFFFF # mantissa

Expand Down Expand Up @@ -675,12 +681,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
ret = 0
return int(ret)
elif not fn and not uz:
if (b & 0x7FC00000) == 0x7FC00000:
return 0x7F | ret
if numpy.isinf(x):
if (b & 0x7FFFFFFF) == 0x7F800000:
# inf
if saturate:
return 0x7B | ret
return 0x7C | ret
if (b & 0x7F800000) == 0x7F800000:
# nan
return 0x7F | ret
e = (b & 0x7F800000) >> 23 # exponent
m = b & 0x007FFFFF # mantissa

Expand Down