Skip to content

Commit 5bfeb2c

Browse files
committed
use_dpctl_ceil_floor_trunc_in_dpnp
1 parent 1650774 commit 5bfeb2c

15 files changed

+773
-247
lines changed

dpnp/backend/extensions/vm/ceil.hpp

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event ceil_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
T *y = reinterpret_cast<T *>(out_y);
52+
53+
return mkl_vm::ceil(exec_q,
54+
n, // number of elements to be calculated
55+
a, // pointer `a` containing input vector of size n
56+
y, // pointer `y` to the output vector of size n
57+
depends);
58+
}
59+
60+
template <typename fnT, typename T>
61+
struct CeilContigFactory
62+
{
63+
fnT get()
64+
{
65+
if constexpr (std::is_same_v<
66+
typename types::CeilOutputType<T>::value_type, void>)
67+
{
68+
return nullptr;
69+
}
70+
else {
71+
return ceil_contig_impl<T>;
72+
}
73+
}
74+
};
75+
} // namespace vm
76+
} // namespace ext
77+
} // namespace backend
78+
} // namespace dpnp

dpnp/backend/extensions/vm/floor.hpp

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event floor_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
T *y = reinterpret_cast<T *>(out_y);
52+
53+
return mkl_vm::floor(exec_q,
54+
n, // number of elements to be calculated
55+
a, // pointer `a` containing input vector of size n
56+
y, // pointer `y` to the output vector of size n
57+
depends);
58+
}
59+
60+
template <typename fnT, typename T>
61+
struct FloorContigFactory
62+
{
63+
fnT get()
64+
{
65+
if constexpr (std::is_same_v<
66+
typename types::FloorOutputType<T>::value_type, void>)
67+
{
68+
return nullptr;
69+
}
70+
else {
71+
return floor_contig_impl<T>;
72+
}
73+
}
74+
};
75+
} // namespace vm
76+
} // namespace ext
77+
} // namespace backend
78+
} // namespace dpnp

dpnp/backend/extensions/vm/trunc.hpp

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event trunc_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
T *y = reinterpret_cast<T *>(out_y);
52+
53+
return mkl_vm::trunc(exec_q,
54+
n, // number of elements to be calculated
55+
a, // pointer `a` containing input vector of size n
56+
y, // pointer `y` to the output vector of size n
57+
depends);
58+
}
59+
60+
template <typename fnT, typename T>
61+
struct TruncContigFactory
62+
{
63+
fnT get()
64+
{
65+
if constexpr (std::is_same_v<
66+
typename types::TruncOutputType<T>::value_type, void>)
67+
{
68+
return nullptr;
69+
}
70+
else {
71+
return trunc_contig_impl<T>;
72+
}
73+
}
74+
};
75+
} // namespace vm
76+
} // namespace ext
77+
} // namespace backend
78+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

+46
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,21 @@ struct DivOutputType
6868
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
6969
};
7070

71+
/**
72+
* @brief A factory to define pairs of supported types for which
73+
* MKL VM library provides support in oneapi::mkl::vm::ceil<T> function.
74+
*
75+
* @tparam T Type of input vector `a` and of result vector `y`.
76+
*/
77+
template <typename T>
78+
struct CeilOutputType
79+
{
80+
using value_type = typename std::disjunction<
81+
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
82+
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
83+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
84+
};
85+
7186
/**
7287
* @brief A factory to define pairs of supported types for which
7388
* MKL VM library provides support in oneapi::mkl::vm::cos<T> function.
@@ -87,6 +102,21 @@ struct CosOutputType
87102
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
88103
};
89104

105+
/**
106+
* @brief A factory to define pairs of supported types for which
107+
* MKL VM library provides support in oneapi::mkl::vm::floor<T> function.
108+
*
109+
* @tparam T Type of input vector `a` and of result vector `y`.
110+
*/
111+
template <typename T>
112+
struct FloorOutputType
113+
{
114+
using value_type = typename std::disjunction<
115+
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
116+
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
117+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
118+
};
119+
90120
/**
91121
* @brief A factory to define pairs of supported types for which
92122
* MKL VM library provides support in oneapi::mkl::vm::ln<T> function.
@@ -158,6 +188,22 @@ struct SqrtOutputType
158188
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
159189
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
160190
};
191+
192+
/**
193+
* @brief A factory to define pairs of supported types for which
194+
* MKL VM library provides support in oneapi::mkl::vm::trunc<T> function.
195+
*
196+
* @tparam T Type of input vector `a` and of result vector `y`.
197+
*/
198+
template <typename T>
199+
struct TruncOutputType
200+
{
201+
using value_type = typename std::disjunction<
202+
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
203+
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
204+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
205+
};
206+
161207
} // namespace types
162208
} // namespace vm
163209
} // namespace ext

0 commit comments

Comments
 (0)