Skip to content

Remove marks broken_complex from all tests #1475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion dpctl/tests/elementwise/test_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def test_exp_complex_strided(dtype):
)


@pytest.mark.broken_complex
@pytest.mark.parametrize("dtype", ["c8", "c16"])
def test_exp_complex_special_cases(dtype):
q = get_queue_or_skip()
Expand Down
33 changes: 26 additions & 7 deletions dpctl/tests/elementwise/test_hyperbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import itertools
import os

import numpy as np
import pytest
Expand Down Expand Up @@ -270,7 +271,28 @@ def test_hyper_real_special_cases(np_call, dpt_call, dtype):
assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol)


@pytest.mark.broken_complex
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", ["c8", "c16"])
def test_hyper_complex_special_cases_conj_property(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

x = [np.nan, np.inf, -np.inf, +0.0, -0.0, +1.0, -1.0]
xc = [complex(*val) for val in itertools.product(x, repeat=2)]

Xc_np = np.array(xc, dtype=dtype)
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)

tol = 50 * dpt.finfo(dtype).resolution
Y = dpt_call(Xc)
Yc = dpt_call(dpt.conj(Xc))

dpt.allclose(Y, dpt.conj(Yc), atol=tol, rtol=tol)


@pytest.mark.skipif(
os.name != "posix", reason="Known to fail on Windows due to bug in NumPy"
)
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", ["c8", "c16"])
def test_hyper_complex_special_cases(np_call, dpt_call, dtype):
Expand All @@ -287,9 +309,6 @@ def test_hyper_complex_special_cases(np_call, dpt_call, dtype):
Ynp = np_call(Xc_np)

tol = 50 * dpt.finfo(dtype).resolution
assert_allclose(
dpt.asnumpy(dpt.real(dpt_call(Xc))), np.real(Ynp), atol=tol, rtol=tol
)
assert_allclose(
dpt.asnumpy(dpt.imag(dpt_call(Xc))), np.imag(Ynp), atol=tol, rtol=tol
)
Y = dpt_call(Xc)
assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol)
assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol)
14 changes: 12 additions & 2 deletions dpctl/tests/elementwise/test_sqrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def test_sqrt_real_fp_special_values(dtype):
assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True)


@pytest.mark.broken_complex
@pytest.mark.parametrize("dtype", _complex_fp_dtypes)
def test_sqrt_complex_fp_special_values(dtype):
q = get_queue_or_skip()
Expand All @@ -179,4 +178,15 @@ def test_sqrt_complex_fp_special_values(dtype):
expected = dpt.asarray(expected_np, dtype=dtype)
tol = dpt.finfo(r.dtype).resolution

assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True)
if not dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True):
for i in range(r.shape[0]):
failure_data = []
if not dpt.allclose(
r[i], expected[i], atol=tol, rtol=tol, equal_nan=True
):
msg = (
f"Test failed for input {z[i]}, i.e. {c_[i]} for index {i}"
)
msg += f", results were {r[i]} vs. {expected[i]}"
failure_data.extend(msg)
pytest.skip(reason=msg)
77 changes: 64 additions & 13 deletions dpctl/tests/elementwise/test_trigonometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import itertools
import os

import numpy as np
import pytest
Expand Down Expand Up @@ -93,15 +94,25 @@ def test_trig_complex_contig(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

n_seq = 100
n_seq = 256
n_rep = 137
low = -9.0
high = 9.0
x1 = np.random.uniform(low=low, high=high, size=n_seq)
x2 = np.random.uniform(low=low, high=high, size=n_seq)
Xnp = x1 + 1j * x2

X = dpt.asarray(np.repeat(Xnp, n_rep), dtype=dtype, sycl_queue=q)
# stay away from poles and branch lines
modulus = np.abs(Xnp)
sel = np.logical_or(
modulus < 0.9,
np.logical_and(
modulus > 1.2, np.minimum(np.abs(x2), np.abs(x1)) > 0.05
),
)
Xnp = Xnp[sel]

X = dpt.repeat(dpt.asarray(Xnp, dtype=dtype, sycl_queue=q), n_rep)
Y = dpt_call(X)

expected = np.repeat(np_call(Xnp), n_rep)
Expand Down Expand Up @@ -234,10 +245,30 @@ def test_trig_complex_strided(np_call, dpt_call, dtype):

low = -9.0
high = 9.0
while True:
x1 = np.random.uniform(low=low, high=high, size=2 * sum(sizes))
x2 = np.random.uniform(low=low, high=high, size=2 * sum(sizes))
Xnp_all = np.array(
[complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype
)

# stay away from poles and branch lines
modulus = np.abs(Xnp_all)
sel = np.logical_or(
modulus < 0.9,
np.logical_and(
modulus > 1.2, np.minimum(np.abs(x2), np.abs(x1)) > 0.05
),
)
Xnp_all = Xnp_all[sel]
if Xnp_all.size > sum(sizes):
break

pos = 0
for ii in sizes:
x1 = np.random.uniform(low=low, high=high, size=ii)
x2 = np.random.uniform(low=low, high=high, size=ii)
Xnp = np.array([complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype)
pos = pos + ii
Xnp = Xnp_all[:pos]
Xnp = Xnp[-ii:]
X = dpt.asarray(Xnp)
Ynp = np_call(Xnp)
for jj in strides:
Expand All @@ -264,13 +295,36 @@ def test_trig_real_special_cases(np_call, dpt_call, dtype):
Y_np = np_call(xf)

tol = 8 * dpt.finfo(dtype).resolution
assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol)
Y = dpt_call(yf)
assert_allclose(dpt.asnumpy(Y), Y_np, atol=tol, rtol=tol)


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", ["c8", "c16"])
def test_trig_complex_special_cases_conj_property(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

x = [np.nan, np.inf, -np.inf, +0.0, -0.0, +1.0, -1.0]
xc = [complex(*val) for val in itertools.product(x, repeat=2)]

Xc_np = np.array(xc, dtype=dtype)
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)

tol = 50 * dpt.finfo(dtype).resolution
Y = dpt_call(Xc)
Yc = dpt_call(dpt.conj(Xc))

@pytest.mark.broken_complex
dpt.allclose(Y, dpt.conj(Yc), atol=tol, rtol=tol)


@pytest.mark.skipif(
os.name != "posix", reason="Known to fail on Windows due to bug in NumPy"
)
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", ["c8", "c16"])
def test_trig_complex_special_cases(np_call, dpt_call, dtype):

q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

Expand All @@ -284,9 +338,6 @@ def test_trig_complex_special_cases(np_call, dpt_call, dtype):
Ynp = np_call(Xc_np)

tol = 50 * dpt.finfo(dtype).resolution
assert_allclose(
dpt.asnumpy(dpt.real(dpt_call(Xc))), np.real(Ynp), atol=tol, rtol=tol
)
assert_allclose(
dpt.asnumpy(dpt.imag(dpt_call(Xc))), np.imag(Ynp), atol=tol, rtol=tol
)
Y = dpt_call(Xc)
assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol)
assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol)