Skip to content

Commit abfa422

Browse files
author
Vahid Tavanashad
committed
improve coverage
1 parent c4040d0 commit abfa422

File tree

3 files changed

+60
-49
lines changed

3 files changed

+60
-49
lines changed

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -830,8 +830,7 @@ def matmul(
830830
>>> np.dot(a, c).shape
831831
(9, 5, 7, 9, 5, 3)
832832
>>> np.matmul(a, c).shape
833-
(9, 5, 7, 3)
834-
>>> # n is 7, k is 4, m is 3
833+
(9, 5, 7, 3) # n is 7, k is 4, m is 3
835834
836835
Examples
837836
--------
@@ -872,7 +871,7 @@ def matmul(
872871
>>> np.matmul(x1, x2)
873872
array(-13+0j)
874873
875-
The ``@`` operator can be used as a shorthand for ``matmul`` on
874+
The ``@`` operator can be used as a shorthand for :obj:`dpnp.matmul` on
876875
:class:`dpnp.ndarray`.
877876
878877
>>> x1 @ x2
@@ -881,6 +880,7 @@ def matmul(
881880
"""
882881

883882
dpnp.check_limitations(subok_linalg=subok, signature=signature)
883+
dpnp.check_supported_arrays_type(x1, x2)
884884

885885
return dpnp_multiplication(
886886
"matmul",
@@ -913,7 +913,7 @@ def matvec(
913913
Matrix-vector dot product of two arrays.
914914
915915
Given a matrix (or stack of matrices) :math:`\mathbf{A}` in `x1` and
916-
a vector (or stack of vectors) :math:`\mathbf{v}` `x2`, the
916+
a vector (or stack of vectors) :math:`\mathbf{v}` in `x2`, the
917917
matrix-vector product is defined as:
918918
919919
.. math::
@@ -952,11 +952,11 @@ def matvec(
952952
953953
Default: ``"K"``.
954954
axes : {None, list of tuples}, optional
955-
A list of tuples with indices of axes the matrix product should operate
956-
on. For instance, for the signature of ``(i,j),(j)->(i)``, the base
957-
elements are 2d matrices and 1d vectors, where the matrices are assumed
958-
to be stored in the last two axes of the first argument, and the
959-
vectors in the last axis of the second argument. The corresponding
955+
A list of tuples with indices of axes the matrix-vector product should
956+
operate on. For instance, for the signature of ``(i,j),(j)->(i)``, the
957+
base elements are 2d matrices and 1d vectors, where the matrices are
958+
assumed to be stored in the last two axes of the first argument, and
959+
the vectors in the last axis of the second argument. The corresponding
960960
axes keyword would be ``[(-2, -1), (-1,), (-1,)]``.
961961
962962
Default: ``None``.
@@ -1000,6 +1000,7 @@ def matvec(
10001000
"""
10011001

10021002
dpnp.check_limitations(subok_linalg=subok, signature=signature)
1003+
dpnp.check_supported_arrays_type(x1, x2)
10031004

10041005
return dpnp_multiplication(
10051006
"matvec",
@@ -1333,7 +1334,7 @@ def vecdot(
13331334
13341335
Default: ``None``.
13351336
axes : {None, list of tuples}, optional
1336-
A list of tuples with indices of axes the matrix product should operate
1337+
A list of tuples with indices of axes the dot product should operate
13371338
on. For instance, for the signature of ``(i),(i)->()``, the base
13381339
elements are vectors and these are taken to be stored in the last axes
13391340
of each argument. The corresponding axes keyword would be
@@ -1414,7 +1415,7 @@ def vecmat(
14141415
Vector-matrix dot product of two arrays.
14151416
14161417
Given a vector (or stack of vector) :math:`\mathbf{v}` in `x1` and a matrix
1417-
(or stack of matrices) :math:`\mathbf{A}` `x2`, the vector-matrix product
1418+
(or stack of matrices) :math:`\mathbf{A}` in `x2`, the vector-matrix product
14181419
is defined as:
14191420
14201421
.. math::
@@ -1455,12 +1456,12 @@ def vecmat(
14551456
14561457
Default: ``"K"``.
14571458
axes : {None, list of tuples}, optional
1458-
A list of tuples with indices of axes the matrix product should operate
1459-
on. For instance, for the signature of ``(i),(i,j)->(j)``, the base
1460-
elements are 1D vectors and 2D matrices, where the vectors are assumed
1461-
to be stored in the last axis of the first argument, and the matrices
1462-
in the last two axes of the second argument. The corresponding axes
1463-
keyword would be ``[(-1,), (-2, -1), (-1,)]``.
1459+
A list of tuples with indices of axes the vector-matrix product should
1460+
operate on. For instance, for the signature of ``(i),(i,j)->(j)``, the
1461+
base elements are 1D vectors and 2D matrices, where the vectors are
1462+
assumed to be stored in the last axis of the first argument, and the
1463+
matrices in the last two axes of the second argument. The corresponding
1464+
axes keyword would be ``[(-1,), (-2, -1), (-1,)]``.
14641465
14651466
Default: ``None``.
14661467
@@ -1490,14 +1491,18 @@ def vecmat(
14901491
>>> v = np.array([0., 4., 2.])
14911492
>>> a = np.array([[1., 0., 0.],
14921493
... [0., 1., 0.],
1493-
... [0., 0., 1.],
1494-
... [0., 6., 8.]])
1494+
... [0., 0., 0.]])
14951495
>>> np.vecmat(v, a)
1496-
array([ 0., 4., 0.])
1496+
array([0., 4., 0.])
14971497
14981498
"""
14991499

15001500
dpnp.check_limitations(subok_linalg=subok, signature=signature)
1501+
dpnp.check_supported_arrays_type(x1, x2)
1502+
1503+
# cannot directly use dpnp.conj(x1) as it returns `int8` for boolean input
1504+
if dpnp.issubdtype(x1.dtype, dpnp.complexfloating):
1505+
x1 = dpnp.conj(x1)
15011506

15021507
return dpnp_multiplication(
15031508
"vecmat",

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -508,11 +508,9 @@ def _gemm_matmul(exec_q, x1, x2, res):
508508
def _shape_error(shape1, shape2, func, err_msg):
509509
"""Validate the shapes of input and output arrays."""
510510

511+
# func=None is applicable when err_msg == 2
511512
if func is not None:
512513
signature, _ = _get_signature(func)
513-
else:
514-
# applicable when err_msg == 2
515-
assert func is None
516514

517515
if err_msg == 0:
518516
raise ValueError(
@@ -569,7 +567,13 @@ def _validate_axes(x1, x2, axes, func):
569567
"""Check axes is valid for linear algebra functions."""
570568

571569
def _validate_internal(axes, op, ncores, ndim=None):
572-
if ncores == 1:
570+
if ncores == 0:
571+
if axes != ():
572+
raise AxisError(
573+
f"{func}: operand {op} has 0 core dimensions. "
574+
f"Axes item {op} should be an empty tuple."
575+
)
576+
elif ncores == 1:
573577
if isinstance(axes, int):
574578
axes = (axes,)
575579
elif not isinstance(axes, tuple):
@@ -616,32 +620,34 @@ def _validate_internal(axes, op, ncores, ndim=None):
616620
axes[0] = _validate_internal(axes[0], 0, x1_ncore, x1_ndim)
617621
axes[1] = _validate_internal(axes[1], 1, x2_ncore, x2_ndim)
618622

619-
if x1_ncore == 1 and x2_ncore == 1:
623+
if func == "vecdot":
620624
if len(axes) == 3:
621-
if axes[2] != ():
622-
raise AxisError(
623-
f"{func}: output has 0 core dimensions. "
624-
"Axes item 2 should be an empty tuple."
625-
)
626-
elif len(axes) != 2:
627-
raise ValueError(
628-
"Axes should be a list of three tuples: two inputs and one "
629-
"output. Entry for output can only be omitted if it does not "
630-
"have a core axis."
631-
)
625+
axes[2] = _validate_internal(axes[2], 2, 0)
626+
return axes
627+
628+
if len(axes) == 2:
629+
return [axes[0], axes[1], ()]
630+
631+
raise ValueError(
632+
"Axes should be a list of three tuples: two inputs and one "
633+
"output. Entry for output can only be omitted if it does not "
634+
"have a core axis."
635+
)
632636
else:
633637
if len(axes) != 3:
634638
raise ValueError(
635639
"Axes should be a list of three tuples: two inputs and one "
636640
"output; Entry for output can only be omitted if it does not "
637641
"have a core axis."
638642
)
639-
if x1_ncore == 1 or x2_ncore == 1:
643+
if x1_ncore == 1 and x2_ncore == 1:
644+
axes[2] = _validate_internal(axes[2], 2, 0)
645+
elif x1_ncore == 1 or x2_ncore == 1:
640646
axes[2] = _validate_internal(axes[2], 2, 1)
641647
else:
642648
axes[2] = _validate_internal(axes[2], 2, 2)
643649

644-
return axes
650+
return axes
645651

646652

647653
def _validate_out_array(out, exec_q):
@@ -829,14 +835,9 @@ def dpnp_multiplication(
829835
830836
"""
831837

832-
dpnp.check_supported_arrays_type(x1, x2)
833838
res_usm_type, exec_q = get_usm_allocations([x1, x2])
834839
_validate_out_array(out, exec_q)
835840

836-
if func == "vecmat":
837-
if dpnp.issubdtype(x1.dtype, dpnp.complexfloating):
838-
x1 = dpnp.conj(x1)
839-
840841
if order in "aA":
841842
if x1.flags.fnc and x2.flags.fnc:
842843
order = "F"
@@ -894,7 +895,7 @@ def dpnp_multiplication(
894895
)
895896

896897
if axes is not None:
897-
# Now that we have result array shape, check axes is within range
898+
# Now that result array shape is calculated, check axes is within range
898899
axes_res = normalize_axis_tuple(axes_res, len(result_shape), "axes")
899900

900901
# Determine the appropriate data types

dpnp/tests/test_product.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ def setup_method(self):
674674
((7, 4, 3), (0, 7, 3, 5)),
675675
],
676676
)
677-
def test_matmul(self, order1, order2, shape1, shape2):
677+
def test_basic(self, order1, order2, shape1, shape2):
678678
# input should be float type otherwise they are copied to c-contigous array
679679
# so testing order becomes meaningless
680680
dtype = dpnp.default_float_type()
@@ -701,7 +701,7 @@ def test_matmul(self, order1, order2, shape1, shape2):
701701
"((6, 7, 4, 3), (6, 7, 3, 5))",
702702
],
703703
)
704-
def test_matmul_bool(self, shape1, shape2):
704+
def test_bool(self, shape1, shape2):
705705
x = numpy.arange(2, dtype=numpy.bool_)
706706
a = numpy.resize(x, numpy.prod(shape1)).reshape(shape1)
707707
b = numpy.resize(x, numpy.prod(shape2)).reshape(shape2)
@@ -766,7 +766,7 @@ def test_axes_1D_ND(self, func, axes):
766766
expected = getattr(numpy, func)(a, b, axes=axes)
767767
assert_dtype_allclose(result, expected)
768768

769-
def test_matmul_axes_1D_1D(self):
769+
def test_axes_1D_1D(self):
770770
a = numpy.arange(3)
771771
ia = dpnp.array(a)
772772

@@ -1417,7 +1417,7 @@ def test_invalid_axes(self, xp):
14171417
assert_raises(AxisError, xp.matmul, a, a, axes=axes)
14181418

14191419
# axes should be a list of three tuples
1420-
axes = [0, 0, 0, 0]
1420+
axes = [0, 0]
14211421
assert_raises(ValueError, xp.matmul, a, a, axes=axes)
14221422

14231423
a = xp.arange(3 * 4 * 5).reshape(3, 4, 5)
@@ -1940,7 +1940,7 @@ def test_axis2(self):
19401940
[(1,), (1,), ()],
19411941
[(0), (0), ()],
19421942
[0, 1, ()],
1943-
[-2, -1, ()],
1943+
[-2, -1],
19441944
],
19451945
)
19461946
def test_axes(self, axes):
@@ -2145,6 +2145,11 @@ def test_error(self, xp):
21452145
a = xp.ones((5, 5))
21462146
assert_raises(TypeError, xp.vecdot, a, a, axes=[0, 0, ()], axis=-1)
21472147

2148+
# axes should be a list of three tuples
2149+
a = xp.ones(5)
2150+
axes = [0, 0, 0, 0]
2151+
assert_raises(ValueError, xp.vecdot, a, a, axes=axes)
2152+
21482153

21492154
@testing.with_requires("numpy>=2.2")
21502155
class TestVecmat:

0 commit comments

Comments
 (0)