diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index 74e387f91a..f717f4b309 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -486,7 +486,7 @@ jobs: done array-api-conformity: - needs: test_linux + needs: build_linux runs-on: ${{ matrix.runner }} strategy: diff --git a/dpctl/_sycl_queue.pxd b/dpctl/_sycl_queue.pxd index 729adfc3cb..c906ada4d6 100644 --- a/dpctl/_sycl_queue.pxd +++ b/dpctl/_sycl_queue.pxd @@ -29,7 +29,7 @@ from ._sycl_event cimport SyclEvent from .program._program cimport SyclKernel -cdef void default_async_error_handler(int) nogil except * +cdef void default_async_error_handler(int) except * nogil cdef public api class _SyclQueue [ object Py_SyclQueueObject, type Py_SyclQueueType diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index e759571790..565a11dec7 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import builtins import operator import numpy as np @@ -289,6 +290,96 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src): _copy_same_shape(dst, src_same_shape) +def _empty_like_orderK(X, dt, usm_type=None, dev=None): + """Returns empty array like `x`, using order='K' + + For an array `x` that was obtained by permutation of a contiguous + array the returned array will have the same shape and the same + strides as `x`. + """ + if not isinstance(X, dpt.usm_ndarray): + raise TypeError(f"Expected usm_ndarray, got {type(X)}") + if usm_type is None: + usm_type = X.usm_type + if dev is None: + dev = X.device + fl = X.flags + if fl["C"] or X.size <= 1: + return dpt.empty_like( + X, dtype=dt, usm_type=usm_type, device=dev, order="C" + ) + elif fl["F"]: + return dpt.empty_like( + X, dtype=dt, usm_type=usm_type, device=dev, order="F" + ) + st = list(X.strides) + perm = sorted( + range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True + ) + inv_perm = sorted(range(X.ndim), key=lambda i: perm[i]) + st_sorted = [st[i] for i in perm] + sh = X.shape + sh_sorted = tuple(sh[i] for i in perm) + R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C") + if min(st_sorted) < 0: + sl = tuple( + slice(None, None, -1) + if st_sorted[i] < 0 + else slice(None, None, None) + for i in range(X.ndim) + ) + R = R[sl] + return dpt.permute_dims(R, inv_perm) + + +def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev): + if not isinstance(X1, dpt.usm_ndarray): + raise TypeError(f"Expected usm_ndarray, got {type(X1)}") + if not isinstance(X2, dpt.usm_ndarray): + raise TypeError(f"Expected usm_ndarray, got {type(X2)}") + nd1 = X1.ndim + nd2 = X2.ndim + if nd1 > nd2 and X1.shape == res_shape: + return _empty_like_orderK(X1, dt, usm_type, dev) + elif nd1 < nd2 and X2.shape == res_shape: + return _empty_like_orderK(X2, dt, usm_type, dev) + fl1 = X1.flags + fl2 = X2.flags + if fl1["C"] or fl2["C"]: + return dpt.empty( + res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C" + ) + if fl1["F"] and fl2["F"]: + return dpt.empty( + res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F" + ) + st1 = list(X1.strides) + st2 = list(X2.strides) + max_ndim = max(nd1, nd2) + st1 += [0] * (max_ndim - len(st1)) + st2 += [0] * (max_ndim - len(st2)) + perm = sorted( + range(max_ndim), + key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])), + reverse=True, + ) + inv_perm = sorted(range(max_ndim), key=lambda i: perm[i]) + st1_sorted = [st1[i] for i in perm] + st2_sorted = [st2[i] for i in perm] + sh = res_shape + sh_sorted = tuple(sh[i] for i in perm) + R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C") + if max(min(st1_sorted), min(st2_sorted)) < 0: + sl = tuple( + slice(None, None, -1) + if (st1_sorted[i] < 0 and st2_sorted[i] < 0) + else slice(None, None, None) + for i in range(nd1) + ) + R = R[sl] + return dpt.permute_dims(R, inv_perm) + + def copy(usm_ary, order="K"): """copy(ary, order="K") @@ -334,28 +425,15 @@ def copy(usm_ary, order="K"): "Unrecognized value of the order keyword. " "Recognized values are 'A', 'C', 'F', or 'K'" ) - c_contig = usm_ary.flags.c_contiguous - f_contig = usm_ary.flags.f_contiguous - R = dpt.usm_ndarray( - usm_ary.shape, - dtype=usm_ary.dtype, - buffer=usm_ary.usm_type, - order=copy_order, - buffer_ctor_kwargs={"queue": usm_ary.sycl_queue}, - ) - if order == "K" and (not c_contig and not f_contig): - original_strides = usm_ary.strides - ind = sorted( - range(usm_ary.ndim), - key=lambda i: abs(original_strides[i]), - reverse=True, - ) - new_strides = tuple(R.strides[ind[i]] for i in ind) + if order == "K": + R = _empty_like_orderK(usm_ary, usm_ary.dtype) + else: R = dpt.usm_ndarray( usm_ary.shape, dtype=usm_ary.dtype, - buffer=R.usm_data, - strides=new_strides, + buffer=usm_ary.usm_type, + order=copy_order, + buffer_ctor_kwargs={"queue": usm_ary.sycl_queue}, ) _copy_same_shape(R, usm_ary) return R @@ -432,26 +510,15 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): "Unrecognized value of the order keyword. " "Recognized values are 'A', 'C', 'F', or 'K'" ) - R = dpt.usm_ndarray( - usm_ary.shape, - dtype=target_dtype, - buffer=usm_ary.usm_type, - order=copy_order, - buffer_ctor_kwargs={"queue": usm_ary.sycl_queue}, - ) - if order == "K" and (not c_contig and not f_contig): - original_strides = usm_ary.strides - ind = sorted( - range(usm_ary.ndim), - key=lambda i: abs(original_strides[i]), - reverse=True, - ) - new_strides = tuple(R.strides[ind[i]] for i in ind) + if order == "K": + R = _empty_like_orderK(usm_ary, target_dtype) + else: R = dpt.usm_ndarray( usm_ary.shape, dtype=target_dtype, - buffer=R.usm_data, - strides=new_strides, + buffer=usm_ary.usm_type, + order=copy_order, + buffer_ctor_kwargs={"queue": usm_ary.sycl_queue}, ) _copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary) return R @@ -492,6 +559,8 @@ def _extract_impl(ary, ary_mask, axis=0): dst = dpt.empty( dst_shape, dtype=ary.dtype, usm_type=ary.usm_type, device=ary.device ) + if dst.size == 0: + return dst hev, _ = ti._extract( src=ary, cumsum=cumsum, diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index 78c79fb2ad..f924ee31cd 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -26,10 +26,9 @@ from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer from dpctl.utils import ExecutionPlacementError +from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK from ._type_utils import ( _acceptance_fn_default, - _empty_like_orderK, - _empty_like_pair_orderK, _find_buf_dtype, _find_buf_dtype2, _find_inplace_dtype, diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 7b066417af..26c1ab60cf 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -25,6 +25,8 @@ import dpctl.tensor._tensor_impl as ti import dpctl.utils as dputils +from ._type_utils import _to_device_supported_dtype + __doc__ = ( "Implementation module for array manipulation " "functions in :module:`dpctl.tensor`" @@ -504,8 +506,10 @@ def _arrays_validation(arrays, check_ndim=True): _supported_dtype(Xi.dtype for Xi in arrays) res_dtype = X0.dtype + dev = exec_q.sycl_device for i in range(1, n): res_dtype = np.promote_types(res_dtype, arrays[i]) + res_dtype = _to_device_supported_dtype(res_dtype, dev) if check_ndim: for i in range(1, n): @@ -554,8 +558,13 @@ def _concat_axis_None(arrays): sycl_queue=exec_q, ) else: + src_ = array + # _copy_usm_ndarray_for_reshape requires src and dst to have + # the same data type + if not array.dtype == res_dtype: + src_ = dpt.astype(src_, res_dtype) hev, _ = ti._copy_usm_ndarray_for_reshape( - src=array, + src=src_, dst=res[fill_start:fill_end], shift=0, sycl_queue=exec_q, diff --git a/dpctl/tensor/_slicing.pxi b/dpctl/tensor/_slicing.pxi index 361dd906c3..4f45d57391 100644 --- a/dpctl/tensor/_slicing.pxi +++ b/dpctl/tensor/_slicing.pxi @@ -33,9 +33,13 @@ cdef Py_ssize_t _slice_len( if sl_start == sl_stop: return 0 if sl_step > 0: + if sl_start > sl_stop: + return 0 # 1 + argmax k such htat sl_start + sl_step*k < sl_stop return 1 + ((sl_stop - sl_start - 1) // sl_step) else: + if sl_start < sl_stop: + return 0 return 1 + ((sl_stop - sl_start + 1) // sl_step) @@ -221,6 +225,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): k_new = k + ellipses_count new_shape.extend(shape[k:k_new]) new_strides.extend(strides[k:k_new]) + if any(dim == 0 for dim in shape[k:k_new]): + is_empty = True + new_offset = offset k = k_new elif ind_i is None: new_shape.append(1) @@ -236,6 +243,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int): new_offset = new_offset + sl_start * strides[k] if sh_i == 0: is_empty = True + new_offset = offset k = k_new elif _is_boolean(ind_i): new_shape.append(1 if ind_i else 0) diff --git a/dpctl/tensor/_stride_utils.pxi b/dpctl/tensor/_stride_utils.pxi index 896c31e65a..ea59ec5402 100644 --- a/dpctl/tensor/_stride_utils.pxi +++ b/dpctl/tensor/_stride_utils.pxi @@ -64,6 +64,8 @@ cdef int _from_input_shape_strides( cdef int j cdef bint all_incr = 1 cdef bint all_decr = 1 + cdef bint all_incr_modified = 0 + cdef bint all_decr_modified = 0 cdef Py_ssize_t elem_count = 1 cdef Py_ssize_t min_shift = 0 cdef Py_ssize_t max_shift = 0 @@ -166,12 +168,14 @@ cdef int _from_input_shape_strides( j = j + 1 if j < nd: if all_incr: + all_incr_modified = 1 all_incr = ( (strides_arr[i] > 0) and (strides_arr[j] > 0) and (strides_arr[i] <= strides_arr[j]) ) if all_decr: + all_decr_modified = 1 all_decr = ( (strides_arr[i] > 0) and (strides_arr[j] > 0) and @@ -180,6 +184,10 @@ cdef int _from_input_shape_strides( i = j else: break + # should only set contig flags on actually obtained + # values, rather than default values + all_incr = all_incr and all_incr_modified + all_decr = all_decr and all_decr_modified if all_incr and all_decr: contig[0] = (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS) elif all_incr: diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index fb2223f292..b576764689 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import builtins - import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as ti @@ -116,96 +114,6 @@ def _can_cast(from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool): return can_cast_v -def _empty_like_orderK(X, dt, usm_type=None, dev=None): - """Returns empty array like `x`, using order='K' - - For an array `x` that was obtained by permutation of a contiguous - array the returned array will have the same shape and the same - strides as `x`. - """ - if not isinstance(X, dpt.usm_ndarray): - raise TypeError(f"Expected usm_ndarray, got {type(X)}") - if usm_type is None: - usm_type = X.usm_type - if dev is None: - dev = X.device - fl = X.flags - if fl["C"] or X.size <= 1: - return dpt.empty_like( - X, dtype=dt, usm_type=usm_type, device=dev, order="C" - ) - elif fl["F"]: - return dpt.empty_like( - X, dtype=dt, usm_type=usm_type, device=dev, order="F" - ) - st = list(X.strides) - perm = sorted( - range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True - ) - inv_perm = sorted(range(X.ndim), key=lambda i: perm[i]) - st_sorted = [st[i] for i in perm] - sh = X.shape - sh_sorted = tuple(sh[i] for i in perm) - R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C") - if min(st_sorted) < 0: - sl = tuple( - slice(None, None, -1) - if st_sorted[i] < 0 - else slice(None, None, None) - for i in range(X.ndim) - ) - R = R[sl] - return dpt.permute_dims(R, inv_perm) - - -def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev): - if not isinstance(X1, dpt.usm_ndarray): - raise TypeError(f"Expected usm_ndarray, got {type(X1)}") - if not isinstance(X2, dpt.usm_ndarray): - raise TypeError(f"Expected usm_ndarray, got {type(X2)}") - nd1 = X1.ndim - nd2 = X2.ndim - if nd1 > nd2 and X1.shape == res_shape: - return _empty_like_orderK(X1, dt, usm_type, dev) - elif nd1 < nd2 and X2.shape == res_shape: - return _empty_like_orderK(X2, dt, usm_type, dev) - fl1 = X1.flags - fl2 = X2.flags - if fl1["C"] or fl2["C"]: - return dpt.empty( - res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C" - ) - if fl1["F"] and fl2["F"]: - return dpt.empty( - res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F" - ) - st1 = list(X1.strides) - st2 = list(X2.strides) - max_ndim = max(nd1, nd2) - st1 += [0] * (max_ndim - len(st1)) - st2 += [0] * (max_ndim - len(st2)) - perm = sorted( - range(max_ndim), - key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])), - reverse=True, - ) - inv_perm = sorted(range(max_ndim), key=lambda i: perm[i]) - st1_sorted = [st1[i] for i in perm] - st2_sorted = [st2[i] for i in perm] - sh = res_shape - sh_sorted = tuple(sh[i] for i in perm) - R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C") - if max(min(st1_sorted), min(st2_sorted)) < 0: - sl = tuple( - slice(None, None, -1) - if (st1_sorted[i] < 0 and st2_sorted[i] < 0) - else slice(None, None, None) - for i in range(nd1) - ) - R = R[sl] - return dpt.permute_dims(R, inv_perm) - - def _to_device_supported_dtype(dt, dev): has_fp16 = dev.has_aspect_fp16 has_fp64 = dev.has_aspect_fp64 @@ -339,8 +247,6 @@ def _find_inplace_dtype(lhs_dtype, rhs_dtype, query_fn, sycl_dev): "_find_buf_dtype", "_find_buf_dtype2", "_find_inplace_dtype", - "_empty_like_orderK", - "_empty_like_pair_orderK", "_to_device_supported_dtype", "_acceptance_fn_default", "_acceptance_fn_divide", diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 5b1bd5f6a3..20a226bd3e 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -764,6 +764,8 @@ cdef class usm_ndarray: ind, (self).shape, ( self).strides, self.get_offset()) cdef usm_ndarray res + cdef int i = 0 + cdef bint matching = 1 if len(_meta) < 5: raise RuntimeError @@ -787,7 +789,20 @@ cdef class usm_ndarray: from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index if len(adv_ind) == 1 and adv_ind[0].dtype == dpt_bool: - return _extract_impl(res, adv_ind[0], axis=adv_ind_start_p) + key_ = adv_ind[0] + adv_ind_end_p = key_.ndim + adv_ind_start_p + if adv_ind_end_p > res.ndim: + raise IndexError("too many indices for the array") + key_shape = key_.shape + arr_shape = res.shape[adv_ind_start_p:adv_ind_end_p] + for i in range(key_.ndim): + if matching: + if not key_shape[i] == arr_shape[i] and key_shape[i] > 0: + matching = 0 + if not matching: + raise IndexError("boolean index did not match indexed array in dimensions") + res = _extract_impl(res, key_, axis=adv_ind_start_p) + return res if any(ind.dtype == dpt_bool for ind in adv_ind): adv_ind_int = list() @@ -857,7 +872,7 @@ cdef class usm_ndarray: strides=self.strides, offset=self.get_offset() ) - res.flags_ = self.flags.flags + res.flags_ = self.flags_ return res else: nbytes = self.usm_data.nbytes @@ -872,7 +887,7 @@ cdef class usm_ndarray: strides=self.strides, offset=self.get_offset() ) - res.flags_ = self.flags.flags + res.flags_ = self.flags_ return res def _set_namespace(self, mod): @@ -884,12 +899,14 @@ cdef class usm_ndarray: Returns array namespace, member functions of which implement data API. """ - return self.array_namespace_ + return self.array_namespace_ if self.array_namespace_ is not None else dpctl.tensor def __bool__(self): if self.size == 1: mem_view = dpmem.as_usm_memory(self) - return mem_view.copy_to_host().view(self.dtype).__bool__() + host_buf = mem_view.copy_to_host() + view = host_buf.view(self.dtype) + return view.__bool__() if self.size == 0: raise ValueError( @@ -898,13 +915,15 @@ cdef class usm_ndarray: raise ValueError( "The truth value of an array with more than one element is " - "ambiguous. Use a.any() or a.all()" + "ambiguous. Use dpctl.tensor.any() or dpctl.tensor.all()" ) def __float__(self): if self.size == 1: mem_view = dpmem.as_usm_memory(self) - return mem_view.copy_to_host().view(self.dtype).__float__() + host_buf = mem_view.copy_to_host() + view = host_buf.view(self.dtype) + return view.__float__() raise ValueError( "only size-1 arrays can be converted to Python scalars" @@ -913,7 +932,9 @@ cdef class usm_ndarray: def __complex__(self): if self.size == 1: mem_view = dpmem.as_usm_memory(self) - return mem_view.copy_to_host().view(self.dtype).__complex__() + host_buf = mem_view.copy_to_host() + view = host_buf.view(self.dtype) + return view.__complex__() raise ValueError( "only size-1 arrays can be converted to Python scalars" @@ -922,7 +943,9 @@ cdef class usm_ndarray: def __int__(self): if self.size == 1: mem_view = dpmem.as_usm_memory(self) - return mem_view.copy_to_host().view(self.dtype).__int__() + host_buf = mem_view.copy_to_host() + view = host_buf.view(self.dtype) + return view.__int__() raise ValueError( "only size-1 arrays can be converted to Python scalars" @@ -957,9 +980,9 @@ cdef class usm_ndarray: def __and__(first, other): "See comment in __add__" if isinstance(first, usm_ndarray): - return _dispatch_binary_elementwise(first, "logical_and", other) + return _dispatch_binary_elementwise(first, "bitwise_and", other) elif isinstance(other, usm_ndarray): - return _dispatch_binary_elementwise2(first, "logical_and", other) + return _dispatch_binary_elementwise2(first, "bitwise_and", other) return NotImplemented def __dlpack__(self, stream=None): @@ -1023,7 +1046,7 @@ cdef class usm_ndarray: return _dispatch_binary_elementwise(self, "greater", other) def __invert__(self): - return _dispatch_unary_elementwise(self, "invert") + return _dispatch_unary_elementwise(self, "bitwise_invert") def __le__(self, other): return _dispatch_binary_elementwise(self, "less_equal", other) @@ -1037,9 +1060,9 @@ cdef class usm_ndarray: def __lshift__(first, other): "See comment in __add__" if isinstance(first, usm_ndarray): - return _dispatch_binary_elementwise(first, "left_shift", other) + return _dispatch_binary_elementwise(first, "bitwise_left_shift", other) elif isinstance(other, usm_ndarray): - return _dispatch_binary_elementwise2(first, "left_shift", other) + return _dispatch_binary_elementwise2(first, "bitwise_left_shift", other) return NotImplemented def __lt__(self, other): @@ -1056,9 +1079,9 @@ cdef class usm_ndarray: def __mod__(first, other): "See comment in __add__" if isinstance(first, usm_ndarray): - return _dispatch_binary_elementwise(first, "mod", other) + return _dispatch_binary_elementwise(first, "remainder", other) elif isinstance(other, usm_ndarray): - return _dispatch_binary_elementwise2(first, "mod", other) + return _dispatch_binary_elementwise2(first, "remainder", other) return NotImplemented def __mul__(first, other): @@ -1078,9 +1101,9 @@ cdef class usm_ndarray: def __or__(first, other): "See comment in __add__" if isinstance(first, usm_ndarray): - return _dispatch_binary_elementwise(first, "logical_or", other) + return _dispatch_binary_elementwise(first, "bitwise_or", other) elif isinstance(other, usm_ndarray): - return _dispatch_binary_elementwise2(first, "logical_or", other) + return _dispatch_binary_elementwise2(first, "bitwise_or", other) return NotImplemented def __pos__(self): @@ -1090,17 +1113,17 @@ cdef class usm_ndarray: "See comment in __add__" if mod is None: if isinstance(first, usm_ndarray): - return _dispatch_binary_elementwise(first, "power", other) + return _dispatch_binary_elementwise(first, "pow", other) elif isinstance(other, usm_ndarray): - return _dispatch_binary_elementwise(first, "power", other) + return _dispatch_binary_elementwise(first, "pow", other) return NotImplemented def __rshift__(first, other): "See comment in __add__" if isinstance(first, usm_ndarray): - return _dispatch_binary_elementwise(first, "right_shift", other) + return _dispatch_binary_elementwise(first, "bitwise_right_shift", other) elif isinstance(other, usm_ndarray): - return _dispatch_binary_elementwise2(first, "right_shift", other) + return _dispatch_binary_elementwise2(first, "bitwise_right_shift", other) return NotImplemented def __setitem__(self, key, rhs): @@ -1144,6 +1167,8 @@ cdef class usm_ndarray: if adv_ind_start_p < 0: # basic slicing if isinstance(rhs, usm_ndarray): + if Xv.size == 0: + return _copy_from_usm_ndarray_to_usm_ndarray(Xv, rhs) else: if hasattr(rhs, "__sycl_usm_array_interface__"): @@ -1202,57 +1227,57 @@ cdef class usm_ndarray: def __truediv__(first, other): "See comment in __add__" if isinstance(first, usm_ndarray): - return _dispatch_binary_elementwise(first, "true_divide", other) + return _dispatch_binary_elementwise(first, "divide", other) elif isinstance(other, usm_ndarray): - return _dispatch_binary_elementwise2(first, "true_divide", other) + return _dispatch_binary_elementwise2(first, "divide", other) return NotImplemented def __xor__(first, other): "See comment in __add__" if isinstance(first, usm_ndarray): - return _dispatch_binary_elementwise(first, "logical_xor", other) + return _dispatch_binary_elementwise(first, "bitwise_xor", other) elif isinstance(other, usm_ndarray): - return _dispatch_binary_elementwise2(first, "logical_xor", other) + return _dispatch_binary_elementwise2(first, "bitwise_xor", other) return NotImplemented def __radd__(self, other): return _dispatch_binary_elementwise(self, "add", other) def __rand__(self, other): - return _dispatch_binary_elementwise(self, "logical_and", other) + return _dispatch_binary_elementwise(self, "bitwise_and", other) def __rfloordiv__(self, other): return _dispatch_binary_elementwise2(other, "floor_divide", self) def __rlshift__(self, other): - return _dispatch_binary_elementwise2(other, "left_shift", self) + return _dispatch_binary_elementwise2(other, "bitwise_left_shift", self) def __rmatmul__(self, other): return _dispatch_binary_elementwise2(other, "matmul", self) def __rmod__(self, other): - return _dispatch_binary_elementwise2(other, "mod", self) + return _dispatch_binary_elementwise2(other, "remainder", self) def __rmul__(self, other): return _dispatch_binary_elementwise(self, "multiply", other) def __ror__(self, other): - return _dispatch_binary_elementwise(self, "logical_or", other) + return _dispatch_binary_elementwise(self, "bitwise_or", other) def __rpow__(self, other): - return _dispatch_binary_elementwise2(other, "power", self) + return _dispatch_binary_elementwise2(other, "pow", self) def __rrshift__(self, other): - return _dispatch_binary_elementwise2(other, "right_shift", self) + return _dispatch_binary_elementwise2(other, "bitwise_right_shift", self) def __rsub__(self, other): return _dispatch_binary_elementwise2(other, "subtract", self) def __rtruediv__(self, other): - return _dispatch_binary_elementwise2(other, "true_divide", self) + return _dispatch_binary_elementwise2(other, "divide", self) def __rxor__(self, other): - return _dispatch_binary_elementwise2(other, "logical_xor", self) + return _dispatch_binary_elementwise2(other, "bitwise_xor", self) def __iadd__(self, other): from ._elementwise_funcs import add diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp index 39ba5c3fcf..b718a5f991 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include "utils/offset_utils.hpp" @@ -55,16 +56,12 @@ using dpctl::tensor::type_utils::vec_cast; template struct LogAddExpFunctor { - using supports_sg_loadstore = typename std::negation< - std::disjunction, tu_ns::is_complex>>; - using supports_vec = typename std::negation< - std::disjunction, tu_ns::is_complex>>; + using supports_sg_loadstore = std::true_type; + using supports_vec = std::true_type; resT operator()(const argT1 &in1, const argT2 &in2) { - resT max = std::max(in1, in2); - resT min = std::min(in1, in2); - return max + std::log1p(std::exp(min - max)); + return impl(in1, in2); } template @@ -72,16 +69,48 @@ template struct LogAddExpFunctor const sycl::vec &in2) { sycl::vec res; - auto diff = in1 - in2; + auto diff = in1 - in2; // take advantange of faster vec arithmetic #pragma unroll for (int i = 0; i < vec_sz; ++i) { - resT max = std::max(in1[i], in2[i]); - res[i] = max + std::log1p(std::exp(std::abs(diff[i]))); + if (std::isfinite(diff[i])) { + res[i] = std::max(in1[i], in2[i]) + + impl_finite(-std::abs(diff[i])); + } + else { + res[i] = impl(in1[i], in2[i]); + } } return res; } + +private: + template T impl(T const &in1, T const &in2) + { + if (in1 == in2) { // handle signed infinities + const T log2 = std::log(T(2)); + return in1 + log2; + } + else { + const T tmp = in1 - in2; + if (tmp > 0) { + return in1 + std::log1p(std::exp(-tmp)); + } + else if (tmp <= 0) { + return in2 + std::log1p(std::exp(tmp)); + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + } + + template T impl_finite(T const &in) + { + return (in > 0) ? (in + std::log1p(std::exp(-in))) + : std::log1p(std::exp(in)); + } }; template