Skip to content

Commit 439f2b5

Browse files
authored
Use tril() and triu() function from dpctl.tensor (#1286)
* Use tril() function from dpctl.tensor * Use triu() function from dpctl.tensor * Changed tests for tril() and triu() functions. * Skip tests for tril() and triu() functions with usm_type.
1 parent 0f7420e commit 439f2b5

File tree

9 files changed

+107
-168
lines changed

9 files changed

+107
-168
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,7 @@ enum class DPNPFuncName : size_t
370370
DPNP_FN_TRI, /**< Used in numpy.tri() impl */
371371
DPNP_FN_TRI_EXT, /**< Used in numpy.tri() impl, requires extra parameters */
372372
DPNP_FN_TRIL, /**< Used in numpy.tril() impl */
373-
DPNP_FN_TRIL_EXT, /**< Used in numpy.tril() impl, requires extra parameters */
374373
DPNP_FN_TRIU, /**< Used in numpy.triu() impl */
375-
DPNP_FN_TRIU_EXT, /**< Used in numpy.triu() impl, requires extra parameters */
376374
DPNP_FN_TRUNC, /**< Used in numpy.trunc() impl */
377375
DPNP_FN_TRUNC_EXT, /**< Used in numpy.trunc() impl, requires extra parameters */
378376
DPNP_FN_VANDER, /**< Used in numpy.vander() impl */

dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp

-32
Original file line numberDiff line numberDiff line change
@@ -1055,17 +1055,6 @@ void (*dpnp_tril_default_c)(void*,
10551055
const size_t,
10561056
const size_t) = dpnp_tril_c<_DataType>;
10571057

1058-
template <typename _DataType>
1059-
DPCTLSyclEventRef (*dpnp_tril_ext_c)(DPCTLSyclQueueRef,
1060-
void*,
1061-
void*,
1062-
const int,
1063-
shape_elem_type*,
1064-
shape_elem_type*,
1065-
const size_t,
1066-
const size_t,
1067-
const DPCTLEventVectorRef) = dpnp_tril_c<_DataType>;
1068-
10691058
template <typename _DataType>
10701059
DPCTLSyclEventRef dpnp_triu_c(DPCTLSyclQueueRef q_ref,
10711060
void* array_in,
@@ -1218,17 +1207,6 @@ void (*dpnp_triu_default_c)(void*,
12181207
const size_t,
12191208
const size_t) = dpnp_triu_c<_DataType>;
12201209

1221-
template <typename _DataType>
1222-
DPCTLSyclEventRef (*dpnp_triu_ext_c)(DPCTLSyclQueueRef,
1223-
void*,
1224-
void*,
1225-
const int,
1226-
shape_elem_type*,
1227-
shape_elem_type*,
1228-
const size_t,
1229-
const size_t,
1230-
const DPCTLEventVectorRef) = dpnp_triu_c<_DataType>;
1231-
12321210
template <typename _DataType>
12331211
DPCTLSyclEventRef dpnp_zeros_c(DPCTLSyclQueueRef q_ref,
12341212
void* result,
@@ -1439,21 +1417,11 @@ void func_map_init_arraycreation(func_map_t& fmap)
14391417
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_tril_default_c<float>};
14401418
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_tril_default_c<double>};
14411419

1442-
fmap[DPNPFuncName::DPNP_FN_TRIL_EXT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_tril_ext_c<int32_t>};
1443-
fmap[DPNPFuncName::DPNP_FN_TRIL_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_tril_ext_c<int64_t>};
1444-
fmap[DPNPFuncName::DPNP_FN_TRIL_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_tril_ext_c<float>};
1445-
fmap[DPNPFuncName::DPNP_FN_TRIL_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_tril_ext_c<double>};
1446-
14471420
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_triu_default_c<int32_t>};
14481421
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_triu_default_c<int64_t>};
14491422
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_triu_default_c<float>};
14501423
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_triu_default_c<double>};
14511424

1452-
fmap[DPNPFuncName::DPNP_FN_TRIU_EXT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_triu_ext_c<int32_t>};
1453-
fmap[DPNPFuncName::DPNP_FN_TRIU_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_triu_ext_c<int64_t>};
1454-
fmap[DPNPFuncName::DPNP_FN_TRIU_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_triu_ext_c<float>};
1455-
fmap[DPNPFuncName::DPNP_FN_TRIU_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_triu_ext_c<double>};
1456-
14571425
fmap[DPNPFuncName::DPNP_FN_ZEROS][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_zeros_default_c<int32_t>};
14581426
fmap[DPNPFuncName::DPNP_FN_ZEROS][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_zeros_default_c<int64_t>};
14591427
fmap[DPNPFuncName::DPNP_FN_ZEROS][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_zeros_default_c<float>};

dpnp/dpnp_algo/dpnp_algo_arraycreation.pyx

-90
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ __all__ += [
4545
"dpnp_ptp",
4646
"dpnp_trace",
4747
"dpnp_tri",
48-
"dpnp_tril",
49-
"dpnp_triu",
5048
"dpnp_vander",
5149
]
5250

@@ -426,94 +424,6 @@ cpdef utils.dpnp_descriptor dpnp_tri(N, M=None, k=0, dtype=dpnp.float):
426424
return result
427425

428426

429-
cpdef utils.dpnp_descriptor dpnp_tril(utils.dpnp_descriptor m, int k):
430-
cdef shape_type_c input_shape = m.shape
431-
cdef shape_type_c result_shape
432-
433-
if m.ndim == 1:
434-
result_shape = (m.shape[0], m.shape[0])
435-
else:
436-
result_shape = m.shape
437-
438-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(m.dtype)
439-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TRIL_EXT, param1_type, param1_type)
440-
441-
m_obj = m.get_array()
442-
443-
# ceate result array with type given by FPTR data
444-
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape,
445-
kernel_data.return_type,
446-
None,
447-
device=m_obj.sycl_device,
448-
usm_type=m_obj.usm_type,
449-
sycl_queue=m_obj.sycl_queue)
450-
451-
result_sycl_queue = result.get_array().sycl_queue
452-
453-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
454-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
455-
456-
cdef custom_1in_1out_func_ptr_t func = <custom_1in_1out_func_ptr_t > kernel_data.ptr
457-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
458-
m.get_data(),
459-
result.get_data(),
460-
k,
461-
input_shape.data(),
462-
result_shape.data(),
463-
m.ndim,
464-
result.ndim,
465-
NULL) # dep_events_ref
466-
467-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
468-
c_dpctl.DPCTLEvent_Delete(event_ref)
469-
470-
return result
471-
472-
473-
cpdef utils.dpnp_descriptor dpnp_triu(utils.dpnp_descriptor m, int k):
474-
cdef shape_type_c input_shape = m.shape
475-
cdef shape_type_c result_shape
476-
477-
if m.ndim == 1:
478-
result_shape = (m.shape[0], m.shape[0])
479-
else:
480-
result_shape = m.shape
481-
482-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(m.dtype)
483-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TRIU_EXT, param1_type, param1_type)
484-
485-
m_obj = m.get_array()
486-
487-
# ceate result array with type given by FPTR data
488-
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape,
489-
kernel_data.return_type,
490-
None,
491-
device=m_obj.sycl_device,
492-
usm_type=m_obj.usm_type,
493-
sycl_queue=m_obj.sycl_queue)
494-
495-
result_sycl_queue = result.get_array().sycl_queue
496-
497-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
498-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
499-
500-
cdef custom_1in_1out_func_ptr_t func = <custom_1in_1out_func_ptr_t > kernel_data.ptr
501-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
502-
m.get_data(),
503-
result.get_data(),
504-
k,
505-
input_shape.data(),
506-
result_shape.data(),
507-
m.ndim,
508-
result.ndim,
509-
NULL) # dep_events_ref
510-
511-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
512-
c_dpctl.DPCTLEvent_Delete(event_ref)
513-
514-
return result
515-
516-
517427
cpdef utils.dpnp_descriptor dpnp_vander(utils.dpnp_descriptor x1, int N, int increasing):
518428
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
519429
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_VANDER_EXT, param1_type, DPNP_FT_NONE)

dpnp/dpnp_container.py

+14
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
"eye",
4949
"full",
5050
"ones"
51+
"tril",
52+
"triu",
5153
"zeros",
5254
]
5355

@@ -200,6 +202,18 @@ def ones(shape,
200202
return dpnp_array(array_obj.shape, buffer=array_obj, order=order)
201203

202204

205+
def tril(x1, /, *, k=0):
206+
""""Creates `dpnp_array` as lower triangular part of an input array."""
207+
array_obj = dpt.tril(x1.get_array() if isinstance(x1, dpnp_array) else x1, k)
208+
return dpnp_array(array_obj.shape, buffer=array_obj, order="K")
209+
210+
211+
def triu(x1, /, *, k=0):
212+
""""Creates `dpnp_array` as upper triangular part of an input array."""
213+
array_obj = dpt.triu(x1.get_array() if isinstance(x1, dpnp_array) else x1, k)
214+
return dpnp_array(array_obj.shape, buffer=array_obj, order="K")
215+
216+
203217
def zeros(shape,
204218
*,
205219
dtype=None,

dpnp/dpnp_iface_arraycreation.py

+43-14
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
import numpy
4444
import dpnp
45+
import operator
4546

4647
import dpnp.config as config
4748
from dpnp.dpnp_algo import *
@@ -1332,14 +1333,20 @@ def tri(N, M=None, k=0, dtype=dpnp.float, **kwargs):
13321333
return call_origin(numpy.tri, N, M, k, dtype, **kwargs)
13331334

13341335

1335-
def tril(x1, k=0):
1336+
def tril(x1, /, *, k=0):
13361337
"""
13371338
Lower triangle of an array.
13381339
13391340
Return a copy of an array with elements above the `k`-th diagonal zeroed.
13401341
13411342
For full documentation refer to :obj:`numpy.tril`.
13421343
1344+
Limitations
1345+
-----------
1346+
Parameter `x1` is supported as :class:`dpnp.dpnp_array` or :class:`dpctl.tensor.usm_ndarray` with two or more dimensions.
1347+
Parameter `k` is supported only of integer data type.
1348+
Otherwise the function will be executed sequentially on CPU.
1349+
13431350
Examples
13441351
--------
13451352
>>> import dpnp as np
@@ -1351,17 +1358,25 @@ def tril(x1, k=0):
13511358
13521359
"""
13531360

1354-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
1355-
if x1_desc:
1356-
if not isinstance(k, int):
1357-
pass
1358-
else:
1359-
return dpnp_tril(x1_desc, k).get_pyobj()
1361+
_k = None
1362+
try:
1363+
_k = operator.index(k)
1364+
except TypeError:
1365+
pass
1366+
1367+
if not isinstance(x1, (dpnp.ndarray, dpt.usm_ndarray)):
1368+
pass
1369+
elif x1.ndim < 2:
1370+
pass
1371+
elif _k is None:
1372+
pass
1373+
else:
1374+
return dpnp_container.tril(x1, k=_k)
13601375

13611376
return call_origin(numpy.tril, x1, k)
13621377

13631378

1364-
def triu(x1, k=0):
1379+
def triu(x1, /, *, k=0):
13651380
"""
13661381
Upper triangle of an array.
13671382
@@ -1370,6 +1385,12 @@ def triu(x1, k=0):
13701385
13711386
For full documentation refer to :obj:`numpy.triu`.
13721387
1388+
Limitations
1389+
-----------
1390+
Parameter `x1` is supported as :class:`dpnp.dpnp_array` or :class:`dpctl.tensor.usm_ndarray` with two or more dimensions.
1391+
Parameter `k` is supported only of integer data type.
1392+
Otherwise the function will be executed sequentially on CPU.
1393+
13731394
Examples
13741395
--------
13751396
>>> import dpnp as np
@@ -1381,12 +1402,20 @@ def triu(x1, k=0):
13811402
13821403
"""
13831404

1384-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
1385-
if x1_desc:
1386-
if not isinstance(k, int):
1387-
pass
1388-
else:
1389-
return dpnp_triu(x1_desc, k).get_pyobj()
1405+
_k = None
1406+
try:
1407+
_k = operator.index(k)
1408+
except TypeError:
1409+
pass
1410+
1411+
if not isinstance(x1, (dpnp.ndarray, dpt.usm_ndarray)):
1412+
pass
1413+
elif x1.ndim < 2:
1414+
pass
1415+
elif _k is None:
1416+
pass
1417+
else:
1418+
return dpnp_container.triu(x1, k=_k)
13901419

13911420
return call_origin(numpy.triu, x1, k)
13921421

tests/test_arraycreation.py

+25-24
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616

1717
import tempfile
18+
import operator
1819

1920

2021
@pytest.mark.parametrize("start",
@@ -258,48 +259,48 @@ def test_tri_default_dtype():
258259

259260

260261
@pytest.mark.parametrize("k",
261-
[-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6],
262-
ids=['-6', '-5', '-4', '-3', '-2', '-1', '0', '1', '2', '3', '4', '5', '6'])
262+
[-3, -2, -1, 0, 1, 2, 3, 4, 5,
263+
numpy.array(1), dpnp.array(2), dpt.asarray(3)],
264+
ids=['-3', '-2', '-1', '0', '1', '2', '3', '4', '5',
265+
'np.array(1)', 'dpnp.array(2)', 'dpt.asarray(3)'])
263266
@pytest.mark.parametrize("m",
264-
[[0, 1, 2, 3, 4],
265-
[1, 1, 1, 1, 1],
266-
[[0, 0], [0, 0]],
267+
[[[0, 0], [0, 0]],
267268
[[1, 2], [1, 2]],
268269
[[1, 2], [3, 4]],
269270
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
270271
[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]],
271-
ids=['[0, 1, 2, 3, 4]',
272-
'[1, 1, 1, 1, 1]',
273-
'[[0, 0], [0, 0]]',
272+
ids=['[[0, 0], [0, 0]]',
274273
'[[1, 2], [1, 2]]',
275274
'[[1, 2], [3, 4]]',
276275
'[[0, 1, 2], [3, 4, 5], [6, 7, 8]]',
277276
'[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]'])
278-
def test_tril(m, k):
279-
a = numpy.array(m)
277+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False))
278+
def test_tril(m, k, dtype):
279+
a = numpy.array(m, dtype=dtype)
280280
ia = dpnp.array(a)
281-
expected = numpy.tril(a, k)
282-
result = dpnp.tril(ia, k)
281+
expected = numpy.tril(a, k=operator.index(k))
282+
result = dpnp.tril(ia, k=k)
283283
assert_array_equal(expected, result)
284284

285285

286286
@pytest.mark.parametrize("k",
287-
[-4, -3, -2, -1, 0, 1, 2, 3, 4],
288-
ids=['-4', '-3', '-2', '-1', '0', '1', '2', '3', '4'])
287+
[-3, -2, -1, 0, 1, 2, 3, 4, 5,
288+
numpy.array(1), dpnp.array(2), dpt.asarray(3)],
289+
ids=['-3', '-2', '-1', '0', '1', '2', '3', '4', '5',
290+
'np.array(1)', 'dpnp.array(2)', 'dpt.asarray(3)'])
289291
@pytest.mark.parametrize("m",
290-
[[0, 1, 2, 3, 4],
291-
[[1, 2], [3, 4]],
292+
[[[1, 2], [3, 4]],
292293
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
293294
[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]],
294-
ids=['[0, 1, 2, 3, 4]',
295-
'[[1, 2], [3, 4]]',
295+
ids=['[[1, 2], [3, 4]]',
296296
'[[0, 1, 2], [3, 4, 5], [6, 7, 8]]',
297297
'[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]'])
298-
def test_triu(m, k):
299-
a = numpy.array(m)
298+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False))
299+
def test_triu(m, k, dtype):
300+
a = numpy.array(m, dtype=dtype)
300301
ia = dpnp.array(a)
301-
expected = numpy.triu(a, k)
302-
result = dpnp.triu(ia, k)
302+
expected = numpy.triu(a, k=operator.index(k))
303+
result = dpnp.triu(ia, k=k)
303304
assert_array_equal(expected, result)
304305

305306

@@ -309,8 +310,8 @@ def test_triu(m, k):
309310
def test_triu_size_null(k):
310311
a = numpy.ones(shape=(1, 2, 0))
311312
ia = dpnp.array(a)
312-
expected = numpy.triu(a, k)
313-
result = dpnp.triu(ia, k)
313+
expected = numpy.triu(a, k=k)
314+
result = dpnp.triu(ia, k=k)
314315
assert_array_equal(expected, result)
315316

316317

0 commit comments

Comments
 (0)