-
Notifications
You must be signed in to change notification settings - Fork 88
Separated implementation of aten::scatter overloads #2605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… for the value argument in aten::scatter.value Signed-off-by: Linsho Kaku <[email protected]>
|
@microsoft-github-policy-service agree |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2605 +/- ##
==========================================
+ Coverage 69.94% 69.95% +0.01%
==========================================
Files 222 222
Lines 26307 26311 +4
Branches 2604 2604
==========================================
+ Hits 18400 18406 +6
+ Misses 6995 6993 -2
Partials 912 912 ☔ View full report in Codecov by Sentry. |
|
Thanks for the high quality contribution! |
| def aten_scatter( | ||
| @torch_op("aten::scatter.src", trace_only=True) | ||
| def aten_scatter_src( | ||
| self: TReal, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing I just noticed:
| self: TReal, | |
| self: TensorType, |
We should update the accepted type to make it more permissive. Could you create a follow PR?
| return op.ScatterElements(self, index, update, axis=dim) | ||
| @torch_op("aten::scatter.value", trace_only=True) | ||
| def aten_scatter_value( | ||
| self: TReal, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self: TReal, | |
| self: TensorType, |
| self: TReal, | ||
| dim: int, # we have to use int here because ScatterElements() will use this attribute | ||
| index: TInt, | ||
| value: TReal, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| value: TReal, | |
| value: float, |
| scalar_tensor = op.CastLike(value, self) | ||
| src = op.Expand(scalar_tensor, op.Shape(index)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be simplified:
| scalar_tensor = op.CastLike(value, self) | |
| src = op.Expand(scalar_tensor, op.Shape(index)) | |
| scalar_tensor = ir.tensor([value], dtype=self.dtype) | |
| src = op.ConstantOfShape(op.Shape(index), value=scalar_tensor) |
|
@linshokaku I added some suggestions after the PR was merged. Could you create a follow PR to address them? Thanks a lot. |
…lue (#2612) follow #2605 --------- Signed-off-by: Linsho Kaku <[email protected]>
close #2601 #2602
This PR refactors the implementation of
aten::scatteroverloads, improving the clarity of the ONNX output generated byaten::scatter.src.I've also added new tests to verify the correctness of these changes. To make the added tests pass, I needed to also address the issue reported in #2602, which is included in this PR's diff.