Skip to content

cuda_tensor.norm(dim=(X, Y)) is broken #30704

@ssnl

Description

@ssnl
In [1]: torch.__version__
Out[1]: '1.4.0a0+91c6d2e'

In [2]: torch.randn(3, 3, device='cuda').norm(dim=(0, 1))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-2-32861c8633ea> in <module>
----> 1 torch.randn(3, 3, device='cuda').norm(dim=(0, 1))

/data/packages/pytorch/torch/tensor.py in norm(self, p, dim, keepdim, dtype)
    337     def norm(self, p="fro", dim=None, keepdim=False, dtype=None):
    338         r"""See :func:`torch.norm`"""
--> 339         return torch.norm(self, p, dim, keepdim, dtype=dtype)
    340
    341     def lu(self, pivot=True, get_infos=False):

/data/packages/pytorch/torch/functional.py in norm(input, p, dim, keepdim, out, dtype)
    780             dim = tuple(range(ndim))
    781         if out is None:
--> 782             return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim)
    783         return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim, out=out)
    784     elif p == "nuc":

RuntimeError: Could not run 'aten::conj.out' with arguments from the 'CUDATensorId' backend. 'aten::conj.out' is only available for these backends: [CPUTensorId, Va
riableTensorId].

In [3]: torch.randn(3, 3, device='cuda').norm()
Out[3]: tensor(3.6015, device='cuda:0')

In [4]: torch.randn(3, 3).norm(dim=(0, 1))
Out[4]: tensor(3.5681)

In [5]: torch.randn(3, 3).norm()
Out[5]: tensor(2.7136)

Yet passing in a 2-tuple is perfectly valid, per the torch.norm doc:

dim (int, 2-tuple of python:ints, 2-list of python:ints, optional) – If it is an int, vector norm will be calculated, if it is 2-tuple of ints, matrix norm will be calculated. If the value is None, matrix norm will be calculated when the input tensor only has two dimensions, vector norm will be calculated when the input tensor only has one dimension. If the input tensor has more than two dimensions, the vector norm will be applied to last dimension.

cc @ezyang @gchanan @zou3519 @anjali411 @dylanbespalko

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: complexRelated to complex number support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions