Skip to content

Commit 9018745

Browse files
Merge pull request #1463 from IntelPython/optimize-small-size-tree-reduction
Improve performance of reduction for small number of elements to reduce for types where tree-reduction is needed
2 parents 11ecba8 + d4d4992 commit 9018745

12 files changed

+921
-188
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,13 @@ set(_reduction_sources
113113
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduce_hypot.cpp
114114
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp
115115
)
116+
set(_boolean_reduction_sources
117+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp
118+
)
116119
set(_tensor_impl_sources
117-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_py.cpp
118-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators.cpp
120+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_ctors.cpp
119121
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
122+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators.cpp
120123
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
121124
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
122125
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
@@ -128,19 +131,39 @@ set(_tensor_impl_sources
128131
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
129132
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp
130133
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
131-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp
132134
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
133135
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
134136
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
135137
)
136-
list(APPEND _tensor_impl_sources
137-
${_elementwise_sources}
138-
${_reduction_sources}
138+
set(_tensor_elementwise_impl_sources
139+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_elementwise.cpp
140+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
141+
${_elementwise_sources}
142+
)
143+
set(_tensor_reductions_impl_sources
144+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_reductions.cpp
145+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
146+
${_boolean_reduction_sources}
147+
${_reduction_sources}
139148
)
140149

150+
set(_py_trgts)
151+
141152
set(python_module_name _tensor_impl)
142153
pybind11_add_module(${python_module_name} MODULE ${_tensor_impl_sources})
143154
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_impl_sources})
155+
list(APPEND _py_trgts ${python_module_name})
156+
157+
set(python_module_name _tensor_elementwise_impl)
158+
pybind11_add_module(${python_module_name} MODULE ${_tensor_elementwise_impl_sources})
159+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_elementwise_impl_sources})
160+
list(APPEND _py_trgts ${python_module_name})
161+
162+
set(python_module_name _tensor_reductions_impl)
163+
pybind11_add_module(${python_module_name} MODULE ${_tensor_reductions_impl_sources})
164+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_reductions_impl_sources})
165+
list(APPEND _py_trgts ${python_module_name})
166+
144167
set(_clang_prefix "")
145168
if (WIN32)
146169
set(_clang_prefix "/clang:")
@@ -170,19 +193,22 @@ if (UNIX)
170193
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sqrt.cpp
171194
PROPERTIES COMPILE_DEFINITIONS "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES")
172195
endif()
173-
target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int)
174-
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)
175-
if(UNIX)
176-
# this option is supported on Linux only
177-
target_link_options(${python_module_name} PRIVATE -fsycl-link-huge-device-code)
178-
endif()
179-
target_include_directories(${python_module_name}
180-
PRIVATE
181-
${CMAKE_CURRENT_SOURCE_DIR}/../include
182-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/include
183-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/
184-
)
196+
185197
set(_linker_options "LINKER:${DPCTL_LDFLAGS}")
186-
target_link_options(${python_module_name} PRIVATE ${_linker_options})
187-
add_dependencies(${python_module_name} _dpctl4pybind11_deps)
188-
install(TARGETS ${python_module_name} DESTINATION "dpctl/tensor")
198+
foreach(python_module_name ${_py_trgts})
199+
target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int)
200+
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)
201+
if(UNIX)
202+
# this option is supported on Linux only
203+
target_link_options(${python_module_name} PRIVATE -fsycl-link-huge-device-code)
204+
endif()
205+
target_include_directories(${python_module_name}
206+
PRIVATE
207+
${CMAKE_CURRENT_SOURCE_DIR}/../include
208+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/include
209+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/
210+
)
211+
target_link_options(${python_module_name} PRIVATE ${_linker_options})
212+
add_dependencies(${python_module_name} _dpctl4pybind11_deps)
213+
install(TARGETS ${python_module_name} DESTINATION "dpctl/tensor")
214+
endforeach()

dpctl/tensor/_clip.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import dpctl
1818
import dpctl.tensor as dpt
19+
import dpctl.tensor._tensor_elementwise_impl as tei
1920
import dpctl.tensor._tensor_impl as ti
2021
from dpctl.tensor._copy_utils import (
2122
_empty_like_orderK,
@@ -429,9 +430,9 @@ def clip(x, min=None, max=None, out=None, order="K"):
429430
"only one of `min` and `max` is permitted to be `None`"
430431
)
431432
elif max is None:
432-
return _clip_none(x, min, out, order, ti._maximum)
433+
return _clip_none(x, min, out, order, tei._maximum)
433434
elif min is None:
434-
return _clip_none(x, max, out, order, ti._minimum)
435+
return _clip_none(x, max, out, order, tei._minimum)
435436
else:
436437
q1, x_usm_type = x.sycl_queue, x.usm_type
437438
q2, min_usm_type = _get_queue_usm_type(min)

dpctl/tensor/_elementwise_common.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,31 @@
3939
class UnaryElementwiseFunc:
4040
"""
4141
Class that implements unary element-wise functions.
42+
43+
Args:
44+
name (str):
45+
Name of the unary function
46+
result_type_resovler_fn (callable):
47+
Function that takes dtype of the input and
48+
returns the dtype of the result if the
49+
implementation functions supports it, or
50+
returns `None` otherwise.
51+
unary_dp_impl_fn (callable):
52+
Data-parallel implementation function with signature
53+
`impl_fn(src: usm_ndarray, dst: usm_ndarray,
54+
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
55+
where the `src` is the argument array, `dst` is the
56+
array to be populated with function values, effectively
57+
evaluating `dst = func(src)`.
58+
The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
59+
The first event corresponds to data-management host tasks,
60+
including lifetime management of argument Python objects to ensure
61+
that their associated USM allocation is not freed before offloaded
62+
computational tasks complete execution, while the second event
63+
corresponds to computational tasks associated with function
64+
evaluation.
65+
docs (str):
66+
Documentation string for the unary function.
4267
"""
4368

4469
def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
@@ -55,8 +80,31 @@ def __str__(self):
5580
def __repr__(self):
5681
return f"<{self.__name__} '{self.name_}'>"
5782

83+
def get_implementation_function(self):
84+
"""Returns the implementation function for
85+
this elementwise unary function.
86+
87+
"""
88+
return self.unary_fn_
89+
90+
def get_type_result_resolver_function(self):
91+
"""Returns the type resolver function for this
92+
elementwise unary function.
93+
"""
94+
return self.result_type_resolver_fn_
95+
5896
@property
5997
def types(self):
98+
"""Returns information about types supported by
99+
implementation function, using NumPy's character
100+
encoding for data types, e.g.
101+
102+
:Example:
103+
.. code-block:: python
104+
105+
dpctl.tensor.sin.types
106+
# Outputs: ['e->e', 'f->f', 'd->d', 'F->F', 'D->D']
107+
"""
60108
types = self.types_
61109
if not types:
62110
types = []
@@ -363,6 +411,56 @@ def _get_shape(o):
363411
class BinaryElementwiseFunc:
364412
"""
365413
Class that implements binary element-wise functions.
414+
415+
Args:
416+
name (str):
417+
Name of the unary function
418+
result_type_resovle_fn (callable):
419+
Function that takes dtypes of the input and
420+
returns the dtype of the result if the
421+
implementation functions supports it, or
422+
returns `None` otherwise.
423+
binary_dp_impl_fn (callable):
424+
Data-parallel implementation function with signature
425+
`impl_fn(src1: usm_ndarray, src2: usm_ndarray, dst: usm_ndarray,
426+
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
427+
where the `src1` and `src2` are the argument arrays, `dst` is the
428+
array to be populated with function values,
429+
i.e. `dst=func(src1, src2)`.
430+
The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
431+
The first event corresponds to data-management host tasks,
432+
including lifetime management of argument Python objects to ensure
433+
that their associated USM allocation is not freed before offloaded
434+
computational tasks complete execution, while the second event
435+
corresponds to computational tasks associated with function
436+
evaluation.
437+
docs (str):
438+
Documentation string for the unary function.
439+
binary_inplace_fn (callable, optional):
440+
Data-parallel implementation function with signature
441+
`impl_fn(src: usm_ndarray, dst: usm_ndarray,
442+
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
443+
where the `src` is the argument array, `dst` is the
444+
array to be populated with function values,
445+
i.e. `dst=func(dst, src)`.
446+
The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
447+
The first event corresponds to data-management host tasks,
448+
including async lifetime management of Python arguments,
449+
while the second event corresponds to computational tasks
450+
associated with function evaluation.
451+
acceptance_fn (callable, optional):
452+
Function to influence type promotion behavior of this binary
453+
function. The function takes 6 arguments:
454+
arg1_dtype - Data type of the first argument
455+
arg2_dtype - Data type of the second argument
456+
ret_buf1_dtype - Data type the first argument would be cast to
457+
ret_buf2_dtype - Data type the second argument would be cast to
458+
res_dtype - Data type of the output array with function values
459+
sycl_dev - The :class:`dpctl.SyclDevice` where the function
460+
evaluation is carried out.
461+
The function is only called when both arguments of the binary
462+
function require casting, e.g. both arguments of
463+
`dpctl.tensor.logaddexp` are arrays with integral data type.
366464
"""
367465

368466
def __init__(
@@ -392,8 +490,60 @@ def __str__(self):
392490
def __repr__(self):
393491
return f"<{self.__name__} '{self.name_}'>"
394492

493+
def get_implementation_function(self):
494+
"""Returns the out-of-place implementation
495+
function for this elementwise binary function.
496+
497+
"""
498+
return self.binary_fn_
499+
500+
def get_implementation_inplace_function(self):
501+
"""Returns the in-place implementation
502+
function for this elementwise binary function.
503+
504+
"""
505+
return self.binary_inplace_fn_
506+
507+
def get_type_result_resolver_function(self):
508+
"""Returns the type resolver function for this
509+
elementwise binary function.
510+
"""
511+
return self.result_type_resolver_fn_
512+
513+
def get_type_promotion_path_acceptance_function(self):
514+
"""Returns the acceptance function for this
515+
elementwise binary function.
516+
517+
Acceptance function influences the type promotion
518+
behavior of this binary function.
519+
The function takes 6 arguments:
520+
arg1_dtype - Data type of the first argument
521+
arg2_dtype - Data type of the second argument
522+
ret_buf1_dtype - Data type the first argument would be cast to
523+
ret_buf2_dtype - Data type the second argument would be cast to
524+
res_dtype - Data type of the output array with function values
525+
sycl_dev - :class:`dpctl.SyclDevice` on which function evaluation
526+
is carried out.
527+
528+
The acceptance function is only invoked if both input arrays must be
529+
cast to intermediary data types, as would happen during call of
530+
`dpctl.tensor.hypot` with both arrays being of integral data type.
531+
"""
532+
return self.acceptance_fn_
533+
395534
@property
396535
def types(self):
536+
"""Returns information about types supported by
537+
implementation function, using NumPy's character
538+
encoding for data types, e.g.
539+
540+
:Example:
541+
.. code-block:: python
542+
543+
dpctl.tensor.divide.types
544+
# Outputs: ['ee->e', 'ff->f', 'fF->F', 'dd->d', 'dD->D',
545+
# 'Ff->F', 'FF->F', 'Dd->D', 'DD->D']
546+
"""
397547
types = self.types_
398548
if not types:
399549
types = []

dpctl/tensor/_elementwise_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
import dpctl.tensor._tensor_impl as ti
17+
import dpctl.tensor._tensor_elementwise_impl as ti
1818

1919
from ._elementwise_common import BinaryElementwiseFunc, UnaryElementwiseFunc
2020
from ._type_utils import _acceptance_fn_divide

dpctl/tensor/_reduction.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import dpctl
2020
import dpctl.tensor as dpt
2121
import dpctl.tensor._tensor_impl as ti
22+
import dpctl.tensor._tensor_reductions_impl as tri
2223

2324
from ._type_utils import _to_device_supported_dtype
2425

@@ -220,8 +221,8 @@ def sum(x, axis=None, dtype=None, keepdims=False):
220221
axis,
221222
dtype,
222223
keepdims,
223-
ti._sum_over_axis,
224-
ti._sum_over_axis_dtype_supported,
224+
tri._sum_over_axis,
225+
tri._sum_over_axis_dtype_supported,
225226
_default_reduction_dtype,
226227
_identity=0,
227228
)
@@ -281,8 +282,8 @@ def prod(x, axis=None, dtype=None, keepdims=False):
281282
axis,
282283
dtype,
283284
keepdims,
284-
ti._prod_over_axis,
285-
ti._prod_over_axis_dtype_supported,
285+
tri._prod_over_axis,
286+
tri._prod_over_axis_dtype_supported,
286287
_default_reduction_dtype,
287288
_identity=1,
288289
)
@@ -335,8 +336,8 @@ def logsumexp(x, axis=None, dtype=None, keepdims=False):
335336
axis,
336337
dtype,
337338
keepdims,
338-
ti._logsumexp_over_axis,
339-
lambda inp_dt, res_dt, *_: ti._logsumexp_over_axis_dtype_supported(
339+
tri._logsumexp_over_axis,
340+
lambda inp_dt, res_dt, *_: tri._logsumexp_over_axis_dtype_supported(
340341
inp_dt, res_dt
341342
),
342343
_default_reduction_dtype_fp_types,
@@ -391,8 +392,8 @@ def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
391392
axis,
392393
dtype,
393394
keepdims,
394-
ti._hypot_over_axis,
395-
lambda inp_dt, res_dt, *_: ti._hypot_over_axis_dtype_supported(
395+
tri._hypot_over_axis,
396+
lambda inp_dt, res_dt, *_: tri._hypot_over_axis_dtype_supported(
396397
inp_dt, res_dt
397398
),
398399
_default_reduction_dtype_fp_types,
@@ -468,7 +469,7 @@ def max(x, axis=None, keepdims=False):
468469
entire array, a zero-dimensional array is returned. The returned
469470
array has the same data type as `x`.
470471
"""
471-
return _comparison_over_axis(x, axis, keepdims, ti._max_over_axis)
472+
return _comparison_over_axis(x, axis, keepdims, tri._max_over_axis)
472473

473474

474475
def min(x, axis=None, keepdims=False):
@@ -496,7 +497,7 @@ def min(x, axis=None, keepdims=False):
496497
entire array, a zero-dimensional array is returned. The returned
497498
array has the same data type as `x`.
498499
"""
499-
return _comparison_over_axis(x, axis, keepdims, ti._min_over_axis)
500+
return _comparison_over_axis(x, axis, keepdims, tri._min_over_axis)
500501

501502

502503
def _search_over_axis(x, axis, keepdims, _reduction_fn):
@@ -577,7 +578,7 @@ def argmax(x, axis=None, keepdims=False):
577578
zero-dimensional array is returned. The returned array has the
578579
default array index data type for the device of `x`.
579580
"""
580-
return _search_over_axis(x, axis, keepdims, ti._argmax_over_axis)
581+
return _search_over_axis(x, axis, keepdims, tri._argmax_over_axis)
581582

582583

583584
def argmin(x, axis=None, keepdims=False):
@@ -609,4 +610,4 @@ def argmin(x, axis=None, keepdims=False):
609610
zero-dimensional array is returned. The returned array has the
610611
default array index data type for the device of `x`.
611612
"""
612-
return _search_over_axis(x, axis, keepdims, ti._argmin_over_axis)
613+
return _search_over_axis(x, axis, keepdims, tri._argmin_over_axis)

0 commit comments

Comments
 (0)