|
27 | 27 | # *****************************************************************************
|
28 | 28 |
|
29 | 29 |
|
| 30 | +from sys import platform |
| 31 | + |
30 | 32 | import dpctl.tensor._tensor_impl as ti
|
31 | 33 | from dpctl.tensor._elementwise_common import (
|
32 | 34 | BinaryElementwiseFunc,
|
|
68 | 70 | "dpnp_multiply",
|
69 | 71 | "dpnp_negative",
|
70 | 72 | "dpnp_not_equal",
|
| 73 | + "dpnp_power", |
71 | 74 | "dpnp_proj",
|
72 | 75 | "dpnp_remainder",
|
73 | 76 | "dpnp_right_shift",
|
@@ -1460,6 +1463,69 @@ def dpnp_not_equal(x1, x2, out=None, order="K"):
|
1460 | 1463 | return dpnp_array._create_from_usm_ndarray(res_usm)
|
1461 | 1464 |
|
1462 | 1465 |
|
| 1466 | +_power_docstring_ = """ |
| 1467 | +power(x1, x2, out=None, order="K") |
| 1468 | +
|
| 1469 | +Calculates `x1_i` raised to `x2_i` for each element `x1_i` of the input array |
| 1470 | +`x1` with the respective element `x2_i` of the input array `x2`. |
| 1471 | +
|
| 1472 | +Args: |
| 1473 | + x1 (dpnp.ndarray): |
| 1474 | + First input array, expected to have numeric data type. |
| 1475 | + x2 (dpnp.ndarray): |
| 1476 | + Second input array, also expected to have numeric data type. |
| 1477 | + out ({None, dpnp.ndarray}, optional): |
| 1478 | + Output array to populate. Array must have the correct |
| 1479 | + shape and the expected data type. |
| 1480 | + order ("C","F","A","K", None, optional): |
| 1481 | + Output array, if parameter `out` is `None`. |
| 1482 | + Default: "K". |
| 1483 | +Returns: |
| 1484 | + dpnp.ndarray: |
| 1485 | + An array containing the result of element-wise of raising each element |
| 1486 | + to a specified power. |
| 1487 | + The data type of the returned array is determined by the Type Promotion Rules. |
| 1488 | +""" |
| 1489 | + |
| 1490 | + |
| 1491 | +def _call_pow(src1, src2, dst, sycl_queue, depends=None): |
| 1492 | + """A callback to register in BinaryElementwiseFunc class of dpctl.tensor""" |
| 1493 | + |
| 1494 | + if depends is None: |
| 1495 | + depends = [] |
| 1496 | + |
| 1497 | + # TODO: remove this check when OneMKL is fixed on Windows |
| 1498 | + is_win = platform.startswith("win") |
| 1499 | + |
| 1500 | + if not is_win and vmi._mkl_pow_to_call(sycl_queue, src1, src2, dst): |
| 1501 | + # call pybind11 extension for pow() function from OneMKL VM |
| 1502 | + return vmi._pow(sycl_queue, src1, src2, dst, depends) |
| 1503 | + return ti._pow(src1, src2, dst, sycl_queue, depends) |
| 1504 | + |
| 1505 | + |
| 1506 | +pow_func = BinaryElementwiseFunc( |
| 1507 | + "pow", ti._pow_result_type, _call_pow, _power_docstring_ |
| 1508 | +) |
| 1509 | + |
| 1510 | + |
| 1511 | +def dpnp_power(x1, x2, out=None, order="K"): |
| 1512 | + """ |
| 1513 | + Invokes pow() function from pybind11 extension of OneMKL VM if possible. |
| 1514 | +
|
| 1515 | + Otherwise fully relies on dpctl.tensor implementation for pow() function. |
| 1516 | + """ |
| 1517 | + |
| 1518 | + # dpctl.tensor only works with usm_ndarray or scalar |
| 1519 | + x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1) |
| 1520 | + x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) |
| 1521 | + out_usm = None if out is None else dpnp.get_usm_ndarray(out) |
| 1522 | + |
| 1523 | + res_usm = pow_func( |
| 1524 | + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order |
| 1525 | + ) |
| 1526 | + return dpnp_array._create_from_usm_ndarray(res_usm) |
| 1527 | + |
| 1528 | + |
1463 | 1529 | _proj_docstring = """
|
1464 | 1530 | proj(x, out=None, order="K")
|
1465 | 1531 |
|
|
0 commit comments