From 2a4e201d625d4d2c71925efe2fdd9f08e513ed61 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Fri, 23 Feb 2024 07:33:39 -0600 Subject: [PATCH 1/2] update dpnp.kron --- dpnp/backend/include/dpnp_iface_fptr.hpp | 2 - dpnp/backend/kernels/dpnp_krnl_linalg.cpp | 73 ------------- dpnp/dpnp_algo/CMakeLists.txt | 1 - dpnp/dpnp_algo/dpnp_algo.pxd | 2 - dpnp/dpnp_algo/dpnp_algo.pyx | 1 - dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi | 106 ------------------ dpnp/dpnp_iface_linearalgebra.py | 68 ++++++++++-- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 30 +++++- tests/test_product.py | 113 ++++++++++++++++++++ tests/test_sycl_queue.py | 1 + tests/test_usm_type.py | 1 + 11 files changed, 202 insertions(+), 196 deletions(-) delete mode 100644 dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index 9660d290e5cf..4032b0de5a99 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -174,8 +174,6 @@ enum class DPNPFuncName : size_t DPNP_FN_INV, /**< Used in numpy.linalg.inv() impl */ DPNP_FN_INVERT, /**< Used in numpy.invert() impl */ DPNP_FN_KRON, /**< Used in numpy.kron() impl */ - DPNP_FN_KRON_EXT, /**< Used in numpy.kron() impl, requires extra parameters - */ DPNP_FN_LEFT_SHIFT, /**< Used in numpy.left_shift() impl */ DPNP_FN_LOG, /**< Used in numpy.log() impl */ DPNP_FN_LOG10, /**< Used in numpy.log10() impl */ diff --git a/dpnp/backend/kernels/dpnp_krnl_linalg.cpp b/dpnp/backend/kernels/dpnp_krnl_linalg.cpp index 8f70ddd01e33..1dc2783d48cc 100644 --- a/dpnp/backend/kernels/dpnp_krnl_linalg.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_linalg.cpp @@ -499,18 +499,6 @@ void (*dpnp_kron_default_c)(void *, size_t) = dpnp_kron_c<_DataType1, _DataType2, _ResultType>; -template -DPCTLSyclEventRef (*dpnp_kron_ext_c)(DPCTLSyclQueueRef, - void *, - void *, - void *, - shape_elem_type *, - shape_elem_type *, - shape_elem_type *, - size_t, - const DPCTLEventVectorRef) = - dpnp_kron_c<_DataType1, _DataType2, _ResultType>; - template DPCTLSyclEventRef dpnp_matrix_rank_c(DPCTLSyclQueueRef q_ref, @@ -890,67 +878,6 @@ void func_map_init_linalg_func(func_map_t &fmap) (void *)dpnp_kron_default_c, std::complex, std::complex>}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_INT] = { - eft_INT, (void *)dpnp_kron_ext_c}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_LNG] = { - eft_LNG, (void *)dpnp_kron_ext_c}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_FLT] = { - eft_FLT, (void *)dpnp_kron_ext_c}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_DBL] = { - eft_DBL, (void *)dpnp_kron_ext_c}; - // fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_C128] = { - // eft_C128, (void*)dpnp_kron_ext_c, - // std::complex>}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_INT] = { - eft_LNG, (void *)dpnp_kron_ext_c}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_LNG] = { - eft_LNG, (void *)dpnp_kron_ext_c}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_FLT] = { - eft_FLT, (void *)dpnp_kron_ext_c}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_DBL] = { - eft_DBL, (void *)dpnp_kron_ext_c}; - // fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_C128] = { - // eft_C128, (void*)dpnp_kron_ext_c, - // std::complex>}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_INT] = { - eft_FLT, (void *)dpnp_kron_ext_c}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_LNG] = { - eft_FLT, (void *)dpnp_kron_ext_c}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_FLT] = { - eft_FLT, (void *)dpnp_kron_ext_c}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_DBL] = { - eft_DBL, (void *)dpnp_kron_ext_c}; - // fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_C128] = { - // eft_C128, (void*)dpnp_kron_ext_c, - // std::complex>}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_INT] = { - eft_DBL, (void *)dpnp_kron_ext_c}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_LNG] = { - eft_DBL, (void *)dpnp_kron_ext_c}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_FLT] = { - eft_DBL, (void *)dpnp_kron_ext_c}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_DBL] = { - eft_DBL, (void *)dpnp_kron_ext_c}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_C128] = { - eft_C128, (void *)dpnp_kron_ext_c, - std::complex>}; - // fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_INT] = { - // eft_C128, (void*)dpnp_kron_ext_c, int32_t, - // std::complex>}; - // fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_LNG] = { - // eft_C128, (void*)dpnp_kron_ext_c, int64_t, - // std::complex>}; - // fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_FLT] = { - // eft_C128, (void*)dpnp_kron_ext_c, float, - // std::complex>}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_DBL] = { - eft_C128, (void *)dpnp_kron_ext_c, double, - std::complex>}; - fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_C128] = { - eft_C128, - (void *)dpnp_kron_ext_c, std::complex, - std::complex>}; - fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_INT][eft_INT] = { eft_INT, (void *)dpnp_matrix_rank_default_c}; fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_LNG][eft_LNG] = { diff --git a/dpnp/dpnp_algo/CMakeLists.txt b/dpnp/dpnp_algo/CMakeLists.txt index 442fa8e82b16..2c3a49c6be4b 100644 --- a/dpnp/dpnp_algo/CMakeLists.txt +++ b/dpnp/dpnp_algo/CMakeLists.txt @@ -1,6 +1,5 @@ set(dpnp_algo_pyx_deps - ${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_linearalgebra.pxi ${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_statistics.pxi ${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_trigonometric.pxi ${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_sorting.pxi diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index 67db6a07f756..528b03e3b583 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -74,8 +74,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_FMOD_EXT DPNP_FN_FULL DPNP_FN_FULL_LIKE - DPNP_FN_KRON - DPNP_FN_KRON_EXT DPNP_FN_MAXIMUM DPNP_FN_MAXIMUM_EXT DPNP_FN_MEDIAN diff --git a/dpnp/dpnp_algo/dpnp_algo.pyx b/dpnp/dpnp_algo/dpnp_algo.pyx index 9cbdaf3f1df3..3013cd76094f 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pyx +++ b/dpnp/dpnp_algo/dpnp_algo.pyx @@ -60,7 +60,6 @@ __all__ = [ include "dpnp_algo_arraycreation.pxi" include "dpnp_algo_indexing.pxi" -include "dpnp_algo_linearalgebra.pxi" include "dpnp_algo_logic.pxi" include "dpnp_algo_mathematical.pxi" include "dpnp_algo_sorting.pxi" diff --git a/dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi b/dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi deleted file mode 100644 index 6585645c45a9..000000000000 --- a/dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi +++ /dev/null @@ -1,106 +0,0 @@ -# cython: language_level=3 -# cython: linetrace=True -# -*- coding: utf-8 -*- -# ***************************************************************************** -# Copyright (c) 2016-2024, Intel Corporation -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# - Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# - Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -# THE POSSIBILITY OF SUCH DAMAGE. -# ***************************************************************************** - -"""Module Backend (linear algebra routines) - -This module contains interface functions between C backend layer -and the rest of the library - -""" - -# NO IMPORTs here. All imports must be placed into main "dpnp_algo.pyx" file - -__all__ += [ - "dpnp_kron", -] - - -# C function pointer to the C library template functions -ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_2in_1out_shapes_t)(c_dpctl.DPCTLSyclQueueRef, - void *, void * , void * , shape_elem_type * , - shape_elem_type *, shape_elem_type * , size_t, - const c_dpctl.DPCTLEventVectorRef) - - -cpdef utils.dpnp_descriptor dpnp_kron(dpnp_descriptor in_array1, dpnp_descriptor in_array2): - cdef size_t ndim = max(in_array1.ndim, in_array2.ndim) - - cdef shape_type_c in_array1_shape - if in_array1.ndim < ndim: - for i in range(ndim - in_array1.ndim): - in_array1_shape.push_back(1) - for i in range(in_array1.ndim): - in_array1_shape.push_back(in_array1.shape[i]) - - cdef shape_type_c in_array2_shape - if in_array2.ndim < ndim: - for i in range(ndim - in_array2.ndim): - in_array2_shape.push_back(1) - for i in range(in_array2.ndim): - in_array2_shape.push_back(in_array2.shape[i]) - - cdef shape_type_c result_shape - for i in range(ndim): - result_shape.push_back(in_array1_shape[i] * in_array2_shape[i]) - - # convert string type names (array.dtype) to C enum DPNPFuncType - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(in_array1.dtype) - cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(in_array2.dtype) - - # get the FPTR data structure - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_KRON_EXT, param1_type, param2_type) - - result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(in_array1, in_array2) - - # create result array with type given by FPTR data - cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, - kernel_data.return_type, - None, - device=result_sycl_device, - usm_type=result_usm_type, - sycl_queue=result_sycl_queue) - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - cdef fptr_2in_1out_shapes_t func = kernel_data.ptr - # call FPTR function - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - in_array1.get_data(), - in_array2.get_data(), - result.get_data(), - in_array1_shape.data(), - in_array2_shape.data(), - result_shape.data(), - ndim, - NULL) # dep_events_ref - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return result diff --git a/dpnp/dpnp_iface_linearalgebra.py b/dpnp/dpnp_iface_linearalgebra.py index 312a101524b8..4f8f33015a0b 100644 --- a/dpnp/dpnp_iface_linearalgebra.py +++ b/dpnp/dpnp_iface_linearalgebra.py @@ -44,14 +44,12 @@ import dpnp # pylint: disable=no-name-in-module -from .dpnp_algo import ( - dpnp_kron, -) from .dpnp_utils import ( call_origin, ) from .dpnp_utils.dpnp_utils_linearalgebra import ( dpnp_dot, + dpnp_kron, dpnp_matmul, ) @@ -305,22 +303,72 @@ def inner(a, b): return dpnp.tensordot(a, b, axes=(-1, -1)) -def kron(x1, x2): +def kron(a, b): """ Returns the kronecker product of two arrays. For full documentation refer to :obj:`numpy.kron`. - .. seealso:: :obj:`dpnp.outer` returns the outer product of two arrays. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray, scalar} + First input array. Both inputs `a` and `b` can not be scalars + at the same time. + b : {dpnp.ndarray, usm_ndarray, scalar} + Second input array. Both inputs `a` and `b` can not be scalars + at the same time. + + Returns + ------- + out : dpnp.ndarray + Returns the Kronecker product. + + See Also + -------- + :obj:`dpnp.outer` : Returns the outer product of two arrays. + + Examples + -------- + >>> import dpnp as np + >>> a = np.array([1, 10, 100]) + >>> b = np.array([5, 6, 7]) + >>> np.kron(a, b) + array([ 5, 6, 7, ..., 500, 600, 700]) + >>> np.kron(b, a) + array([ 5, 50, 500, ..., 7, 70, 700]) + + >>> np.kron(np.eye(2), np.ones((2,2))) + array([[1., 1., 0., 0.], + [1., 1., 0., 0.], + [0., 0., 1., 1.], + [0., 0., 1., 1.]]) + + >>> a = np.arange(100).reshape((2,5,2,5)) + >>> b = np.arange(24).reshape((2,3,4)) + >>> c = np.kron(a,b) + >>> c.shape + (2, 10, 6, 20) + >>> I = (1,3,0,2) + >>> J = (0,2,1) + >>> J1 = (0,) + J # extend to ndim=4 + >>> S1 = (1,) + b.shape + >>> K = tuple(np.array(I) * np.array(S1) + np.array(J1)) + >>> c[K] == a[I]*b[J] + array(True) """ - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False) - if x1_desc and x2_desc: - return dpnp_kron(x1_desc, x2_desc).get_pyobj() + dpnp.check_supported_arrays_type(a, b, scalar_type=True) + + if dpnp.isscalar(a) or dpnp.isscalar(b): + return dpnp.multiply(a, b) + + a_ndim = a.ndim + b_ndim = b.ndim + if a_ndim == 0 or b_ndim == 0: + return dpnp.multiply(a, b) - return call_origin(numpy.kron, x1, x2) + return dpnp_kron(a, b, a_ndim, b_ndim) def matmul( diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index ff7b7a05972f..080c3c52b3c8 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -35,7 +35,7 @@ from dpnp.dpnp_array import dpnp_array from dpnp.dpnp_utils import get_usm_allocations -__all__ = ["dpnp_cross", "dpnp_dot", "dpnp_matmul"] +__all__ = ["dpnp_cross", "dpnp_dot", "dpnp_kron", "dpnp_matmul"] def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue): @@ -476,6 +476,34 @@ def dpnp_cross(a, b, cp, exec_q): return cp +def dpnp_kron(a, b, a_ndim, b_ndim): + """Returns the kronecker product of two arrays.""" + + a_shape = a.shape + b_shape = b.shape + if not a.flags.contiguous: + a = dpnp.reshape(a, a_shape) + if not b.flags.contiguous: + b = dpnp.reshape(b, b_shape) + + # Equalise the shapes by prepending smaller one with 1s + a_shape = (1,) * max(0, b_ndim - a_ndim) + a_shape + b_shape = (1,) * max(0, a_ndim - b_ndim) + b_shape + + # Insert empty dimensions + a_arr = dpnp.expand_dims(a, axis=tuple(range(b_ndim - a_ndim))) + b_arr = dpnp.expand_dims(b, axis=tuple(range(a_ndim - b_ndim))) + + # Compute the product + ndim = max(b_ndim, a_ndim) + a_arr = dpnp.expand_dims(a_arr, axis=tuple(range(1, 2 * ndim, 2))) + b_arr = dpnp.expand_dims(b_arr, axis=tuple(range(0, 2 * ndim, 2))) + result = dpnp.multiply(a_arr, b_arr) + + # Reshape back + return result.reshape(tuple(numpy.multiply(a_shape, b_shape))) + + def dpnp_dot(a, b, /, out=None, *, conjugate=False): """ Return the dot product of two arrays. diff --git a/tests/test_product.py b/tests/test_product.py index e1b95b939f9e..37ef8177c606 100644 --- a/tests/test_product.py +++ b/tests/test_product.py @@ -660,6 +660,119 @@ def test_inner_error(self): dpnp.inner(a, b) +class TestKron: + @pytest.mark.parametrize("dtype", get_all_dtypes()) + def test_kron_scalar(self, dtype): + a = 2 + b = numpy.array(numpy.random.uniform(-5, 5, 10), dtype=dtype) + ib = dpnp.array(b) + + result = dpnp.kron(a, ib) + expected = numpy.kron(a, b) + if dtype in [numpy.int32, numpy.float32, numpy.complex64]: + assert_dtype_allclose(result, expected, check_only_type_kind=True) + else: + assert_dtype_allclose(result, expected) + + result = dpnp.kron(ib, a) + expected = numpy.kron(b, a) + if dtype in [numpy.int32, numpy.float32, numpy.complex64]: + assert_dtype_allclose(result, expected, check_only_type_kind=True) + else: + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) + @pytest.mark.parametrize( + "shape1, shape2", + [ + ((5,), (5,)), + ((3, 5), (4, 6)), + ((2, 4, 3, 5), (3, 5, 6, 2)), + ((4, 3, 5), (3, 5, 6, 2)), + ((2, 4, 3, 5), (3, 5, 6)), + ((2, 4, 3, 5), (3,)), + ((), (3, 4)), + ((5,), ()), + ], + ) + def test_kron(self, dtype, shape1, shape2): + size1 = numpy.prod(shape1, dtype=int) + size2 = numpy.prod(shape2, dtype=int) + a = numpy.array( + numpy.random.uniform(-5, 5, size1), dtype=dtype + ).reshape(shape1) + b = numpy.array( + numpy.random.uniform(-5, 5, size2), dtype=dtype + ).reshape(shape2) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.kron(ia, ib) + expected = numpy.kron(a, b) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_complex_dtypes()) + @pytest.mark.parametrize( + "shape1, shape2", + [ + ((5,), (5,)), + ((3, 5), (4, 6)), + ((2, 4, 3, 5), (3, 5, 6, 2)), + ((4, 3, 5), (3, 5, 6, 2)), + ((2, 4, 3, 5), (3, 5, 6)), + ((2, 4, 3, 5), (3,)), + ((), (3, 4)), + ((5,), ()), + ], + ) + def test_kron(self, dtype, shape1, shape2): + size1 = numpy.prod(shape1, dtype=int) + size2 = numpy.prod(shape2, dtype=int) + x11 = numpy.random.uniform(-5, 5, size1) + x12 = numpy.random.uniform(-5, 5, size1) + x21 = numpy.random.uniform(-5, 5, size2) + x22 = numpy.random.uniform(-5, 5, size2) + a = numpy.array(x11 + 1j * x12, dtype=dtype).reshape(shape1) + b = numpy.array(x21 + 1j * x22, dtype=dtype).reshape(shape2) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.kron(ia, ib) + expected = numpy.kron(a, b) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype1", get_all_dtypes()) + @pytest.mark.parametrize("dtype2", get_all_dtypes()) + def test_kron_input_dtype_matrix(self, dtype1, dtype2): + a = numpy.array(numpy.random.uniform(-5, 5, 10), dtype=dtype1) + b = numpy.array(numpy.random.uniform(-5, 5, 10), dtype=dtype2) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.kron(ia, ib) + expected = numpy.kron(a, b) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_kron_strided(self, dtype): + a = numpy.arange(20, dtype=dtype) + b = numpy.arange(20, dtype=dtype) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.kron(ia[::3], ib[::3]) + expected = numpy.kron(a[::3], b[::3]) + assert_dtype_allclose(result, expected) + + result = dpnp.kron(ia, ib[::-1]) + expected = numpy.kron(a, b[::-1]) + assert_dtype_allclose(result, expected) + + result = dpnp.kron(ia[::-4], ib[::-4]) + expected = numpy.kron(a[::-4], b[::-4]) + assert_dtype_allclose(result, expected) + + class TestMultiDot: def setup_method(self): numpy.random.seed(70) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index a7b3e896d3e5..842baf0ade4b 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -580,6 +580,7 @@ def test_reduce_hypot(device): "hypot", [[1.0, 2.0, 3.0, 4.0]], [[-1.0, -2.0, -4.0, -5.0]] ), pytest.param("inner", [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]), + pytest.param("kron", [3.0, 4.0, 5.0], [1.0, 2.0]), pytest.param("logaddexp", [[-1, 2, 5, 9]], [[4, -3, 2, -8]]), pytest.param( "matmul", [[1.0, 0.0], [0.0, 1.0]], [[4.0, 1.0], [1.0, 2.0]] diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index b8d35e80fd68..a90aeb9fab1f 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -525,6 +525,7 @@ def test_1in_1out(func, data, usm_type): "hypot", [[1.0, 2.0, 3.0, 4.0]], [[-1.0, -2.0, -4.0, -5.0]] ), pytest.param("inner", [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]), + pytest.param("kron", [3.0, 4.0, 5.0], [1.0, 2.0]), pytest.param("logaddexp", [[-1, 2, 5, 9]], [[4, -3, 2, -8]]), pytest.param("maximum", [[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]), pytest.param("minimum", [[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]), From 5e89313c1fb4d1c3e4be351c3e3887dd79df3222 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Sat, 24 Feb 2024 14:59:21 -0600 Subject: [PATCH 2/2] address comments --- dpnp/linalg/dpnp_utils_linalg.py | 2 +- tests/test_product.py | 42 ++++++++----------- .../cupy/linalg_tests/test_product.py | 13 ++++++ 3 files changed, 32 insertions(+), 25 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index e6564c70a71f..0305182389b7 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -401,7 +401,7 @@ def _multi_dot(arrays, order, i, j, out=None): def _multi_dot_matrix_chain_order(n, arrays, return_costs=False): """ - Return a dpnp.ndarray that encodes the optimal order of mutiplications. + Return a dpnp.ndarray that encodes the optimal order of multiplications. The optimal order array is then used by `_multi_dot()` to do the multiplication. diff --git a/tests/test_product.py b/tests/test_product.py index 37ef8177c606..daeff9177622 100644 --- a/tests/test_product.py +++ b/tests/test_product.py @@ -8,6 +8,13 @@ from .helper import assert_dtype_allclose, get_all_dtypes, get_complex_dtypes +def _assert_selective_dtype_allclose(result, expected, dtype): + if dtype in [numpy.int32, numpy.float32, numpy.complex64]: + assert_dtype_allclose(result, expected, check_only_type_kind=True) + else: + assert_dtype_allclose(result, expected) + + class TestCross: def setup_method(self): numpy.random.seed(42) @@ -237,11 +244,11 @@ def test_dot_scalar(self, dtype): result = dpnp.dot(a, ib) expected = numpy.dot(a, b) - assert_allclose(result, expected) + _assert_selective_dtype_allclose(result, expected, dtype) result = dpnp.dot(ib, a) expected = numpy.dot(b, a) - assert_allclose(result, expected) + _assert_selective_dtype_allclose(result, expected, dtype) @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) @pytest.mark.parametrize( @@ -408,7 +415,7 @@ def test_dot_out_scalar(self, dtype): expected = numpy.dot(a, b) assert result is dp_out - assert_allclose(result, expected) + _assert_selective_dtype_allclose(result, expected, dtype) @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize( @@ -555,17 +562,11 @@ def test_inner_scalar(self, dtype): result = dpnp.inner(a, ib) expected = numpy.inner(a, b) - if dtype in [numpy.int32, numpy.float32, numpy.complex64]: - assert_dtype_allclose(result, expected, check_only_type_kind=True) - else: - assert_dtype_allclose(result, expected) + _assert_selective_dtype_allclose(result, expected, dtype) result = dpnp.inner(ib, a) expected = numpy.inner(b, a) - if dtype in [numpy.int32, numpy.float32, numpy.complex64]: - assert_dtype_allclose(result, expected, check_only_type_kind=True) - else: - assert_dtype_allclose(result, expected) + _assert_selective_dtype_allclose(result, expected, dtype) @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) @pytest.mark.parametrize( @@ -669,17 +670,11 @@ def test_kron_scalar(self, dtype): result = dpnp.kron(a, ib) expected = numpy.kron(a, b) - if dtype in [numpy.int32, numpy.float32, numpy.complex64]: - assert_dtype_allclose(result, expected, check_only_type_kind=True) - else: - assert_dtype_allclose(result, expected) + _assert_selective_dtype_allclose(result, expected, dtype) result = dpnp.kron(ib, a) expected = numpy.kron(b, a) - if dtype in [numpy.int32, numpy.float32, numpy.complex64]: - assert_dtype_allclose(result, expected, check_only_type_kind=True) - else: - assert_dtype_allclose(result, expected) + _assert_selective_dtype_allclose(result, expected, dtype) @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) @pytest.mark.parametrize( @@ -1001,11 +996,11 @@ def test_tensordot_scalar(self, dtype): result = dpnp.tensordot(a, ib, axes=0) expected = numpy.tensordot(a, b, axes=0) - assert_allclose(result, expected) + _assert_selective_dtype_allclose(result, expected, dtype) result = dpnp.tensordot(ib, a, axes=0) expected = numpy.tensordot(b, a, axes=0) - assert_allclose(result, expected) + _assert_selective_dtype_allclose(result, expected, dtype) @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) @pytest.mark.parametrize("axes", [0, 1, 2]) @@ -1060,7 +1055,6 @@ def test_tensordot_axes(self, dtype, axes): ia = dpnp.array(a) ib = dpnp.array(b) - print(a.dtype, ia.dtype) result = dpnp.tensordot(ia, ib, axes=axes) expected = numpy.tensordot(a, b, axes=axes) assert_dtype_allclose(result, expected) @@ -1154,11 +1148,11 @@ def test_vdot_scalar(self, dtype): result = dpnp.vdot(ia, b) expected = numpy.vdot(a, b) - assert_allclose(result, expected) + _assert_selective_dtype_allclose(result, expected, dtype) result = dpnp.vdot(b, ia) expected = numpy.vdot(b, a) - assert_allclose(result, expected) + _assert_selective_dtype_allclose(result, expected, dtype) @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) @pytest.mark.parametrize( diff --git a/tests/third_party/cupy/linalg_tests/test_product.py b/tests/third_party/cupy/linalg_tests/test_product.py index 94e288beafd6..17d7861cb64e 100644 --- a/tests/third_party/cupy/linalg_tests/test_product.py +++ b/tests/third_party/cupy/linalg_tests/test_product.py @@ -392,6 +392,19 @@ def test_zerodim_kron(self, xp, dtype): b = testing.shaped_arange((4, 5), xp, dtype) return xp.kron(a, b) + @pytest.mark.parametrize( + "a, b", + [ + # (2, 3.0), # dpnp does not support both inputs as scalar + (2, [[0, -1j / 2], [1j / 2, 0]]), + ([[0, -1j / 2], [1j / 2, 0]], 2), + ], + ) + @testing.numpy_cupy_allclose(type_check=has_support_aspect64()) + def test_kron_accepts_numbers_as_arguments(self, a, b, xp): + args = [xp.array(arg) if type(arg) == list else arg for arg in [a, b]] + return xp.kron(*args) + @testing.parameterize( *testing.product(