Skip to content

Commit 897345d

Browse files
authored
Separated implementation of aten::scatter overloads (#2605)
close #2601 #2602 This PR refactors the implementation of `aten::scatter` overloads, improving the clarity of the ONNX output generated by `aten::scatter.src.` I've also added new tests to verify the correctness of these changes. To make the added tests pass, I needed to also address the issue reported in #2602, which is included in this PR's diff. Signed-off-by: Linsho Kaku <[email protected]>
1 parent 30ae54b commit 897345d

File tree

3 files changed

+94
-5
lines changed

3 files changed

+94
-5
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7736,17 +7736,29 @@ def aten_scalar_tensor_sym_number(
77367736
return common_ops.cast_to(s, dtype=dtype)
77377737

77387738

7739-
@torch_op(("aten::scatter.value", "aten::scatter.src"), trace_only=True)
7740-
def aten_scatter(
7739+
@torch_op("aten::scatter.src", trace_only=True)
7740+
def aten_scatter_src(
77417741
self: TReal,
77427742
dim: int, # we have to use int here because ScatterElements() will use this attribute
77437743
index: TInt,
77447744
src: TReal,
77457745
) -> TReal:
7746-
"""scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""
7746+
"""scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""
7747+
return op.ScatterElements(self, index, src, axis=dim)
7748+
77477749

7748-
update = op.Expand(src, op.Shape(index))
7749-
return op.ScatterElements(self, index, update, axis=dim)
7750+
@torch_op("aten::scatter.value", trace_only=True)
7751+
def aten_scatter_value(
7752+
self: TReal,
7753+
dim: int, # we have to use int here because ScatterElements() will use this attribute
7754+
index: TInt,
7755+
value: TReal,
7756+
) -> TReal:
7757+
"""scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor"""
7758+
# Ensure value is a scalar tensor and expand it to match index shape
7759+
scalar_tensor = op.CastLike(value, self)
7760+
src = op.Expand(scalar_tensor, op.Shape(index))
7761+
return op.ScatterElements(self, index, src, axis=dim)
77507762

77517763

77527764
@torch_op("aten::scatter_add", trace_only=True)

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,65 @@ def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs)
13651365
yield opinfo_core.SampleInput(input_, args=(src, *args))
13661366

13671367

1368+
def sample_inputs_scatter_src(op_info, device, dtype, requires_grad, **kwargs):
1369+
del op_info
1370+
del kwargs
1371+
make_arg = functools.partial(
1372+
torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
1373+
)
1374+
1375+
# Basic test cases for scatter.src
1376+
cases = [
1377+
# (self_shape, index_shape, src_shape, dim)
1378+
((5, 5), (2, 3), (2, 3), 0), # 2D scatter on dim=0
1379+
((5, 5), (3, 2), (3, 2), 1), # 2D scatter on dim=1
1380+
((3, 4, 5), (2, 2, 3), (2, 2, 3), 0), # 3D scatter on dim=0
1381+
((3, 4, 5), (2, 2, 3), (2, 2, 3), 1), # 3D scatter on dim=1
1382+
((3, 4, 5), (2, 2, 3), (2, 2, 3), 2), # 3D scatter on dim=2
1383+
((10,), (3,), (3,), 0), # 1D scatter
1384+
]
1385+
1386+
for self_shape, index_shape, src_shape, dim in cases:
1387+
self_tensor = make_arg(self_shape)
1388+
# Create valid indices for the given dimension without duplication
1389+
index_buffer_shape = list(index_shape)
1390+
index_buffer_shape[dim] = self_shape[dim]
1391+
index_tensor = torch.rand(index_buffer_shape, device=device).argsort(dim=dim)[
1392+
tuple(slice(None, d, None) for d in index_shape)
1393+
]
1394+
src_tensor = make_arg(src_shape)
1395+
yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, src_tensor))
1396+
1397+
1398+
def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs):
1399+
del op_info
1400+
del kwargs
1401+
make_arg = functools.partial(
1402+
torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
1403+
)
1404+
1405+
# Basic test cases for scatter.value
1406+
cases = [
1407+
# (self_shape, index_shape, dim, value)
1408+
((5, 5), (2, 3), 0, 1.0), # 2D scatter on dim=0 with scalar value
1409+
((5, 5), (3, 2), 1, -2.5), # 2D scatter on dim=1 with scalar value
1410+
((3, 4, 5), (2, 2, 3), 0, 0.0), # 3D scatter on dim=0 with scalar value
1411+
((3, 4, 5), (2, 2, 3), 1, 3.14), # 3D scatter on dim=1 with scalar value
1412+
((3, 4, 5), (2, 2, 3), 2, -1.0), # 3D scatter on dim=2 with scalar value
1413+
((10,), (3,), 0, 5.0), # 1D scatter with scalar value
1414+
]
1415+
1416+
for self_shape, index_shape, dim, value in cases:
1417+
self_tensor = make_arg(self_shape)
1418+
# Create valid indices for the given dimension without duplication
1419+
index_buffer_shape = list(index_shape)
1420+
index_buffer_shape[dim] = self_shape[dim]
1421+
index_tensor = torch.rand(index_buffer_shape, device=device).argsort(dim=dim)[
1422+
tuple(slice(None, d, None) for d in index_shape)
1423+
]
1424+
yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, value))
1425+
1426+
13681427
def sample_inputs__scaled_dot_product_flash_attention(
13691428
op_info, device, dtype, requires_grad, **kwargs
13701429
):
@@ -2533,6 +2592,22 @@ def __init__(self):
25332592
sample_inputs_func=sample_inputs_slice_scatter,
25342593
supports_out=False,
25352594
),
2595+
opinfo_core.OpInfo(
2596+
"ops.aten.scatter.src",
2597+
op=torch.ops.aten.scatter.src,
2598+
aten_name="scatter.src",
2599+
dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool),
2600+
sample_inputs_func=sample_inputs_scatter_src,
2601+
supports_out=False,
2602+
),
2603+
opinfo_core.OpInfo(
2604+
"ops.aten.scatter.value",
2605+
op=torch.ops.aten.scatter.value,
2606+
aten_name="scatter.value",
2607+
dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool),
2608+
sample_inputs_func=sample_inputs_scatter_value,
2609+
supports_out=False,
2610+
),
25362611
opinfo_core.OpInfo(
25372612
"ops.aten._softmax",
25382613
op=torch.ops.aten._softmax, # pylint: disable=protected-access

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2108,6 +2108,8 @@ def _where_input_wrangler(
21082108
reason="onnxruntime does not support ml_dtypes.bfloat16",
21092109
),
21102110
TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter),
2111+
TorchLibOpInfo("ops.aten.scatter.src", core_ops.aten_scatter_src),
2112+
TorchLibOpInfo("ops.aten.scatter.value", core_ops.aten_scatter_value),
21112113
TorchLibOpInfo("slice", core_ops.aten_slice),
21122114
TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True),
21132115
TorchLibOpInfo(

0 commit comments

Comments
 (0)