Skip to content

Commit 05402b0

Browse files
committed
Use tril() function from dpctl.tensor
1 parent 351c6a6 commit 05402b0

10 files changed

+51
-84
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

-1
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,6 @@ enum class DPNPFuncName : size_t
370370
DPNP_FN_TRI, /**< Used in numpy.tri() impl */
371371
DPNP_FN_TRI_EXT, /**< Used in numpy.tri() impl, requires extra parameters */
372372
DPNP_FN_TRIL, /**< Used in numpy.tril() impl */
373-
DPNP_FN_TRIL_EXT, /**< Used in numpy.tril() impl, requires extra parameters */
374373
DPNP_FN_TRIU, /**< Used in numpy.triu() impl */
375374
DPNP_FN_TRIU_EXT, /**< Used in numpy.triu() impl, requires extra parameters */
376375
DPNP_FN_TRUNC, /**< Used in numpy.trunc() impl */

dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp

-16
Original file line numberDiff line numberDiff line change
@@ -1055,17 +1055,6 @@ void (*dpnp_tril_default_c)(void*,
10551055
const size_t,
10561056
const size_t) = dpnp_tril_c<_DataType>;
10571057

1058-
template <typename _DataType>
1059-
DPCTLSyclEventRef (*dpnp_tril_ext_c)(DPCTLSyclQueueRef,
1060-
void*,
1061-
void*,
1062-
const int,
1063-
shape_elem_type*,
1064-
shape_elem_type*,
1065-
const size_t,
1066-
const size_t,
1067-
const DPCTLEventVectorRef) = dpnp_tril_c<_DataType>;
1068-
10691058
template <typename _DataType>
10701059
DPCTLSyclEventRef dpnp_triu_c(DPCTLSyclQueueRef q_ref,
10711060
void* array_in,
@@ -1439,11 +1428,6 @@ void func_map_init_arraycreation(func_map_t& fmap)
14391428
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_tril_default_c<float>};
14401429
fmap[DPNPFuncName::DPNP_FN_TRIL][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_tril_default_c<double>};
14411430

1442-
fmap[DPNPFuncName::DPNP_FN_TRIL_EXT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_tril_ext_c<int32_t>};
1443-
fmap[DPNPFuncName::DPNP_FN_TRIL_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_tril_ext_c<int64_t>};
1444-
fmap[DPNPFuncName::DPNP_FN_TRIL_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_tril_ext_c<float>};
1445-
fmap[DPNPFuncName::DPNP_FN_TRIL_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_tril_ext_c<double>};
1446-
14471431
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_triu_default_c<int32_t>};
14481432
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_triu_default_c<int64_t>};
14491433
fmap[DPNPFuncName::DPNP_FN_TRIU][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_triu_default_c<float>};

dpnp/dpnp_algo/dpnp_algo_arraycreation.pyx

-45
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ __all__ += [
4545
"dpnp_ptp",
4646
"dpnp_trace",
4747
"dpnp_tri",
48-
"dpnp_tril",
4948
"dpnp_triu",
5049
"dpnp_vander",
5150
]
@@ -426,50 +425,6 @@ cpdef utils.dpnp_descriptor dpnp_tri(N, M=None, k=0, dtype=dpnp.float):
426425
return result
427426

428427

429-
cpdef utils.dpnp_descriptor dpnp_tril(utils.dpnp_descriptor m, int k):
430-
cdef shape_type_c input_shape = m.shape
431-
cdef shape_type_c result_shape
432-
433-
if m.ndim == 1:
434-
result_shape = (m.shape[0], m.shape[0])
435-
else:
436-
result_shape = m.shape
437-
438-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(m.dtype)
439-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TRIL_EXT, param1_type, param1_type)
440-
441-
m_obj = m.get_array()
442-
443-
# ceate result array with type given by FPTR data
444-
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape,
445-
kernel_data.return_type,
446-
None,
447-
device=m_obj.sycl_device,
448-
usm_type=m_obj.usm_type,
449-
sycl_queue=m_obj.sycl_queue)
450-
451-
result_sycl_queue = result.get_array().sycl_queue
452-
453-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
454-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
455-
456-
cdef custom_1in_1out_func_ptr_t func = <custom_1in_1out_func_ptr_t > kernel_data.ptr
457-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
458-
m.get_data(),
459-
result.get_data(),
460-
k,
461-
input_shape.data(),
462-
result_shape.data(),
463-
m.ndim,
464-
result.ndim,
465-
NULL) # dep_events_ref
466-
467-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
468-
c_dpctl.DPCTLEvent_Delete(event_ref)
469-
470-
return result
471-
472-
473428
cpdef utils.dpnp_descriptor dpnp_triu(utils.dpnp_descriptor m, int k):
474429
cdef shape_type_c input_shape = m.shape
475430
cdef shape_type_c result_shape

dpnp/dpnp_container.py

+7
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
"eye",
4949
"full",
5050
"ones"
51+
"tril"
5152
"zeros",
5253
]
5354

@@ -200,6 +201,12 @@ def ones(shape,
200201
return dpnp_array(array_obj.shape, buffer=array_obj, order=order)
201202

202203

204+
def tril(x1, /, *, k=0):
205+
""""Creates `dpnp_array` as lower triangle of an input array."""
206+
array_obj = dpt.tril(x1.get_array() if isinstance(x1, dpnp_array) else x1, k)
207+
return dpnp_array(array_obj.shape, buffer=array_obj)
208+
209+
203210
def zeros(shape,
204211
*,
205212
dtype=None,

dpnp/dpnp_iface.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@
6666
"dpnp_queue_is_cpu",
6767
"get_dpnp_descriptor",
6868
"get_include",
69-
"get_normalized_queue_device"
69+
"get_normalized_queue_device",
70+
"isarray"
7071
]
7172

7273
from dpnp import (
@@ -338,3 +339,13 @@ def get_normalized_queue_device(obj=None,
338339
if hasattr(dpt._device, 'normalize_queue_device'):
339340
return dpt._device.normalize_queue_device(sycl_queue=sycl_queue, device=device)
340341
return sycl_queue
342+
343+
344+
def isarray(obj):
345+
"""
346+
Return True if:
347+
`obj` has a supported array type
348+
Return False if:
349+
`obj` has an unsupported array type or other data type
350+
"""
351+
return isinstance(obj, (dpnp_array, dpt.usm_ndarray))

dpnp/dpnp_iface_arraycreation.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -1331,14 +1331,20 @@ def tri(N, M=None, k=0, dtype=dpnp.float, **kwargs):
13311331
return call_origin(numpy.tri, N, M, k, dtype, **kwargs)
13321332

13331333

1334-
def tril(x1, k=0):
1334+
def tril(x1, /, *, k=0):
13351335
"""
13361336
Lower triangle of an array.
13371337
13381338
Return a copy of an array with elements above the `k`-th diagonal zeroed.
13391339
13401340
For full documentation refer to :obj:`numpy.tril`.
13411341
1342+
Limitations
1343+
-----------
1344+
Parameter ``x1`` is supported only as :class:`dpnp.dpnp_array` with two or more dimensions.
1345+
Parameter ``k`` is supported only as int data type.
1346+
Otherwise the function will be executed sequentially on CPU.
1347+
13421348
Examples
13431349
--------
13441350
>>> import dpnp as np
@@ -1350,12 +1356,14 @@ def tril(x1, k=0):
13501356
13511357
"""
13521358

1353-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
1354-
if x1_desc:
1355-
if not isinstance(k, int):
1356-
pass
1357-
else:
1358-
return dpnp_tril(x1_desc, k).get_pyobj()
1359+
if not dpnp.isarray(x1):
1360+
pass
1361+
elif x1.ndim < 2:
1362+
pass
1363+
elif not isinstance(k, int):
1364+
pass
1365+
else:
1366+
return dpnp_container.tril(x1, k=k)
13591367

13601368
return call_origin(numpy.tril, x1, k)
13611369

tests/skipped_tests.tbl

+3
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,9 @@ tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_
430430
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_mixed_start_stop
431431
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_mixed_start_stop2
432432
tests/third_party/cupy/creation_tests/test_ranges.py::TestRanges::test_linspace_start_stop_list
433+
tests/third_party/cupy/creation_tests/test_matrix.py::TestTriLowerAndUpper_param_0_{shape=(2,)}::test_tril
434+
tests/third_party/cupy/creation_tests/test_matrix.py::TestTriLowerAndUpper_param_0_{shape=(2,)}::test_tril_nega
435+
tests/third_party/cupy/creation_tests/test_matrix.py::TestTriLowerAndUpper_param_0_{shape=(2,)}::test_tril_posi
433436
tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_AxisConcatenator_init1
434437
tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::test_len
435438
tests/third_party/cupy/indexing_tests/test_generate.py::TestC_::test_c_1

tests/skipped_tests_gpu.tbl

+3
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ tests/third_party/cupy/creation_tests/test_matrix.py::TestMatrix::test_diag_cons
190190
tests/third_party/cupy/creation_tests/test_matrix.py::TestMatrix::test_diag_construction_from_tuple
191191
tests/third_party/cupy/creation_tests/test_matrix.py::TestMatrix::test_diag_extraction_from_nested_list
192192
tests/third_party/cupy/creation_tests/test_matrix.py::TestMatrix::test_diag_extraction_from_nested_tuple
193+
tests/third_party/cupy/creation_tests/test_matrix.py::TestTriLowerAndUpper_param_0_{shape=(2,)}::test_tril
194+
tests/third_party/cupy/creation_tests/test_matrix.py::TestTriLowerAndUpper_param_0_{shape=(2,)}::test_tril_nega
195+
tests/third_party/cupy/creation_tests/test_matrix.py::TestTriLowerAndUpper_param_0_{shape=(2,)}::test_tril_posi
193196

194197
tests/third_party/cupy/indexing_tests/test_insert.py::TestFillDiagonal_param_4_{shape=(3, 3), val=(2,), wrap=True}::test_1darray
195198
tests/third_party/cupy/indexing_tests/test_insert.py::TestFillDiagonal_param_4_{shape=(3, 3), val=(2,), wrap=True}::test_fill_diagonal

tests/test_arraycreation.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -258,28 +258,25 @@ def test_tri_default_dtype():
258258

259259

260260
@pytest.mark.parametrize("k",
261-
[-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6],
262-
ids=['-6', '-5', '-4', '-3', '-2', '-1', '0', '1', '2', '3', '4', '5', '6'])
261+
[-3, -2, -1, 0, 1, 2, 3, 4, 5],
262+
ids=['-3', '-2', '-1', '0', '1', '2', '3', '4', '5'])
263263
@pytest.mark.parametrize("m",
264-
[[0, 1, 2, 3, 4],
265-
[1, 1, 1, 1, 1],
266-
[[0, 0], [0, 0]],
264+
[[[0, 0], [0, 0]],
267265
[[1, 2], [1, 2]],
268266
[[1, 2], [3, 4]],
269267
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
270268
[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]],
271-
ids=['[0, 1, 2, 3, 4]',
272-
'[1, 1, 1, 1, 1]',
273-
'[[0, 0], [0, 0]]',
269+
ids=['[[0, 0], [0, 0]]',
274270
'[[1, 2], [1, 2]]',
275271
'[[1, 2], [3, 4]]',
276272
'[[0, 1, 2], [3, 4, 5], [6, 7, 8]]',
277273
'[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]'])
278-
def test_tril(m, k):
279-
a = numpy.array(m)
274+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False))
275+
def test_tril(m, k, dtype):
276+
a = numpy.array(m, dtype=dtype)
280277
ia = dpnp.array(a)
281-
expected = numpy.tril(a, k)
282-
result = dpnp.tril(ia, k)
278+
expected = numpy.tril(a, k=k)
279+
result = dpnp.tril(ia, k=k)
283280
assert_array_equal(expected, result)
284281

285282

tests/third_party/cupy/creation_tests/test_matrix.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,13 @@ def test_tril_array_like(self, xp):
157157
@testing.numpy_cupy_array_equal()
158158
def test_tril_nega(self, xp, dtype):
159159
m = testing.shaped_arange(self.shape, xp, dtype)
160-
return xp.tril(m, -1)
160+
return xp.tril(m, k=-1)
161161

162162
@testing.for_all_dtypes(no_complex=True)
163163
@testing.numpy_cupy_array_equal()
164164
def test_tril_posi(self, xp, dtype):
165165
m = testing.shaped_arange(self.shape, xp, dtype)
166-
return xp.tril(m, 1)
166+
return xp.tril(m, k=1)
167167

168168
@testing.for_all_dtypes(no_complex=True)
169169
@testing.numpy_cupy_array_equal()

0 commit comments

Comments
 (0)