-
Notifications
You must be signed in to change notification settings - Fork 90
Closed
Closed
Copy link
Labels
contribution welcomeWe welcome code contributions for thisWe welcome code contributions for thismodule: torchlibRelated to the torch/aten function lib in developmentRelated to the torch/aten function lib in development
Description
onnxscript/onnxscript/function_libs/torch_lib/ops/core.py
Lines 7739 to 7749 in 81f8444
| @torch_op(("aten::scatter.value", "aten::scatter.src"), trace_only=True) | |
| def aten_scatter( | |
| self: TReal, | |
| dim: int, # we have to use int here because ScatterElements() will use this attribute | |
| index: TInt, | |
| src: TReal, | |
| ) -> TReal: | |
| """scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor""" | |
| update = op.Expand(src, op.Shape(index)) | |
| return op.ScatterElements(self, index, update, axis=dim) |
Since aten::scatter.src never receives input cases requiring Expand, separating the implementation could simplify the resulting ONNX output.
>>> data = torch.zeros((3, 3)).float()
>>> indices = torch.tensor([[1, 0, 2], [0, 2, 1]]).long()
>>> torch.ops.aten.scatter.src(data, 0, indices, torch.ones(()).float())
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 723, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Index tensor must have the same number of dimensions as src tensor
>>> torch.ops.aten.scatter.src(data, 0, indices, torch.ones((1, )).float())
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 723, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Index tensor must have the same number of dimensions as src tensor
>>> torch.ops.aten.scatter.src(data, 0, indices, torch.ones((1, 1)).float())
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 723, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected index [2, 3] to be no larger than self [3, 3] apart from dimension 0 and to be no larger size than src [1, 1]
>>> torch.ops.aten.scatter.src(data, 0, indices, torch.ones((2, 3)).float())
tensor([[1., 1., 0.],
[1., 0., 1.],
[0., 1., 1.]])Metadata
Metadata
Assignees
Labels
contribution welcomeWe welcome code contributions for thisWe welcome code contributions for thismodule: torchlibRelated to the torch/aten function lib in developmentRelated to the torch/aten function lib in development