Skip to content

Commit 29cffab

Browse files
committed
Improve outer method
1 parent 8d54d7c commit 29cffab

File tree

4 files changed

+96
-14
lines changed

4 files changed

+96
-14
lines changed

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"""
4141

4242

43+
from dpnp.dpnp_array import dpnp_array
4344
from dpnp.dpnp_algo import *
4445
from dpnp.dpnp_utils import *
4546
import dpnp
@@ -312,10 +313,15 @@ def outer(x1, x2, **kwargs):
312313
313314
"""
314315

315-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
316-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False)
317-
if x1_desc and x2_desc and not kwargs:
318-
return dpnp_outer(x1_desc, x2_desc).get_pyobj()
316+
if not kwargs:
317+
if isinstance(x1, dpnp_array) and isinstance(x2, dpnp_array):
318+
ravel = lambda x: x.flatten() if x.ndim > 1 else x
319+
return ravel(x1)[:, None] * ravel(x2)[None, :]
320+
321+
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
322+
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False)
323+
if x1_desc and x2_desc:
324+
return dpnp_outer(x1_desc, x2_desc).get_pyobj()
319325

320326
return call_origin(numpy.outer, x1, x2, **kwargs)
321327

tests/skipped_tests_gpu.tbl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumError::test_too_fe
286286
tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumError::test_too_many_dimension3
287287
tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_dim_mismatch3
288288
tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_too_many_dims3
289-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_outer
289+
290290
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_reversed_outer
291291
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_reversed_vdot
292292
tests/third_party/cupy/manipulation_tests/test_basic.py::TestCopytoFromScalar_param_7_{dst_shape=(0,), src=3.2}::test_copyto_where
@@ -347,8 +347,6 @@ tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_2_{extern
347347
tests/third_party/cupy/statistics_tests/test_correlation.py::TestCov::test_cov_empty
348348
tests/third_party/cupy/statistics_tests/test_meanvar.py::TestMeanVar::test_external_mean_axis
349349

350-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_multidim_outer
351-
352350
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_axis
353351
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_negative_axis
354352
tests/third_party/cupy/sorting_tests/test_sort.py::TestPartition_param_3_{external=True, length=20000}::test_partition_none_axis

tests/test_outer.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import unittest
2+
from tests.third_party.cupy import testing
3+
4+
import dpnp as dp
5+
import numpy as np
6+
7+
from numpy.testing import assert_raises
8+
9+
10+
class TestOuter(unittest.TestCase):
11+
12+
@testing.for_all_dtypes()
13+
@testing.numpy_cupy_allclose()
14+
def test_two_vectors(self, xp, dtype):
15+
a = xp.ones((10, ), dtype=dtype)
16+
b = xp.linspace(-2, 2, 5, dtype=dtype)
17+
18+
return xp.outer(a, b)
19+
20+
@testing.for_all_dtypes()
21+
@testing.numpy_cupy_allclose()
22+
def test_two_matrix(self, xp, dtype):
23+
a = xp.ones((10, 10, 10), dtype=dtype)
24+
b = xp.full(shape=(3, 7), fill_value=42, dtype=dtype)
25+
26+
return xp.outer(a, b)
27+
28+
@testing.for_all_dtypes()
29+
@testing.numpy_cupy_allclose()
30+
def test_the_same_vector(self, xp, dtype):
31+
a = xp.full(shape=(100, ), fill_value=7, dtype=dtype)
32+
return xp.outer(a, a)
33+
34+
@testing.for_all_dtypes()
35+
@testing.numpy_cupy_allclose()
36+
def test_the_same_matrix(self, xp, dtype):
37+
a = xp.arange(27, dtype=dtype).reshape(3, 3, 3)
38+
return xp.outer(a, a)
39+
40+
41+
class TestScalarOuter(unittest.TestCase):
42+
43+
@unittest.skip("A scalar isn't currently supported as input")
44+
@testing.for_all_dtypes()
45+
@testing.numpy_cupy_allclose()
46+
def test_first_is_scalar(self, xp, dtype):
47+
scalar = xp.int64(4)
48+
a = xp.arange(5**3, dtype=dtype).reshape(5, 5, 5)
49+
return xp.outer(scalar, a)
50+
51+
@unittest.skip("A scalar isn't currently supported as input")
52+
@testing.for_all_dtypes()
53+
@testing.numpy_cupy_allclose()
54+
def test_second_is_scalar(self, xp, dtype):
55+
scalar = xp.int32(7)
56+
a = xp.arange(5**3, dtype=dtype).reshape(5, 5, 5)
57+
return xp.outer(a, scalar)
58+
59+
@unittest.skip("A scalar isn't currently supported as input")
60+
@testing.numpy_cupy_array_equal()
61+
def test_both_inputs_as_scalar(self, xp):
62+
a = xp.int64(4)
63+
b = xp.int32(17)
64+
return xp.outer(a, b)
65+
66+
67+
class TestListOuter(unittest.TestCase):
68+
69+
def test_list(self):
70+
a = np.arange(27).reshape(3, 3, 3)
71+
b: list[list[list[int]]] = a.tolist()
72+
dp_a = dp.array(a)
73+
74+
with assert_raises(NotImplementedError):
75+
dp.outer(b, dp_a)
76+
dp.outer(dp_a, b)
77+
dp.outer(b, b)

tests/test_sycl_queue.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ def test_1in_1out(func, data, device):
180180
pytest.param("fmod",
181181
[-3., -2., -1., 1., 2., 3.],
182182
[2., 2., 2., 2., 2., 2.]),
183+
pytest.param("matmul",
184+
[[1., 0.], [0., 1.]],
185+
[[4., 1.], [1., 2.]]),
183186
pytest.param("maximum",
184187
[2., 3., 4.],
185188
[1., 5., 2.]),
@@ -189,6 +192,9 @@ def test_1in_1out(func, data, device):
189192
pytest.param("multiply",
190193
[0., 1., 2., 3., 4., 5., 6., 7., 8.],
191194
[0., 1., 2., 0., 1., 2., 0., 1., 2.]),
195+
pytest.param("outer",
196+
[0., 1., 2., 3., 4., 5.],
197+
[0., 1., 2., 0.]),
192198
pytest.param("power",
193199
[0., 1., 2., 3., 4., 5.],
194200
[1., 2., 3., 3., 2., 1.]),
@@ -198,9 +204,6 @@ def test_1in_1out(func, data, device):
198204
pytest.param("subtract",
199205
[0., 1., 2., 3., 4., 5., 6., 7., 8.],
200206
[0., 1., 2., 0., 1., 2., 0., 1., 2.]),
201-
pytest.param("matmul",
202-
[[1., 0.], [0., 1.]],
203-
[[4., 1.], [1., 2.]]),
204207
],
205208
)
206209
@pytest.mark.parametrize("device",
@@ -217,10 +220,8 @@ def test_2in_1out(func, data1, data2, device):
217220

218221
numpy.testing.assert_array_equal(result, expected)
219222

220-
expected_queue = x1.get_array().sycl_queue
221-
result_queue = result.get_array().sycl_queue
222-
223-
assert_sycl_queue_equal(result_queue, expected_queue)
223+
assert_sycl_queue_equal(result.sycl_queue, x1.sycl_queue)
224+
assert_sycl_queue_equal(result.sycl_queue, x2.sycl_queue)
224225

225226

226227
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)