Skip to content

Separation of implementations for aten::scatter.value and aten::scatter.src #2601

@linshokaku

Description

@linshokaku

@torch_op(("aten::scatter.value", "aten::scatter.src"), trace_only=True)
def aten_scatter(
self: TReal,
dim: int, # we have to use int here because ScatterElements() will use this attribute
index: TInt,
src: TReal,
) -> TReal:
"""scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""
update = op.Expand(src, op.Shape(index))
return op.ScatterElements(self, index, update, axis=dim)

Since aten::scatter.src never receives input cases requiring Expand, separating the implementation could simplify the resulting ONNX output.

>>> data = torch.zeros((3, 3)).float()
>>> indices = torch.tensor([[1, 0, 2], [0, 2, 1]]).long()
>>> torch.ops.aten.scatter.src(data, 0, indices, torch.ones(()).float())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Index tensor must have the same number of dimensions as src tensor
>>> torch.ops.aten.scatter.src(data, 0, indices, torch.ones((1, )).float())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Index tensor must have the same number of dimensions as src tensor
>>> torch.ops.aten.scatter.src(data, 0, indices, torch.ones((1, 1)).float())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected index [2, 3] to be no larger than self [3, 3] apart from dimension 0 and to be no larger size than src [1, 1]
>>> torch.ops.aten.scatter.src(data, 0, indices, torch.ones((2, 3)).float())
tensor([[1., 1., 0.],
        [1., 0., 1.],
        [0., 1., 1.]])

Metadata

Metadata

Assignees

No one assigned

    Labels

    contribution welcomeWe welcome code contributions for thismodule: torchlibRelated to the torch/aten function lib in development

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions