Skip to content

Commit 44fddf7

Browse files
authored
Merge pull request #1530 from IntelPython/dpnp_use_inplace_dpctl
utilize new functionality of dpctl for in-place operators
2 parents 57e7359 + 6c00c20 commit 44fddf7

6 files changed

+78
-42
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

-23
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
# *****************************************************************************
2828

2929

30-
import dpctl
31-
import dpctl.tensor as dpt
3230
import dpctl.tensor._tensor_impl as ti
3331
from dpctl.tensor._elementwise_common import (
3432
BinaryElementwiseFunc,
@@ -538,32 +536,11 @@ def _call_divide(src1, src2, dst, sycl_queue, depends=None):
538536
return ti._divide(src1, src2, dst, sycl_queue, depends)
539537

540538

541-
def _call_divide_inplace(lhs, rhs, sycl_queue, depends=None):
542-
"""In place workaround until dpctl.tensor provides the functionality."""
543-
544-
if depends is None:
545-
depends = []
546-
547-
# allocate temporary memory for out array
548-
out = dpt.empty_like(lhs, dtype=dpnp.result_type(lhs.dtype, rhs.dtype))
549-
550-
# call a general callback
551-
div_ht_, div_ev_ = _call_divide(lhs, rhs, out, sycl_queue, depends)
552-
553-
# store the result into left input array and return events
554-
cp_ht_, cp_ev_ = ti._copy_usm_ndarray_into_usm_ndarray(
555-
src=out, dst=lhs, sycl_queue=sycl_queue, depends=[div_ev_]
556-
)
557-
dpctl.SyclEvent.wait_for([div_ht_])
558-
return (cp_ht_, cp_ev_)
559-
560-
561539
divide_func = BinaryElementwiseFunc(
562540
"divide",
563541
ti._divide_result_type,
564542
_call_divide,
565543
_divide_docstring_,
566-
_call_divide_inplace,
567544
)
568545

569546

dpnp/dpnp_array.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,10 @@ def __eq__(self, other):
188188
def __float__(self):
189189
return self._array_obj.__float__()
190190

191-
# '__floordiv__',
191+
def __floordiv__(self, other):
192+
"""Return self//value."""
193+
return dpnp.floor_divide(self, other)
194+
192195
# '__format__',
193196

194197
def __ge__(self, other):
@@ -227,15 +230,22 @@ def __iand__(self, other):
227230
dpnp.bitwise_and(self, other, out=self)
228231
return self
229232

230-
# '__ifloordiv__',
233+
def __ifloordiv__(self, other):
234+
"""Return self//=value."""
235+
dpnp.floor_divide(self, other, out=self)
236+
return self
231237

232238
def __ilshift__(self, other):
233239
"""Return self<<=value."""
234240
dpnp.left_shift(self, other, out=self)
235241
return self
236242

237243
# '__imatmul__',
238-
# '__imod__',
244+
245+
def __imod__(self, other):
246+
"""Return self%=value."""
247+
dpnp.remainder(self, other, out=self)
248+
return self
239249

240250
def __imul__(self, other):
241251
"""Return self*=value."""
@@ -345,7 +355,8 @@ def __rand__(self, other):
345355
def __repr__(self):
346356
return dpt.usm_ndarray_repr(self._array_obj, prefix="array")
347357

348-
# '__rfloordiv__',
358+
def __rfloordiv__(self, other):
359+
return dpnp.floor_divide(self, other)
349360

350361
def __rlshift__(self, other):
351362
return dpnp.left_shift(other, self)

dpnp/dpnp_iface_mathematical.py

+8
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,14 @@ def floor_divide(
901901
902902
>>> np.floor_divide(np.array([1., 2., 3., 4.]), 2.5)
903903
array([ 0., 0., 1., 1.])
904+
905+
The ``//`` operator can be used as a shorthand for ``floor_divide`` on
906+
:class:`dpnp.ndarray`.
907+
908+
>>> x1 = np.array([1., 2., 3., 4.])
909+
>>> x1 // 2.5
910+
array([0., 0., 1., 1.])
911+
904912
"""
905913

906914
return check_nd_call_func(

tests/test_bitwise.py

-15
Original file line numberDiff line numberDiff line change
@@ -67,50 +67,41 @@ def test_bitwise_and(self, lhs, rhs, dtype):
6767
)
6868
assert_array_equal(dp_a & dp_b, np_a & np_b)
6969

70-
"""
71-
TODO: unmute once dpctl support that
7270
if (
7371
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
7472
and dp_a.shape == dp_b.shape
7573
):
7674
dp_a &= dp_b
7775
np_a &= np_b
7876
assert_array_equal(dp_a, np_a)
79-
"""
8077

8178
def test_bitwise_or(self, lhs, rhs, dtype):
8279
dp_a, dp_b, np_a, np_b = self._test_binary_int(
8380
"bitwise_or", lhs, rhs, dtype
8481
)
8582
assert_array_equal(dp_a | dp_b, np_a | np_b)
8683

87-
"""
88-
TODO: unmute once dpctl support that
8984
if (
9085
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
9186
and dp_a.shape == dp_b.shape
9287
):
9388
dp_a |= dp_b
9489
np_a |= np_b
9590
assert_array_equal(dp_a, np_a)
96-
"""
9791

9892
def test_bitwise_xor(self, lhs, rhs, dtype):
9993
dp_a, dp_b, np_a, np_b = self._test_binary_int(
10094
"bitwise_xor", lhs, rhs, dtype
10195
)
10296
assert_array_equal(dp_a ^ dp_b, np_a ^ np_b)
10397

104-
"""
105-
TODO: unmute once dpctl support that
10698
if (
10799
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
108100
and dp_a.shape == dp_b.shape
109101
):
110102
dp_a ^= dp_b
111103
np_a ^= np_b
112104
assert_array_equal(dp_a, np_a)
113-
"""
114105

115106
def test_invert(self, lhs, rhs, dtype):
116107
dp_a, np_a = self._test_unary_int("invert", lhs, dtype)
@@ -122,30 +113,24 @@ def test_left_shift(self, lhs, rhs, dtype):
122113
)
123114
assert_array_equal(dp_a << dp_b, np_a << np_b)
124115

125-
"""
126-
TODO: unmute once dpctl support that
127116
if (
128117
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
129118
and dp_a.shape == dp_b.shape
130119
):
131120
dp_a <<= dp_b
132121
np_a <<= np_b
133122
assert_array_equal(dp_a, np_a)
134-
"""
135123

136124
def test_right_shift(self, lhs, rhs, dtype):
137125
dp_a, dp_b, np_a, np_b = self._test_binary_int(
138126
"right_shift", lhs, rhs, dtype
139127
)
140128
assert_array_equal(dp_a >> dp_b, np_a >> np_b)
141129

142-
"""
143-
TODO: unmute once dpctl support that
144130
if (
145131
not (inp.isscalar(dp_a) or inp.isscalar(dp_b))
146132
and dp_a.shape == dp_b.shape
147133
):
148134
dp_a >>= dp_b
149135
np_a >>= np_b
150136
assert_array_equal(dp_a, np_a)
151-
"""

tests/test_mathematical.py

+28
Original file line numberDiff line numberDiff line change
@@ -1177,3 +1177,31 @@ def test_mean_scalar(self):
11771177
result = dp_array.mean()
11781178
expected = np_array.mean()
11791179
assert_allclose(expected, result)
1180+
1181+
1182+
@pytest.mark.parametrize(
1183+
"dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True)
1184+
)
1185+
def test_inplace_remainder(dtype):
1186+
size = 21
1187+
np_a = numpy.arange(size, dtype=dtype)
1188+
dp_a = dpnp.arange(size, dtype=dtype)
1189+
1190+
np_a %= 4
1191+
dp_a %= 4
1192+
1193+
assert_allclose(dp_a, np_a)
1194+
1195+
1196+
@pytest.mark.parametrize(
1197+
"dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True)
1198+
)
1199+
def test_inplace_floor_divide(dtype):
1200+
size = 21
1201+
np_a = numpy.arange(size, dtype=dtype)
1202+
dp_a = dpnp.arange(size, dtype=dtype)
1203+
1204+
np_a //= 4
1205+
dp_a //= 4
1206+
1207+
assert_allclose(dp_a, np_a)

tests/test_usm_type.py

+27
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,35 @@ def test_coerced_usm_types_remainder(usm_type_x, usm_type_y):
8383
y = dp.arange(100, usm_type=usm_type_y).reshape(10, 10)
8484
y = y.T + 1
8585

86+
z = 100 % y
87+
z = y % 7
8688
z = x % y
8789

90+
# inplace remainder
91+
z %= y
92+
z %= 5
93+
94+
assert x.usm_type == usm_type_x
95+
assert y.usm_type == usm_type_y
96+
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
97+
98+
99+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
100+
@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types)
101+
def test_coerced_usm_types_floor_divide(usm_type_x, usm_type_y):
102+
x = dp.arange(100, usm_type=usm_type_x).reshape(10, 10)
103+
y = dp.arange(100, usm_type=usm_type_y).reshape(10, 10)
104+
x = x + 1.5
105+
y = y.T + 0.5
106+
107+
z = 3.4 // y
108+
z = y // 2.7
109+
z = x // y
110+
111+
# inplace floor_divide
112+
z //= y
113+
z //= 2.5
114+
88115
assert x.usm_type == usm_type_x
89116
assert y.usm_type == usm_type_y
90117
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])

0 commit comments

Comments
 (0)