-
Notifications
You must be signed in to change notification settings - Fork 93
Implements repeat_interleave #2477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2477 +/- ##
==========================================
- Coverage 70.00% 69.90% -0.11%
==========================================
Files 215 215
Lines 25992 26035 +43
Branches 2606 2614 +8
==========================================
+ Hits 18196 18199 +3
- Misses 6896 6936 +40
Partials 900 900 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements the repeat_interleave operation for the torch library, adding support for both scalar and tensor variants. The implementation handles different cases where the repeat count is either an integer or a tensor, with optional dimension specification.
- Adds two new functions:
aten_repeat_interleave_intfor scalar repeats andaten_repeat_interleave_Tensorfor tensor repeats - Includes comprehensive end-to-end tests covering integer repeats, tensor repeats, and tensor repeats with no dimension specified
- Comments out the TorchLibOpInfo entries temporarily with a note about splitting into two cases
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| onnxscript/function_libs/torch_lib/ops/core.py | Implements the core repeat_interleave functionality with two variants for scalar and tensor inputs |
| tests/function_libs/torch_lib/e2e_ops_tests.py | Adds three comprehensive test cases covering different repeat_interleave scenarios |
| tests/function_libs/torch_lib/ops_test_data.py | Comments out TorchLibOpInfo entries with explanation about splitting into separate cases |
Signed-off-by: xadupre <[email protected]>
Signed-off-by: xadupre <[email protected]>
| torch.arange(4, dtype=torch.float32).reshape((2, 2)), | ||
| torch.tensor([1, 2, 3, 2], dtype=torch.int64), | ||
| ) | ||
| onnx_program = torch.onnx.export( |
Check warning
Code scanning / CodeQL
Variable defined multiple times Warning test
redefined
|
I think we can simplify the index computation logic like previous comments suggested. |
|
And then Expand can be leveraged to simplify the graph |
Co-authored-by: Justin Chu <[email protected]>
| final_shape = op.Concat( | ||
| op.Shape(self, start=0, end=dim), | ||
| op.Constant(value_ints=[-1]), | ||
| op.Shape(self, start=dim + 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest using pos_dim instead of dim ... otherwise, dim+1 can cause problems when dim == -1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to test-cases for negative dim, including -1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| if dim is None: | ||
| # flatten | ||
| self = op.Reshape(self, [-1]) | ||
| rk = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| rk = 1 | |
| rank = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Similar to #2464. Does not support all the cases but we can add them in other PRs.