Skip to content

Commit 0a45a54

Browse files
authored
Leveraged dpctl.tensor.stack() implementation (#1509)
* Leveraged dpctl.tensor.stack() implementation * Relaxed check in a test of SYCL queue to account the error of floating operations
1 parent 6538397 commit 0a45a54

File tree

6 files changed

+338
-89
lines changed

6 files changed

+338
-89
lines changed

.github/workflows/conda-package.yml

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ env:
1212
CHANNELS: '-c dppy/label/dev -c intel -c conda-forge --override-channels'
1313
TEST_SCOPE: >-
1414
test_arraycreation.py
15+
test_arraymanipulation.py
1516
test_dot.py
1617
test_dparray.py
1718
test_fft.py
@@ -23,6 +24,7 @@ env:
2324
test_umath.py
2425
test_usm_type.py
2526
third_party/cupy/linalg_tests/test_product.py
27+
third_party/cupy/manipulation_tests/test_join.py
2628
third_party/cupy/math_tests/test_explog.py
2729
third_party/cupy/math_tests/test_misc.py
2830
third_party/cupy/math_tests/test_trigonometric.py

dpnp/dpnp_iface_manipulation.py

+66-5
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def broadcast_to(x, /, shape, subok=False):
237237
return call_origin(numpy.broadcast_to, x, shape=shape, subok=subok)
238238

239239

240-
def concatenate(arrays, *, axis=0, out=None, dtype=None, **kwargs):
240+
def concatenate(arrays, /, *, axis=0, out=None, dtype=None, **kwargs):
241241
"""
242242
Join a sequence of arrays along an existing axis.
243243
@@ -253,8 +253,7 @@ def concatenate(arrays, *, axis=0, out=None, dtype=None, **kwargs):
253253
Each array in `arrays` is supported as either :class:`dpnp.ndarray`
254254
or :class:`dpctl.tensor.usm_ndarray`. Otherwise ``TypeError`` exeption
255255
will be raised.
256-
Parameter `out` is supported with default value.
257-
Parameter `dtype` is supported with default value.
256+
Parameters `out` and `dtype are supported with default value.
258257
Keyword argument ``kwargs`` is currently unsupported.
259258
Otherwise the function will be executed sequentially on CPU.
260259
@@ -834,15 +833,77 @@ def squeeze(x, /, axis=None):
834833
return call_origin(numpy.squeeze, x, axis)
835834

836835

837-
def stack(arrays, axis=0, out=None):
836+
def stack(arrays, /, *, axis=0, out=None, dtype=None, **kwargs):
838837
"""
839838
Join a sequence of arrays along a new axis.
840839
841840
For full documentation refer to :obj:`numpy.stack`.
842841
842+
Returns
843+
-------
844+
out : dpnp.ndarray
845+
The stacked array which has one more dimension than the input arrays.
846+
847+
Limitations
848+
-----------
849+
Each array in `arrays` is supported as either :class:`dpnp.ndarray`
850+
or :class:`dpctl.tensor.usm_ndarray`. Otherwise ``TypeError`` exeption
851+
will be raised.
852+
Parameters `out` and `dtype are supported with default value.
853+
Keyword argument ``kwargs`` is currently unsupported.
854+
Otherwise the function will be executed sequentially on CPU.
855+
856+
See Also
857+
--------
858+
:obj:`dpnp.concatenate` : Join a sequence of arrays along an existing axis.
859+
:obj:`dpnp.block` : Assemble an nd-array from nested lists of blocks.
860+
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal size.
861+
862+
Examples
863+
--------
864+
>>> import dpnp as np
865+
>>> arrays = [np.random.randn(3, 4) for _ in range(10)]
866+
>>> np.stack(arrays, axis=0).shape
867+
(10, 3, 4)
868+
869+
>>> np.stack(arrays, axis=1).shape
870+
(3, 10, 4)
871+
872+
>>> np.stack(arrays, axis=2).shape
873+
(3, 4, 10)
874+
875+
>>> a = np.array([1, 2, 3])
876+
>>> b = np.array([4, 5, 6])
877+
>>> np.stack((a, b))
878+
array([[1, 2, 3],
879+
[4, 5, 6]])
880+
881+
>>> np.stack((a, b), axis=-1)
882+
array([[1, 4],
883+
[2, 5],
884+
[3, 6]])
885+
843886
"""
844887

845-
return call_origin(numpy.stack, arrays, axis, out)
888+
if kwargs:
889+
pass
890+
elif out is not None:
891+
pass
892+
elif dtype is not None:
893+
pass
894+
else:
895+
usm_arrays = [dpnp.get_usm_ndarray(x) for x in arrays]
896+
usm_res = dpt.stack(usm_arrays, axis=axis)
897+
return dpnp_array._create_from_usm_ndarray(usm_res)
898+
899+
return call_origin(
900+
numpy.stack,
901+
arrays,
902+
axis=axis,
903+
out=out,
904+
dtype=dtype,
905+
**kwargs,
906+
)
846907

847908

848909
def swapaxes(x1, axis1, axis2):

tests/conftest.py

+8
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ def allow_fall_back_on_numpy(monkeypatch):
8383
)
8484

8585

86+
@pytest.fixture
87+
def suppress_complex_warning():
88+
sup = numpy.testing.suppress_warnings("always")
89+
sup.filter(numpy.ComplexWarning)
90+
with sup:
91+
yield
92+
93+
8694
@pytest.fixture
8795
def suppress_divide_numpy_warnings():
8896
# divide: treatment for division by zero (infinite result obtained from finite numbers)

0 commit comments

Comments
 (0)