@@ -394,36 +394,6 @@ def dpnp_ceil(x, out=None, order="K"):
394394"""
395395
396396
397- def dpnp_cos (x , out = None , order = "K" ):
398- """
399- Invokes cos() function from pybind11 extension of OneMKL VM if possible.
400-
401- Otherwise fully relies on dpctl.tensor implementation for cos() function.
402-
403- """
404-
405- def _call_cos (src , dst , sycl_queue , depends = None ):
406- """A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
407-
408- if depends is None :
409- depends = []
410-
411- if vmi ._mkl_cos_to_call (sycl_queue , src , dst ):
412- # call pybind11 extension for cos() function from OneMKL VM
413- return vmi ._cos (sycl_queue , src , dst , depends )
414- return ti ._cos (src , dst , sycl_queue , depends )
415-
416- # dpctl.tensor only works with usm_ndarray
417- x1_usm = dpnp .get_usm_ndarray (x )
418- out_usm = None if out is None else dpnp .get_usm_ndarray (out )
419-
420- func = UnaryElementwiseFunc (
421- "cos" , ti ._cos_result_type , _call_cos , _cos_docstring
422- )
423- res_usm = func (x1_usm , out = out_usm , order = order )
424- return dpnp_array ._create_from_usm_ndarray (res_usm )
425-
426-
427397_conj_docstring = """
428398conj(x, out=None, order='K')
429399
@@ -462,6 +432,36 @@ def _call_conj(src, dst, sycl_queue, depends=None):
462432)
463433
464434
435+ def dpnp_cos (x , out = None , order = "K" ):
436+ """
437+ Invokes cos() function from pybind11 extension of OneMKL VM if possible.
438+
439+ Otherwise fully relies on dpctl.tensor implementation for cos() function.
440+
441+ """
442+
443+ def _call_cos (src , dst , sycl_queue , depends = None ):
444+ """A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
445+
446+ if depends is None :
447+ depends = []
448+
449+ if vmi ._mkl_cos_to_call (sycl_queue , src , dst ):
450+ # call pybind11 extension for cos() function from OneMKL VM
451+ return vmi ._cos (sycl_queue , src , dst , depends )
452+ return ti ._cos (src , dst , sycl_queue , depends )
453+
454+ # dpctl.tensor only works with usm_ndarray
455+ x1_usm = dpnp .get_usm_ndarray (x )
456+ out_usm = None if out is None else dpnp .get_usm_ndarray (out )
457+
458+ func = UnaryElementwiseFunc (
459+ "cos" , ti ._cos_result_type , _call_cos , _cos_docstring
460+ )
461+ res_usm = func (x1_usm , out = out_usm , order = order )
462+ return dpnp_array ._create_from_usm_ndarray (res_usm )
463+
464+
465465def dpnp_conj (x , out = None , order = "K" ):
466466 """
467467 Invokes conj() function from pybind11 extension of OneMKL VM if possible.
0 commit comments