Skip to content

Commit 57e7359

Browse files
Leverage dpctl.tensor.expand_dims()/swapaxes() implementation (#1532)
* Leverage dpctl.tensor.expand_dims impl * Leverage dpctl.tensor.swapaxes impl * Align args names to numpy * Remove call_origin for dpnp.moveaxis * Remove call_origin for dpnp.squeeze
1 parent 79ac2b0 commit 57e7359

File tree

8 files changed

+131
-241
lines changed

8 files changed

+131
-241
lines changed

dpnp/dpnp_algo/dpnp_algo_manipulation.pxi

-45
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ and the rest of the library
3838
__all__ += [
3939
"dpnp_atleast_2d",
4040
"dpnp_atleast_3d",
41-
"dpnp_expand_dims",
4241
"dpnp_repeat",
43-
"dpnp_reshape",
4442
]
4543

4644

@@ -104,35 +102,6 @@ cpdef utils.dpnp_descriptor dpnp_atleast_3d(utils.dpnp_descriptor arr):
104102
return arr
105103

106104

107-
cpdef utils.dpnp_descriptor dpnp_expand_dims(utils.dpnp_descriptor in_array, axis):
108-
axis_tuple = utils._object_to_tuple(axis)
109-
result_ndim = len(axis_tuple) + in_array.ndim
110-
111-
if len(axis_tuple) == 0:
112-
axis_ndim = 0
113-
else:
114-
axis_ndim = max(-min(0, min(axis_tuple)), max(0, max(axis_tuple))) + 1
115-
116-
axis_norm = utils._object_to_tuple(utils.normalize_axis(axis_tuple, result_ndim))
117-
118-
if axis_ndim - len(axis_norm) > in_array.ndim:
119-
utils.checker_throw_axis_error("dpnp_expand_dims", "axis", axis, axis_ndim)
120-
121-
if len(axis_norm) > len(set(axis_norm)):
122-
utils.checker_throw_value_error("dpnp_expand_dims", "axis", axis, "no repeated axis")
123-
124-
cdef shape_type_c shape_list
125-
axis_idx = 0
126-
for i in range(result_ndim):
127-
if i in axis_norm:
128-
shape_list.push_back(1)
129-
else:
130-
shape_list.push_back(in_array.shape[axis_idx])
131-
axis_idx = axis_idx + 1
132-
133-
return dpnp_reshape(in_array, shape_list)
134-
135-
136105
cpdef utils.dpnp_descriptor dpnp_repeat(utils.dpnp_descriptor array1, repeats, axes=None):
137106
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array1.dtype)
138107

@@ -165,17 +134,3 @@ cpdef utils.dpnp_descriptor dpnp_repeat(utils.dpnp_descriptor array1, repeats, a
165134
c_dpctl.DPCTLEvent_Delete(event_ref)
166135

167136
return result
168-
169-
170-
cpdef utils.dpnp_descriptor dpnp_reshape(utils.dpnp_descriptor array1, newshape, order="C"):
171-
# return dpnp.get_dpnp_descriptor(dpctl.tensor.usm_ndarray(newshape, dtype=numpy.dtype(array1.dtype).name, buffer=array1.get_pyobj()))
172-
# return dpnp.get_dpnp_descriptor(dpctl.tensor.reshape(array1.get_pyobj(), newshape))
173-
array1_obj = array1.get_array()
174-
array_obj = dpctl.tensor.reshape(array1_obj, newshape, order=order)
175-
return dpnp.get_dpnp_descriptor(dpnp_array(array_obj.shape,
176-
buffer=array_obj,
177-
order=order,
178-
device=array1_obj.sycl_device,
179-
usm_type=array1_obj.usm_type,
180-
sycl_queue=array1_obj.sycl_queue),
181-
copy_when_nondefault_queue=False)

dpnp/dpnp_array.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1059,7 +1059,14 @@ def sum(
10591059
where=where,
10601060
)
10611061

1062-
# 'swapaxes',
1062+
def swapaxes(self, axis1, axis2):
1063+
"""
1064+
Interchange two axes of an array.
1065+
1066+
For full documentation refer to :obj:`numpy.swapaxes`.
1067+
"""
1068+
1069+
return dpnp.swapaxes(self, axis1=axis1, axis2=axis2)
10631070

10641071
def take(self, indices, /, *, axis=None, out=None, mode="wrap"):
10651072
"""

0 commit comments

Comments
 (0)