diff --git a/array_api_tests/test_constants.py b/array_api_tests/test_constants.py index 45c039e0..ce14a128 100644 --- a/array_api_tests/test_constants.py +++ b/array_api_tests/test_constants.py @@ -1,9 +1,13 @@ +import math + from ._array_module import (e, inf, nan, pi, equal, isnan, abs, full, float32, float64, less, isinf, greater, all) from .array_helpers import one def test_e(): - # Check that e acts as a scalar + # Check that e is a Python scalar + assert isinstance(e, float), "e is not a Python scalar" + E = full((1,), e, dtype=float64) # We don't require any accuracy. This is just a smoke test to check that @@ -11,7 +15,9 @@ def test_e(): assert all(less(abs(E - 2.71), one((1,), dtype=float64))), "e is not the constant e" def test_pi(): - # Check that pi acts as a scalar + # Check that pi is a Python scalar + assert isinstance(pi, float), "pi is not a Python scalar" + PI = full((1,), pi, dtype=float64) # We don't require any accuracy. This is just a smoke test to check that @@ -19,21 +25,25 @@ def test_pi(): assert all(less(abs(PI - 3.14), one((1,), dtype=float64))), "pi is not the constant π" def test_inf(): - # Check that inf acts as a scalar + # Check that inf is a Python scalar + assert isinstance(inf, float), "inf is not a Python scalar" + INF = full((1,), inf, dtype=float64) zero = full((1,), 0.0, dtype=float64) - assert all(isinf(inf)), "inf is not infinity" + assert math.isinf(inf), "inf is not infinity" assert all(isinf(INF)), "inf is not infinity" - assert all(greater(inf, zero)), "inf is not positive" + assert inf > 0, "inf is not positive" assert all(greater(INF, zero)), "inf is not positive" def test_nan(): - # Check that nan acts as a scalar + # Check that nan is a Python scalar + assert isinstance(nan, float), "nan is not a Python scalar" + NAN = full((1,), nan, dtype=float64) - assert all(isnan(nan)), "nan is not Not a Number" + assert math.isnan(nan), "nan is not Not a Number" assert all(isnan(NAN)), "nan is not Not a Number" - assert not all(equal(nan, nan)), "nan should be unequal to itself" + assert nan != nan, "nan should be unequal to itself" assert not all(equal(NAN, NAN)), "nan should be unequal to itself"