diff --git a/dpnp/backend/CMakeLists.txt b/dpnp/backend/CMakeLists.txt index f66aa4be1ae5..52e9cb21985b 100644 --- a/dpnp/backend/CMakeLists.txt +++ b/dpnp/backend/CMakeLists.txt @@ -93,6 +93,7 @@ string(CONCAT COMMON_COMPILE_FLAGS "-fsycl " "-fsycl-device-code-split=per_kernel " "-fno-approx-func " + "-fno-finite-math-only " ) string(CONCAT COMMON_LINK_FLAGS "-fsycl " diff --git a/dpnp/backend/include/dpnp_gen_2arg_3type_tbl.hpp b/dpnp/backend/include/dpnp_gen_2arg_3type_tbl.hpp index 9a3c69aee8e5..e345c6eefea7 100644 --- a/dpnp/backend/include/dpnp_gen_2arg_3type_tbl.hpp +++ b/dpnp/backend/include/dpnp_gen_2arg_3type_tbl.hpp @@ -132,10 +132,10 @@ MACRO_2ARG_3TYPES_OP(dpnp_copysign_c, MACRO_2ARG_3TYPES_OP(dpnp_divide_c, input1_elem / input2_elem, - nullptr, - std::false_type, + x1 / x2, + MACRO_UNPACK_TYPES(bool, std::int32_t, std::int64_t), oneapi::mkl::vm::div, - MACRO_UNPACK_TYPES(float, double)) + MACRO_UNPACK_TYPES(float, double, std::complex, std::complex)) MACRO_2ARG_3TYPES_OP(dpnp_fmod_c, sycl::fmod((double)input1_elem, (double)input2_elem), @@ -169,7 +169,7 @@ MACRO_2ARG_3TYPES_OP(dpnp_minimum_c, // pytest "tests/third_party/cupy/creation_tests/test_ranges.py::TestMgrid::test_mgrid3" // requires multiplication shape1[10] with shape2[10,1] and result expected as shape[10,10] MACRO_2ARG_3TYPES_OP(dpnp_multiply_c, - input1_elem* input2_elem, + input1_elem * input2_elem, x1 * x2, MACRO_UNPACK_TYPES(bool, std::int32_t, std::int64_t), oneapi::mkl::vm::mul, diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index 70a2d860910b..fb154fcabfac 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -417,8 +417,26 @@ size_t operator-(DPNPFuncType lhs, DPNPFuncType rhs); */ typedef struct DPNPFuncData { - DPNPFuncType return_type; /**< return type identifier which expected by the @ref ptr function */ - void* ptr; /**< C++ backend function pointer */ + DPNPFuncData(const DPNPFuncType gen_type, void* gen_ptr, const DPNPFuncType type_no_fp64, void* ptr_no_fp64) + : return_type(gen_type) + , ptr(gen_ptr) + , return_type_no_fp64(type_no_fp64) + , ptr_no_fp64(ptr_no_fp64) + { + } + DPNPFuncData(const DPNPFuncType gen_type, void* gen_ptr) + : DPNPFuncData(gen_type, gen_ptr, DPNPFuncType::DPNP_FT_NONE, nullptr) + { + } + DPNPFuncData() + : DPNPFuncData(DPNPFuncType::DPNP_FT_NONE, nullptr) + { + } + + DPNPFuncType return_type; /**< return type identifier which expected by the @ref ptr function */ + void* ptr; /**< C++ backend function pointer */ + DPNPFuncType return_type_no_fp64; /**< alternative return type identifier when no fp64 support by device */ + void* ptr_no_fp64; /**< alternative C++ backend function pointer when no fp64 support by device */ } DPNPFuncData_t; /** diff --git a/dpnp/backend/kernels/dpnp_krnl_elemwise.cpp b/dpnp/backend/kernels/dpnp_krnl_elemwise.cpp index 057e0805db6a..5133473d3935 100644 --- a/dpnp/backend/kernels/dpnp_krnl_elemwise.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_elemwise.cpp @@ -1029,18 +1029,42 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap) \ if (start + static_cast(vec_sz) * max_sg_size < result_size) \ { \ - sycl::vec<_DataType_input1, vec_sz> x1 = \ - sg.load(sycl::multi_ptr<_DataType_input1, global_space>(&input1_data[start])); \ - sycl::vec<_DataType_input2, vec_sz> x2 = \ - sg.load(sycl::multi_ptr<_DataType_input2, global_space>(&input2_data[start])); \ + using input1_ptrT = sycl::multi_ptr<_DataType_input1, global_space>; \ + using input2_ptrT = sycl::multi_ptr<_DataType_input2, global_space>; \ + using result_ptrT = sycl::multi_ptr<_DataType_output, global_space>; \ + \ sycl::vec<_DataType_output, vec_sz> res_vec; \ \ - if constexpr (both_types_are_same<_DataType_input1, _DataType_input2, __vec_types__>) \ + if constexpr (both_types_are_any_of<_DataType_input1, _DataType_input2, __vec_types__>) \ { \ - res_vec = __vec_operation__; \ + if constexpr (both_types_are_same<_DataType_input1, _DataType_input2, _DataType_output>) \ + { \ + sycl::vec<_DataType_input1, vec_sz> x1 = \ + sg.load(input1_ptrT(&input1_data[start])); \ + sycl::vec<_DataType_input2, vec_sz> x2 = \ + sg.load(input2_ptrT(&input2_data[start])); \ + \ + res_vec = __vec_operation__; \ + } \ + else /* input types don't match result type, so explicit casting is required */ \ + { \ + sycl::vec<_DataType_output, vec_sz> x1 = \ + dpnp_vec_cast<_DataType_output, _DataType_input1, vec_sz>( \ + sg.load(input1_ptrT(&input1_data[start]))); \ + sycl::vec<_DataType_output, vec_sz> x2 = \ + dpnp_vec_cast<_DataType_output, _DataType_input2, vec_sz>( \ + sg.load(input2_ptrT(&input2_data[start]))); \ + \ + res_vec = __vec_operation__; \ + } \ } \ else \ { \ + sycl::vec<_DataType_input1, vec_sz> x1 = \ + sg.load(input1_ptrT(&input1_data[start])); \ + sycl::vec<_DataType_input2, vec_sz> x2 = \ + sg.load(input2_ptrT(&input2_data[start])); \ + \ for (size_t k = 0; k < vec_sz; ++k) \ { \ const _DataType_output input1_elem = x1[k]; \ @@ -1048,7 +1072,7 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap) res_vec[k] = __operation__; \ } \ } \ - sg.store(sycl::multi_ptr<_DataType_output, global_space>(&result[start]), res_vec); \ + sg.store(result_ptrT(&result[start]), res_vec); \ } \ else \ { \ @@ -1173,6 +1197,47 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap) #include +template +static constexpr DPNPFuncType get_divide_res_type() +{ + constexpr auto widest_type = populate_func_types(); + constexpr auto shortes_type = (widest_type == FT1) ? FT2 : FT1; + + if constexpr (widest_type == DPNPFuncType::DPNP_FT_CMPLX128 || widest_type == DPNPFuncType::DPNP_FT_DOUBLE) + { + return widest_type; + } + else if constexpr (widest_type == DPNPFuncType::DPNP_FT_CMPLX64) + { + if constexpr (shortes_type == DPNPFuncType::DPNP_FT_DOUBLE) + { + return DPNPFuncType::DPNP_FT_CMPLX128; + } + else if constexpr (has_fp64::value && + (shortes_type == DPNPFuncType::DPNP_FT_INT || shortes_type == DPNPFuncType::DPNP_FT_LONG)) + { + return DPNPFuncType::DPNP_FT_CMPLX128; + } + } + else if constexpr (widest_type == DPNPFuncType::DPNP_FT_FLOAT) + { + if constexpr (has_fp64::value && + (shortes_type == DPNPFuncType::DPNP_FT_INT || shortes_type == DPNPFuncType::DPNP_FT_LONG)) + { + return DPNPFuncType::DPNP_FT_DOUBLE; + } + } + else if constexpr (has_fp64::value) + { + return DPNPFuncType::DPNP_FT_DOUBLE; + } + else + { + return DPNPFuncType::DPNP_FT_FLOAT; + } + return widest_type; +} + template static void func_map_elemwise_2arg_3type_core(func_map_t& fmap) { @@ -1194,6 +1259,16 @@ static void func_map_elemwise_2arg_3type_core(func_map_t& fmap) func_type_map_t::find_type, func_type_map_t::find_type>}), ...); + ((fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][FT1][FTs] = + {get_divide_res_type(), + (void*)dpnp_divide_c_ext()>, + func_type_map_t::find_type, + func_type_map_t::find_type>, + get_divide_res_type(), + (void*)dpnp_divide_c_ext()>, + func_type_map_t::find_type, + func_type_map_t::find_type>}), + ...); } template @@ -1402,39 +1477,6 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap) fmap[DPNPFuncName::DPNP_FN_DIVIDE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_divide_c_default}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_INT] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_LNG] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_FLT] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_INT][eft_DBL] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_INT] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_LNG] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_FLT] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_LNG][eft_DBL] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_INT] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_LNG] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_FLT] = {eft_FLT, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_FLT][eft_DBL] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_INT] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_LNG] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_FLT] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][eft_DBL][eft_DBL] = {eft_DBL, - (void*)dpnp_divide_c_ext}; - fmap[DPNPFuncName::DPNP_FN_FMOD][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_fmod_c_default}; fmap[DPNPFuncName::DPNP_FN_FMOD][eft_INT][eft_LNG] = {eft_LNG, diff --git a/dpnp/backend/src/dpnp_fptr.hpp b/dpnp/backend/src/dpnp_fptr.hpp index 4cb664858319..742e6dff3783 100644 --- a/dpnp/backend/src/dpnp_fptr.hpp +++ b/dpnp/backend/src/dpnp_fptr.hpp @@ -35,6 +35,8 @@ #include #include +#include + #include /** @@ -116,6 +118,31 @@ static constexpr DPNPFuncType populate_func_types() return (FT1 < FT2) ? FT2 : FT1; } +/** + * @brief A helper function to cast SYCL vector between types. + */ +template +static auto dpnp_vec_cast_impl(const Vec& v, std::index_sequence) +{ + return Op{v[I]...}; +} + +/** + * @brief A casting function for SYCL vector. + * + * @tparam dstT A result type upon casting. + * @tparam srcT An incoming type of the vector. + * @tparam N A number of elements with the vector. + * @tparam Indices A sequence of integers + * @param s An incoming SYCL vector to cast. + * @return SYCL vector casted to desctination type. + */ +template > +static auto dpnp_vec_cast(const sycl::vec& s) +{ + return dpnp_vec_cast_impl, sycl::vec>(s, Indices{}); +} + /** * Removes parentheses for a passed list of types separated by comma. * It's intended to be used in operations macro. @@ -142,6 +169,12 @@ struct are_same : std::conjunction...> {}; template constexpr auto both_types_are_same = std::conjunction_v, are_same>; +/** + * A template constat to check if both types T1 and T2 match any type from Ts. + */ +template +constexpr auto both_types_are_any_of = std::conjunction_v, is_any>; + /** * A template constat to check if both types T1 and T2 don't match any type from Ts sequence. */ diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index 485e8adb1a66..65e07a9c7046 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -374,6 +374,8 @@ cdef extern from "dpnp_iface_fptr.hpp": struct DPNPFuncData: DPNPFuncType return_type void * ptr + DPNPFuncType return_type_no_fp64 + void *ptr_no_fp64 DPNPFuncData get_dpnp_function_ptr(DPNPFuncName name, DPNPFuncType first_type, DPNPFuncType second_type) except + diff --git a/dpnp/dpnp_algo/dpnp_algo.pyx b/dpnp/dpnp_algo/dpnp_algo.pyx index aaa7334e18a8..f12707ccc761 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pyx +++ b/dpnp/dpnp_algo/dpnp_algo.pyx @@ -481,8 +481,6 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out_strides(DPNPFuncName fptr_name, # get the FPTR data structure cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(fptr_name, x1_c_type, x2_c_type) - result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type) - # Create result array cdef shape_type_c x1_shape = x1_obj.shape @@ -495,15 +493,26 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out_strides(DPNPFuncName fptr_name, result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(x1_obj, x2_obj) + # get FPTR function and return type + cdef fptr_2in_1out_strides_t func = NULL + cdef DPNPFuncType return_type = DPNP_FT_NONE + if fptr_name != DPNP_FN_DIVIDE_EXT or result_sycl_device.has_aspect_fp64: + return_type = kernel_data.return_type + func = < fptr_2in_1out_strides_t > kernel_data.ptr + else: + return_type = kernel_data.return_type_no_fp64 + func = < fptr_2in_1out_strides_t > kernel_data.ptr_no_fp64 + if out is None: """ Create result array with type given by FPTR data """ result = utils.create_output_descriptor(result_shape, - kernel_data.return_type, + return_type, None, device=result_sycl_device, usm_type=result_usm_type, sycl_queue=result_sycl_queue) else: + result_type = dpnp_DPNPFuncType_to_dtype(< size_t > return_type) if out.dtype != result_type: utils.checker_throw_value_error(func_name, 'out.dtype', out.dtype, result_type) if out.shape != result_shape: @@ -517,11 +526,10 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out_strides(DPNPFuncName fptr_name, result_obj = result.get_array() - cdef c_dpctl.SyclQueue q = result_obj.sycl_queue + cdef c_dpctl.SyclQueue q = < c_dpctl.SyclQueue > result_obj.sycl_queue cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() """ Call FPTR function """ - cdef fptr_2in_1out_strides_t func = kernel_data.ptr cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, result.get_data(), result.size, diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 64886de23c02..feff53288cfd 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -544,55 +544,66 @@ def diff(x1, n=1, axis=-1, prepend=numpy._NoValue, append=numpy._NoValue): return call_origin(numpy.diff, x1, n=n, axis=axis, prepend=prepend, append=append) -def divide(x1, x2, dtype=None, out=None, where=True, **kwargs): +def divide(x1, + x2, + /, + out=None, + *, + where=True, + dtype=None, + subok=True, + **kwargs): """ Divide arguments element-wise. For full documentation refer to :obj:`numpy.divide`. + Returns + ------- + y : dpnp.ndarray + The quotient ``x1/x2``, element-wise. + Limitations ----------- - Parameters ``x1`` and ``x2`` are supported as either :obj:`dpnp.ndarray` or scalar. - Parameters ``dtype``, ``out`` and ``where`` are supported with their default values. + Parameters `x1` and `x2` are supported as either :class:`dpnp.ndarray` or scalar, + but not both (at least either `x1` or `x2` should be as :class:`dpnp.ndarray`). + Parameters `out`, `where`, `dtype` and `subok` are supported with their default values. Keyword arguments ``kwargs`` are currently unsupported. - Otherwise the functions will be executed sequentially on CPU. + Otherwise the function will be executed sequentially on CPU. Input array data types are limited by supported DPNP :ref:`Data types`. Examples -------- - >>> import dpnp as np - >>> result = np.divide(np.array([1, -2, 6, -9]), np.array([-2, -2, -2, -2])) - >>> [x for x in result] + >>> import dpnp as dp + >>> result = dp.divide(dp.array([1, -2, 6, -9]), dp.array([-2, -2, -2, -2])) + >>> print(result) [-0.5, 1.0, -3.0, 4.5] """ - x1_is_scalar = dpnp.isscalar(x1) - x2_is_scalar = dpnp.isscalar(x2) - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False) - x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False) + if out is not None: + pass + elif where is not True: + pass + elif dtype is not None: + pass + elif subok is not True: + pass + elif dpnp.isscalar(x1) and dpnp.isscalar(x2): + # at least either x1 or x2 has to be an array + pass + else: + # get USM type and queue to copy scalar from the host memory into a USM allocation + usm_type, queue = get_usm_allocations([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else (None, None) - if x1_desc and x2_desc and not kwargs: - if not x1_desc and not x1_is_scalar: - pass - elif not x2_desc and not x2_is_scalar: - pass - elif x1_is_scalar and x2_is_scalar: - pass - elif x1_desc and x1_desc.ndim == 0: - pass - elif x2_desc and x2_desc.ndim == 0: - pass - elif dtype is not None: - pass - elif out is not None: - pass - elif not where: - pass - else: + x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False, + alloc_usm_type=usm_type, alloc_queue=queue) + x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False, + alloc_usm_type=usm_type, alloc_queue=queue) + if x1_desc and x2_desc: return dpnp_divide(x1_desc, x2_desc, dtype=dtype, out=out, where=where).get_pyobj() - return call_origin(numpy.divide, x1, x2, dtype=dtype, out=out, where=where, **kwargs) + return call_origin(numpy.divide, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs) def ediff1d(x1, to_end=None, to_begin=None): diff --git a/tests/conftest.py b/tests/conftest.py index 78d3180bac08..22276f125f26 100755 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # ***************************************************************************** -# Copyright (c) 2016-2020, Intel Corporation +# Copyright (c) 2016-2023, Intel Corporation # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -77,3 +77,22 @@ def pytest_collection_modifyitems(config, items): @pytest.fixture def allow_fall_back_on_numpy(monkeypatch): monkeypatch.setattr(dpnp.config, '__DPNP_RAISE_EXCEPION_ON_NUMPY_FALLBACK__', 0) + +@pytest.fixture +def suppress_divide_numpy_warnings(): + # divide: treatment for division by zero (infinite result obtained from finite numbers) + old_settings = numpy.seterr(divide='ignore') + yield + numpy.seterr(**old_settings) # reset to default + +@pytest.fixture +def suppress_invalid_numpy_warnings(): + # invalid: treatment for invalid floating-point operation + # (result is not an expressible number, typically indicates that a NaN was produced) + old_settings = numpy.seterr(invalid='ignore') + yield + numpy.seterr(**old_settings) # reset to default + +@pytest.fixture +def suppress_divide_invalid_numpy_warnings(suppress_divide_numpy_warnings, suppress_invalid_numpy_warnings): + yield diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index 25d1fd1bc0f5..2f0334077a06 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -769,9 +769,7 @@ tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticModf::test_m tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_10_{name='remainder', nargs=2}::test_raises_with_numpy_input tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_11_{name='mod', nargs=2}::test_raises_with_numpy_input tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_1_{name='angle', nargs=1}::test_raises_with_numpy_input -tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_4_{name='divide', nargs=2}::test_raises_with_numpy_input tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_5_{name='power', nargs=2}::test_raises_with_numpy_input -tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_7_{name='true_divide', nargs=2}::test_raises_with_numpy_input tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_8_{name='floor_divide', nargs=2}::test_raises_with_numpy_input tests/third_party/cupy/math_tests/test_explog.py::TestExplog::test_logaddexp diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index 34d1795cc98d..e6598904e16f 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -988,9 +988,7 @@ tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticBinary2_para tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_10_{name='remainder', nargs=2}::test_raises_with_numpy_input tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_11_{name='mod', nargs=2}::test_raises_with_numpy_input tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_1_{name='angle', nargs=1}::test_raises_with_numpy_input -tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_4_{name='divide', nargs=2}::test_raises_with_numpy_input tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_5_{name='power', nargs=2}::test_raises_with_numpy_input -tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_7_{name='true_divide', nargs=2}::test_raises_with_numpy_input tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_8_{name='floor_divide', nargs=2}::test_raises_with_numpy_input tests/third_party/cupy/math_tests/test_explog.py::TestExplog::test_logaddexp diff --git a/tests/test_linalg.py b/tests/test_linalg.py index ac8392d15384..d9784a41558f 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -1,9 +1,15 @@ import pytest +from .helper import get_all_dtypes import dpnp as inp import dpctl + import numpy +from numpy.testing import ( + assert_allclose, + assert_array_equal +) def vvsort(val, vec, size, xp): @@ -49,7 +55,7 @@ def test_cholesky(array): ia = inp.array(a) result = inp.linalg.cholesky(ia) expected = numpy.linalg.cholesky(a) - numpy.testing.assert_array_equal(expected, result) + assert_array_equal(expected, result) @pytest.mark.parametrize("arr", @@ -63,7 +69,7 @@ def test_cond(arr, p): ia = inp.array(a) result = inp.linalg.cond(ia, p) expected = numpy.linalg.cond(a, p) - numpy.testing.assert_array_equal(expected, result) + assert_array_equal(expected, result) @pytest.mark.parametrize("array", @@ -82,13 +88,11 @@ def test_det(array): ia = inp.array(a) result = inp.linalg.det(ia) expected = numpy.linalg.det(a) - numpy.testing.assert_allclose(expected, result) + assert_allclose(expected, result) @pytest.mark.usefixtures("allow_fall_back_on_numpy") -@pytest.mark.parametrize("type", - [numpy.float64, numpy.float32, numpy.int64, numpy.int32], - ids=['float64', 'float32', 'int64', 'int32']) +@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True)) @pytest.mark.parametrize("size", [2, 4, 8, 16, 300]) def test_eig_arange(type, size): @@ -115,21 +119,19 @@ def test_eig_arange(type, size): if np_vec[0, i] * dpnp_vec[0, i] < 0: np_vec[:, i] = -np_vec[:, i] - numpy.testing.assert_array_equal(symm_orig, symm) - numpy.testing.assert_array_equal(dpnp_symm_orig, dpnp_symm) + assert_array_equal(symm_orig, symm) + assert_array_equal(dpnp_symm_orig, dpnp_symm) assert (dpnp_val.dtype == np_val.dtype) assert (dpnp_vec.dtype == np_vec.dtype) assert (dpnp_val.shape == np_val.shape) assert (dpnp_vec.shape == np_vec.shape) - numpy.testing.assert_allclose(dpnp_val, np_val, rtol=1e-05, atol=1e-05) - numpy.testing.assert_allclose(dpnp_vec, np_vec, rtol=1e-05, atol=1e-05) + assert_allclose(dpnp_val, np_val, rtol=1e-05, atol=1e-05) + assert_allclose(dpnp_vec, np_vec, rtol=1e-05, atol=1e-05) -@pytest.mark.parametrize("type", - [numpy.float64, numpy.float32, numpy.int64, numpy.int32], - ids=['float64', 'float32', 'int64', 'int32']) +@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True)) def test_eigvals(type): if dpctl.get_current_device_type() != dpctl.device_type.gpu: pytest.skip("eigvals function doesn\'t work on CPU: https://github.com/IntelPython/dpnp/issues/1005") @@ -144,12 +146,10 @@ def test_eigvals(type): ia = inp.array(a) result = inp.linalg.eigvals(ia) expected = numpy.linalg.eigvals(a) - numpy.testing.assert_allclose(expected, result, atol=0.5) + assert_allclose(expected, result, atol=0.5) -@pytest.mark.parametrize("type", - [numpy.float64, numpy.float32, numpy.int64, numpy.int32], - ids=['float64', 'float32', 'int64', 'int32']) +@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True)) @pytest.mark.parametrize("array", [[[1., 2.], [3., 4.]], [[0, 1, 2], [3, 2, -1], [4, -2, 3]]], ids=['[[1., 2.], [3., 4.]]', '[[0, 1, 2], [3, 2, -1], [4, -2, 3]]']) @@ -158,12 +158,10 @@ def test_inv(type, array): ia = inp.array(a) result = inp.linalg.inv(ia) expected = numpy.linalg.inv(a) - numpy.testing.assert_allclose(expected, result) + assert_allclose(expected, result) -@pytest.mark.parametrize("type", - [numpy.float64, numpy.float32, numpy.int64, numpy.int32], - ids=['float64', 'float32', 'int64', 'int32']) +@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True, no_none=True)) @pytest.mark.parametrize("array", [[0, 0], [0, 1], [1, 2], [[0, 0], [0, 0]], [[1, 2], [1, 2]], [[1, 2], [3, 4]]], ids=['[0, 0]', '[0, 1]', '[1, 2]', '[[0, 0], [0, 0]]', '[[1, 2], [1, 2]]', '[[1, 2], [3, 4]]']) @@ -177,10 +175,11 @@ def test_matrix_rank(type, tol, array): result = inp.linalg.matrix_rank(ia, tol=tol) expected = numpy.linalg.matrix_rank(a, tol=tol) - numpy.testing.assert_allclose(expected, result) + assert_allclose(expected, result) @pytest.mark.usefixtures("allow_fall_back_on_numpy") +@pytest.mark.usefixtures("suppress_divide_numpy_warnings") @pytest.mark.parametrize("array", [[7], [1, 2], [1, 0]], ids=['[7]', '[1, 2]', '[1, 0]']) @@ -195,7 +194,7 @@ def test_norm1(array, ord, axis): ia = inp.array(a) result = inp.linalg.norm(ia, ord=ord, axis=axis) expected = numpy.linalg.norm(a, ord=ord, axis=axis) - numpy.testing.assert_allclose(expected, result) + assert_allclose(expected, result) @pytest.mark.usefixtures("allow_fall_back_on_numpy") @@ -213,7 +212,7 @@ def test_norm2(array, ord, axis): ia = inp.array(a) result = inp.linalg.norm(ia, ord=ord, axis=axis) expected = numpy.linalg.norm(a, ord=ord, axis=axis) - numpy.testing.assert_array_equal(expected, result) + assert_array_equal(expected, result) @pytest.mark.usefixtures("allow_fall_back_on_numpy") @@ -231,13 +230,11 @@ def test_norm3(array, ord, axis): ia = inp.array(a) result = inp.linalg.norm(ia, ord=ord, axis=axis) expected = numpy.linalg.norm(a, ord=ord, axis=axis) - numpy.testing.assert_array_equal(expected, result) + assert_array_equal(expected, result) @pytest.mark.usefixtures("allow_fall_back_on_numpy") -@pytest.mark.parametrize("type", - [numpy.float64, numpy.float32, numpy.int64, numpy.int32], - ids=['float64', 'float32', 'int64', 'int32']) +@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True)) @pytest.mark.parametrize("shape", [(2, 2), (3, 4), (5, 3), (16, 16)], ids=['(2,2)', '(3,4)', '(5,3)', '(16,16)']) @@ -262,7 +259,7 @@ def test_qr(type, shape, mode): tol = 1e-11 # check decomposition - numpy.testing.assert_allclose(ia, numpy.dot(inp.asnumpy(dpnp_q), inp.asnumpy(dpnp_r)), rtol=tol, atol=tol) + assert_allclose(ia, numpy.dot(inp.asnumpy(dpnp_q), inp.asnumpy(dpnp_r)), rtol=tol, atol=tol) # NP change sign for comparison ncols = min(a.shape[0], a.shape[1]) @@ -273,14 +270,12 @@ def test_qr(type, shape, mode): np_r[i, :] = -np_r[i, :] if numpy.any(numpy.abs(np_r[i, :]) > tol): - numpy.testing.assert_allclose(inp.asnumpy(dpnp_q)[:, i], np_q[:, i], rtol=tol, atol=tol) + assert_allclose(inp.asnumpy(dpnp_q)[:, i], np_q[:, i], rtol=tol, atol=tol) - numpy.testing.assert_allclose(dpnp_r, np_r, rtol=tol, atol=tol) + assert_allclose(dpnp_r, np_r, rtol=tol, atol=tol) -@pytest.mark.parametrize("type", - [numpy.float64, numpy.float32, numpy.int64, numpy.int32], - ids=['float64', 'float32', 'int64', 'int32']) +@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True)) @pytest.mark.parametrize("shape", [(2, 2), (3, 4), (5, 3), (16, 16)], ids=['(2,2)', '(3,4)', '(5,3)', '(16,16)']) @@ -309,10 +304,10 @@ def test_svd(type, shape): dpnp_diag_s[i, i] = dpnp_s[i] # check decomposition - numpy.testing.assert_allclose(ia, inp.dot(dpnp_u, inp.dot(dpnp_diag_s, dpnp_vt)), rtol=tol, atol=tol) + assert_allclose(ia, inp.dot(dpnp_u, inp.dot(dpnp_diag_s, dpnp_vt)), rtol=tol, atol=tol) # compare singular values - # numpy.testing.assert_allclose(dpnp_s, np_s, rtol=tol, atol=tol) + # assert_allclose(dpnp_s, np_s, rtol=tol, atol=tol) # change sign of vectors for i in range(min(shape[0], shape[1])): @@ -322,5 +317,5 @@ def test_svd(type, shape): # compare vectors for non-zero values for i in range(numpy.count_nonzero(np_s > tol)): - numpy.testing.assert_allclose(inp.asnumpy(dpnp_u)[:, i], np_u[:, i], rtol=tol, atol=tol) - numpy.testing.assert_allclose(inp.asnumpy(dpnp_vt)[i, :], np_vt[i, :], rtol=tol, atol=tol) + assert_allclose(inp.asnumpy(dpnp_u)[:, i], np_u[:, i], rtol=tol, atol=tol) + assert_allclose(inp.asnumpy(dpnp_vt)[i, :], np_vt[i, :], rtol=tol, atol=tol) diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 6f7ee58c0380..78f628908337 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -66,7 +66,7 @@ def test_diff(array): @pytest.mark.parametrize("dtype1", get_all_dtypes()) @pytest.mark.parametrize("dtype2", get_all_dtypes()) @pytest.mark.parametrize("func", - ['add', 'multiply', 'subtract']) + ['add', 'multiply', 'subtract', 'divide']) @pytest.mark.parametrize("data", [[[1, 2], [3, 4]]], ids=['[[1, 2], [3, 4]]']) @@ -132,8 +132,7 @@ def test_arctan2(self, dtype, lhs, rhs): def test_copysign(self, dtype, lhs, rhs): self._test_mathematical('copysign', dtype, lhs, rhs) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") - @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True)) + @pytest.mark.parametrize("dtype", get_all_dtypes()) def test_divide(self, dtype, lhs, rhs): self._test_mathematical('divide', dtype, lhs, rhs) @@ -181,12 +180,13 @@ def test_subtract(self, dtype, lhs, rhs): self._test_mathematical('subtract', dtype, lhs, rhs) +@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings") @pytest.mark.parametrize("val_type", [bool, int, float], ids=['bool', 'int', 'float']) @pytest.mark.parametrize("data_type", get_all_dtypes()) @pytest.mark.parametrize("func", - ['add', 'multiply', 'subtract']) + ['add', 'multiply', 'subtract', 'divide']) @pytest.mark.parametrize("val", [0, 1, 5], ids=['0', '1', '5']) @@ -216,11 +216,11 @@ def test_op_with_scalar(array, val, func, data_type, val_type): else: result = getattr(dpnp, func)(dpnp_a, val_) expected = getattr(numpy, func)(np_a, val_) - assert_array_equal(result, expected) + assert_allclose(result, expected) result = getattr(dpnp, func)(val_, dpnp_a) expected = getattr(numpy, func)(val_, np_a) - assert_array_equal(result, expected) + assert_allclose(result, expected) @pytest.mark.parametrize("shape", @@ -262,6 +262,19 @@ def test_subtract_scalar(shape, dtype): assert_allclose(result, expected) +@pytest.mark.parametrize("shape", + [(), (3, 2)], + ids=['()', '(3, 2)']) +@pytest.mark.parametrize("dtype", get_all_dtypes()) +def test_divide_scalar(shape, dtype): + np_a = numpy.ones(shape, dtype=dtype) + dpnp_a = dpnp.ones(shape, dtype=dtype) + + result = 0.5 / dpnp_a / 1.7 + expected = 0.5 / np_a / 1.7 + assert_allclose(result, expected) + + @pytest.mark.usefixtures("allow_fall_back_on_numpy") @pytest.mark.parametrize("array", [[1, 2, 3, 4, 5], [1, 2, numpy.nan, 4, 5], @@ -442,7 +455,6 @@ def test_cross_3x3(self, x1, x2, axisa, axisb, axisc, axis): assert_array_equal(expected, result) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestGradient: @pytest.mark.parametrize("array", [[2, 3, 6, 8, 4, 9], @@ -456,6 +468,7 @@ def test_gradient_y1(self, array): expected = numpy.gradient(np_y) assert_array_equal(expected, result) + @pytest.mark.usefixtures("allow_fall_back_on_numpy") @pytest.mark.parametrize("array", [[2, 3, 6, 8, 4, 9], [3., 4., 7.5, 9.], [2, 6, 8, 10]]) diff --git a/tests/test_strides.py b/tests/test_strides.py index 3c0d86a44a5a..02e8c8689757 100644 --- a/tests/test_strides.py +++ b/tests/test_strides.py @@ -37,6 +37,7 @@ def test_strides(func_name, dtype): assert_allclose(expected, result) +@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings") @pytest.mark.parametrize("func_name", ["arccos", "arccosh", "arcsin", "arcsinh", "arctan", "arctanh", "cbrt", "ceil", "copy", "cos", "cosh", "conjugate", "degrees", "ediff1d", "exp", "exp2", "expm1", "fabs", "floor", "log", diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 46531cb78aae..1a33a1d655dd 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -50,6 +50,19 @@ def test_coerced_usm_types_subtract(usm_type_x, usm_type_y): assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) +@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types) +@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types) +def test_coerced_usm_types_divide(usm_type_x, usm_type_y): + x = dp.arange(120, usm_type = usm_type_x) + y = dp.arange(120, usm_type = usm_type_y) + + z = 2 / x / y / 1.5 + + assert x.usm_type == usm_type_x + assert y.usm_type == usm_type_y + assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) + + @pytest.mark.parametrize( "func, args", [ diff --git a/tests/third_party/cupy/math_tests/test_arithmetic.py b/tests/third_party/cupy/math_tests/test_arithmetic.py index 21068ece8749..027722d8bef2 100644 --- a/tests/third_party/cupy/math_tests/test_arithmetic.py +++ b/tests/third_party/cupy/math_tests/test_arithmetic.py @@ -146,27 +146,35 @@ def check_binary(self, xp): y = y.astype(numpy.complex64) # NumPy returns an output array of another type than DPNP when input ones have diffrent types. - if self.name in ('add', 'multiply', 'subtract') and xp is cupy and dtype1 != dtype2 and not self.use_dtype: + if xp is cupy and dtype1 != dtype2 and not self.use_dtype: is_array_arg1 = not xp.isscalar(arg1) is_array_arg2 = not xp.isscalar(arg2) is_int_float = lambda _x, _y: numpy.issubdtype(_x, numpy.integer) and numpy.issubdtype(_y, numpy.floating) is_same_type = lambda _x, _y, _type: numpy.issubdtype(_x, _type) and numpy.issubdtype(_y, _type) - if is_array_arg1 and is_array_arg2: - # If both inputs are arrays where one is of floating type and another - integer, - # NumPy will return an output array of always "float64" type, - # while DPNP will return the array of a wider type from the input arrays. - if is_int_float(dtype1, dtype2) or is_int_float(dtype2, dtype1): - y = y.astype(numpy.float64) - elif is_same_type(dtype1, dtype2, numpy.floating) or is_same_type(dtype1, dtype2, numpy.integer): - # If one input is an array and another - scalar, - # NumPy will return an output array of the same type as the inpupt array has, - # while DPNP will return the array of a wider type from the inputs (considering both array and scalar). - if is_array_arg1 and not is_array_arg2: - y = y.astype(dtype1) - elif is_array_arg2 and not is_array_arg1: - y = y.astype(dtype2) + if self.name in ('add', 'multiply', 'subtract'): + if is_array_arg1 and is_array_arg2: + # If both inputs are arrays where one is of floating type and another - integer, + # NumPy will return an output array of always "float64" type, + # while DPNP will return the array of a wider type from the input arrays. + if is_int_float(dtype1, dtype2) or is_int_float(dtype2, dtype1): + y = y.astype(numpy.float64) + elif is_same_type(dtype1, dtype2, numpy.floating) or is_same_type(dtype1, dtype2, numpy.integer): + # If one input is an array and another - scalar, + # NumPy will return an output array of the same type as the inpupt array has, + # while DPNP will return the array of a wider type from the inputs (considering both array and scalar). + if is_array_arg1 and not is_array_arg2: + y = y.astype(dtype1) + elif is_array_arg2 and not is_array_arg1: + y = y.astype(dtype2) + elif self.name in ('divide', 'true_divide'): + # If one input is an array of float32 and another - an integer or floating scalar, + # NumPy will return an output array of float32, while DPNP will return the array of float64, + # since NumPy would use the same float64 type when instead of scalar here is array of integer of floating type. + if not (is_array_arg1 and is_array_arg2): + if (is_array_arg1 and arg1.dtype == numpy.float32) ^ (is_array_arg2 and arg2.dtype == numpy.float32): + y = y.astype(numpy.float32) # NumPy returns different values (nan/inf) on division by zero # depending on the architecture. diff --git a/tests/third_party/cupy/statistics_tests/test_meanvar.py b/tests/third_party/cupy/statistics_tests/test_meanvar.py index aea22d02c511..60d3413b0daa 100644 --- a/tests/third_party/cupy/statistics_tests/test_meanvar.py +++ b/tests/third_party/cupy/statistics_tests/test_meanvar.py @@ -89,7 +89,6 @@ def test_median_axis_sequence(self, xp, dtype): return xp.median(a, self.axis, keepdims=self.keepdims) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.gpu class TestAverage(unittest.TestCase): @@ -101,12 +100,14 @@ def test_average_all(self, xp, dtype): a = testing.shaped_arange((2, 3), xp, dtype) return xp.average(a) + @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_average_axis(self, xp, dtype): a = testing.shaped_arange((2, 3, 4), xp, dtype) return xp.average(a, axis=1) + @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_average_weights(self, xp, dtype): @@ -114,6 +115,7 @@ def test_average_weights(self, xp, dtype): w = testing.shaped_arange((2, 3), xp, dtype) return xp.average(a, weights=w) + @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_average_axis_weights(self, xp, dtype): @@ -132,6 +134,7 @@ def check_returned(self, a, axis, weights): testing.assert_allclose(average_cpu, average_gpu) testing.assert_allclose(sum_weights_cpu, sum_weights_gpu) + @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() def test_returned(self, dtype): a = testing.shaped_arange((2, 3), numpy, dtype) diff --git a/utils/command_build_clib.py b/utils/command_build_clib.py index 95887cc65aaa..d16bab3aec4a 100644 --- a/utils/command_build_clib.py +++ b/utils/command_build_clib.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # ***************************************************************************** -# Copyright (c) 2016-2022, Intel Corporation +# Copyright (c) 2016-2023, Intel Corporation # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -63,7 +63,7 @@ # default variables (for Linux) _project_compiler = "icpx" _project_linker = "icpx" -_project_cmplr_flag_sycl_devel = ["-fsycl-device-code-split=per_kernel", "-fno-approx-func"] +_project_cmplr_flag_sycl_devel = ["-fsycl-device-code-split=per_kernel", "-fno-approx-func", "-fno-finite-math-only"] _project_cmplr_flag_sycl = ["-fsycl"] _project_cmplr_flag_stdcpp_static = [] # This brakes TBB ["-static-libstdc++", "-static-libgcc"] _project_cmplr_flag_compatibility = ["-Wl,--enable-new-dtags"]