diff --git a/_unittests/ut_validation/temp_search_float32_into_fe4m3fn.xlsx b/_unittests/ut_validation/temp_search_float32_into_fe4m3fn.xlsx new file mode 100644 index 0000000..2d59c02 Binary files /dev/null and b/_unittests/ut_validation/temp_search_float32_into_fe4m3fn.xlsx differ diff --git a/_unittests/ut_validation/test_f8.py b/_unittests/ut_validation/test_f8.py index 639319e..b44683f 100644 --- a/_unittests/ut_validation/test_f8.py +++ b/_unittests/ut_validation/test_f8.py @@ -1220,7 +1220,32 @@ def test_fe4m3fn_to_float32_bug(self): continue raise AssertionError(f"Unexpected value for pt={pt}.") + def test_inf(self): + for x, e in [(numpy.float32(numpy.inf), 126), (numpy.float32(-numpy.inf), 254)]: + f8 = float32_to_fe4m3(x) + self.assertEqual(e, f8) + + def test_nan(self): + expected = 127 + values = [ + ( + None, + int.from_bytes(struct.pack("> 24 # sign if uz: - if (b & 0x7FC00000) == 0x7FC00000: - return 0x80 - if numpy.isinf(x): + if (b & 0x7FFFFFFF) == 0x7F800000: + # infinity if saturate: return ret | 127 return 0x80 + if (b & 0x7F800000) == 0x7F800000: + return 0x80 e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa @@ -558,12 +561,14 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True ret = 0 return int(ret) else: - if (b & 0x7FC00000) == 0x7FC00000: - return 0x7F | ret - if numpy.isinf(x): + if (b & 0x7FFFFFFF) == 0x7F800000: + # infinity if saturate: return ret | 126 return 0x7F | ret + if (b & 0x7F800000) == 0x7F800000: + # non + return 0x7F | ret e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa @@ -624,13 +629,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru ret = (b & 0x80000000) >> 24 # sign if fn and uz: - if (b & 0x7FC00000) == 0x7FC00000: - return 0x80 if (b & 0x7FFFFFFF) == 0x7F800000: # inf if saturate: return ret | 0x7F return 0x80 + if (b & 0x7F800000) == 0x7F800000: + # nan + return 0x80 e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa @@ -675,12 +681,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru ret = 0 return int(ret) elif not fn and not uz: - if (b & 0x7FC00000) == 0x7FC00000: - return 0x7F | ret - if numpy.isinf(x): + if (b & 0x7FFFFFFF) == 0x7F800000: + # inf if saturate: return 0x7B | ret return 0x7C | ret + if (b & 0x7F800000) == 0x7F800000: + # nan + return 0x7F | ret e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa