Skip to content

Conversation

@linshokaku
Copy link
Contributor

@linshokaku linshokaku commented Oct 6, 2025

close #2601 #2602

This PR refactors the implementation of aten::scatter overloads, improving the clarity of the ONNX output generated by aten::scatter.src.
I've also added new tests to verify the correctness of these changes. To make the added tests pass, I needed to also address the issue reported in #2602, which is included in this PR's diff.

… for the value argument in aten::scatter.value

Signed-off-by: Linsho Kaku <[email protected]>
@linshokaku
Copy link
Contributor Author

@microsoft-github-policy-service agree

@github-project-automation github-project-automation bot moved this from Todo to Done in ONNX Script Review Board Oct 6, 2025
@codecov
Copy link

codecov bot commented Oct 6, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 69.95%. Comparing base (30ae54b) to head (26ba21e).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2605      +/-   ##
==========================================
+ Coverage   69.94%   69.95%   +0.01%     
==========================================
  Files         222      222              
  Lines       26307    26311       +4     
  Branches     2604     2604              
==========================================
+ Hits        18400    18406       +6     
+ Misses       6995     6993       -2     
  Partials      912      912              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Oct 6, 2025
@justinchuby
Copy link
Collaborator

Thanks for the high quality contribution!

@justinchuby justinchuby merged commit 897345d into microsoft:main Oct 6, 2025
32 checks passed
def aten_scatter(
@torch_op("aten::scatter.src", trace_only=True)
def aten_scatter_src(
self: TReal,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I just noticed:

Suggested change
self: TReal,
self: TensorType,

We should update the accepted type to make it more permissive. Could you create a follow PR?

return op.ScatterElements(self, index, update, axis=dim)
@torch_op("aten::scatter.value", trace_only=True)
def aten_scatter_value(
self: TReal,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self: TReal,
self: TensorType,

self: TReal,
dim: int, # we have to use int here because ScatterElements() will use this attribute
index: TInt,
value: TReal,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
value: TReal,
value: float,

Comment on lines +7759 to +7760
scalar_tensor = op.CastLike(value, self)
src = op.Expand(scalar_tensor, op.Shape(index))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be simplified:

Suggested change
scalar_tensor = op.CastLike(value, self)
src = op.Expand(scalar_tensor, op.Shape(index))
scalar_tensor = ir.tensor([value], dtype=self.dtype)
src = op.ConstantOfShape(op.Shape(index), value=scalar_tensor)

@justinchuby
Copy link
Collaborator

@linshokaku I added some suggestions after the PR was merged. Could you create a follow PR to address them? Thanks a lot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

Separation of implementations for aten::scatter.value and aten::scatter.src

3 participants