Skip to content

Commit aae4cff

Browse files
authored
Merge branch 'master' into implement_tile
2 parents 243b046 + a73d959 commit aae4cff

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1502
-709
lines changed

.github/workflows/conda-package.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,13 @@ env:
3636
third_party/cupy/manipulation_tests/test_join.py
3737
third_party/cupy/manipulation_tests/test_rearrange.py
3838
third_party/cupy/manipulation_tests/test_transpose.py
39+
third_party/cupy/math_tests/test_arithmetic.py
3940
third_party/cupy/math_tests/test_explog.py
41+
third_party/cupy/math_tests/test_floating.py
42+
third_party/cupy/math_tests/test_hyperbolic.py
43+
third_party/cupy/math_tests/test_matmul.py
4044
third_party/cupy/math_tests/test_misc.py
45+
third_party/cupy/math_tests/test_rounding.py
4146
third_party/cupy/math_tests/test_trigonometric.py
4247
third_party/cupy/sorting_tests/test_sort.py
4348
VER_JSON_NAME: 'version.json'

dpnp/backend/extensions/vm/abs.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event abs_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::AbsOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::abs(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/acos.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event acos_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::AcosOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::acos(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/acosh.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event acosh_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::AcoshOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::acosh(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/add.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ sycl::event add_contig_impl(sycl::queue exec_q,
5050

5151
const T *a = reinterpret_cast<const T *>(in_a);
5252
const T *b = reinterpret_cast<const T *>(in_b);
53-
T *y = reinterpret_cast<T *>(out_y);
53+
using resTy = typename types::AddOutputType<T>::value_type;
54+
resTy *y = reinterpret_cast<resTy *>(out_y);
5455

5556
return mkl_vm::add(exec_q,
5657
n, // number of elements to be calculated

dpnp/backend/extensions/vm/asin.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event asin_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::AsinOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::asin(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/asinh.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event asinh_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::AsinhOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::asinh(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/atan.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event atan_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::AtanOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::atan(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/atan2.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ sycl::event atan2_contig_impl(sycl::queue exec_q,
5050

5151
const T *a = reinterpret_cast<const T *>(in_a);
5252
const T *b = reinterpret_cast<const T *>(in_b);
53-
T *y = reinterpret_cast<T *>(out_y);
53+
using resTy = typename types::Atan2OutputType<T>::value_type;
54+
resTy *y = reinterpret_cast<resTy *>(out_y);
5455

5556
return mkl_vm::atan2(exec_q,
5657
n, // number of elements to be calculated

dpnp/backend/extensions/vm/atanh.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event atanh_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::AtanhOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::atanh(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/ceil.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event ceil_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::CeilOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::ceil(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/common.hpp

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,8 @@ std::pair<sycl::event, sycl::event>
8282
{
8383
// check type_nums
8484
int src_typenum = src.get_typenum();
85-
int dst_typenum = dst.get_typenum();
86-
8785
auto array_types = dpctl_td_ns::usm_ndarray_types();
8886
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
89-
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
90-
91-
if (src_typeid != dst_typeid) {
92-
throw py::value_error("Input and output arrays have different types.");
93-
}
9487

9588
// check that queues are compatible
9689
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
@@ -155,7 +148,7 @@ std::pair<sycl::event, sycl::event>
155148
throw py::value_error("Input and outpur arrays must be C-contiguous.");
156149
}
157150

158-
auto dispatch_fn = dispatch_vector[dst_typeid];
151+
auto dispatch_fn = dispatch_vector[src_typeid];
159152
if (dispatch_fn == nullptr) {
160153
throw py::value_error("No implementation is defined for ufunc.");
161154
}
@@ -179,16 +172,13 @@ std::pair<sycl::event, sycl::event> binary_ufunc(
179172
// check type_nums
180173
int src1_typenum = src1.get_typenum();
181174
int src2_typenum = src2.get_typenum();
182-
int dst_typenum = dst.get_typenum();
183175

184176
auto array_types = dpctl_td_ns::usm_ndarray_types();
185177
int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
186178
int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
187-
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
188179

189-
if (src1_typeid != src2_typeid || src2_typeid != dst_typeid) {
190-
throw py::value_error(
191-
"Either any of input arrays or output array have different types.");
180+
if (src1_typeid != src2_typeid) {
181+
throw py::value_error("Input arrays have different types.");
192182
}
193183

194184
// check that queues are compatible
@@ -259,7 +249,7 @@ std::pair<sycl::event, sycl::event> binary_ufunc(
259249
throw py::value_error("Input and outpur arrays must be C-contiguous.");
260250
}
261251

262-
auto dispatch_fn = dispatch_vector[dst_typeid];
252+
auto dispatch_fn = dispatch_vector[src1_typeid];
263253
if (dispatch_fn == nullptr) {
264254
throw py::value_error("No implementation is defined for ufunc.");
265255
}
@@ -279,16 +269,8 @@ bool need_to_call_unary_ufunc(sycl::queue exec_q,
279269
{
280270
// check type_nums
281271
int src_typenum = src.get_typenum();
282-
int dst_typenum = dst.get_typenum();
283-
284272
auto array_types = dpctl_td_ns::usm_ndarray_types();
285273
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
286-
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
287-
288-
// types must be the same
289-
if (src_typeid != dst_typeid) {
290-
return false;
291-
}
292274

293275
// OneMKL VM functions perform a copy on host if no double type support
294276
if (!exec_q.get_device().has(sycl::aspect::fp64)) {
@@ -356,7 +338,7 @@ bool need_to_call_unary_ufunc(sycl::queue exec_q,
356338
}
357339

358340
// MKL function is not defined for the type
359-
if (dispatch_vector[dst_typeid] == nullptr) {
341+
if (dispatch_vector[src_typeid] == nullptr) {
360342
return false;
361343
}
362344
return true;
@@ -372,15 +354,13 @@ bool need_to_call_binary_ufunc(sycl::queue exec_q,
372354
// check type_nums
373355
int src1_typenum = src1.get_typenum();
374356
int src2_typenum = src2.get_typenum();
375-
int dst_typenum = dst.get_typenum();
376357

377358
auto array_types = dpctl_td_ns::usm_ndarray_types();
378359
int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
379360
int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
380-
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
381361

382362
// types must be the same
383-
if (src1_typeid != src2_typeid || src2_typeid != dst_typeid) {
363+
if (src1_typeid != src2_typeid) {
384364
return false;
385365
}
386366

@@ -454,7 +434,7 @@ bool need_to_call_binary_ufunc(sycl::queue exec_q,
454434
}
455435

456436
// MKL function is not defined for the type
457-
if (dispatch_vector[dst_typeid] == nullptr) {
437+
if (dispatch_vector[src1_typeid] == nullptr) {
458438
return false;
459439
}
460440
return true;

dpnp/backend/extensions/vm/conj.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event conj_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::ConjOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::conj(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/cos.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event cos_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::CosOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::cos(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/cosh.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ sycl::event cosh_contig_impl(sycl::queue exec_q,
4848
type_utils::validate_type_for_device<T>(exec_q);
4949

5050
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
51+
using resTy = typename types::CoshOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
5253

5354
return mkl_vm::cosh(exec_q,
5455
n, // number of elements to be calculated

dpnp/backend/extensions/vm/div.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ sycl::event div_contig_impl(sycl::queue exec_q,
5050

5151
const T *a = reinterpret_cast<const T *>(in_a);
5252
const T *b = reinterpret_cast<const T *>(in_b);
53-
T *y = reinterpret_cast<T *>(out_y);
53+
using resTy = typename types::DivOutputType<T>::value_type;
54+
resTy *y = reinterpret_cast<resTy *>(out_y);
5455

5556
return mkl_vm::div(exec_q,
5657
n, // number of elements to be calculated

dpnp/backend/extensions/vm/exp.hpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event exp_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
using resTy = typename types::ExpOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
53+
54+
return mkl_vm::exp(exec_q,
55+
n, // number of elements to be calculated
56+
a, // pointer `a` containing input vector of size n
57+
y, // pointer `y` to the output vector of size n
58+
depends);
59+
}
60+
61+
template <typename fnT, typename T>
62+
struct ExpContigFactory
63+
{
64+
fnT get()
65+
{
66+
if constexpr (std::is_same_v<
67+
typename types::ExpOutputType<T>::value_type, void>)
68+
{
69+
return nullptr;
70+
}
71+
else {
72+
return exp_contig_impl<T>;
73+
}
74+
}
75+
};
76+
} // namespace vm
77+
} // namespace ext
78+
} // namespace backend
79+
} // namespace dpnp

0 commit comments

Comments
 (0)