-
Notifications
You must be signed in to change notification settings - Fork 93
Implements repeat_interleave #2477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7424f69
05c9062
9560baa
689a456
9881f6e
c442d38
84e297e
edd67f1
0e590a4
00ca91a
9586e57
7cc457a
b343329
1343c54
7043030
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -7292,12 +7292,114 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor: | |||||
| return op.Tile(self_expanded, repeats) | ||||||
|
|
||||||
|
|
||||||
| def aten_repeat_interleave( | ||||||
| repeats: TensorType, output_size: Optional[int] = None | ||||||
| @torch_op("aten::repeat_interleave.self_int", trace_only=True) | ||||||
| def aten_repeat_interleave_self_int( | ||||||
| self: TensorType, repeats: int, dim: Optional[int] = None | ||||||
| ) -> TensorType: | ||||||
| """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor""" | ||||||
| """repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor | ||||||
|
|
||||||
| raise NotImplementedError() | ||||||
| The trick is to repeat in one direction orthogonal to reshape. | ||||||
|
|
||||||
| .. code-block:: python | ||||||
|
|
||||||
| x = torch.tensor([[0, 1, 2], [3, 4, 5]]) | ||||||
| x.repeat_interleave(2, dim=0) | ||||||
|
|
||||||
| is equivalent to: | ||||||
|
|
||||||
| .. code-block:: python | ||||||
|
|
||||||
| x = torch.tensor([[0, 1, 2], [3, 4, 5]]) | ||||||
| x.repeat((1, 2)).reshape((-1, t.shape[1])) | ||||||
| """ | ||||||
| if dim is None: | ||||||
| raise NotImplementedError("No conversion available yet when dim is None.") | ||||||
|
|
||||||
| self_rank = len(self.shape) | ||||||
| pos_dim = (dim + self_rank) % self_rank | ||||||
| unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) | ||||||
| tiles = [1] * (self_rank + 1) | ||||||
| tiles[pos_dim + 1] = repeats | ||||||
| tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype)) | ||||||
| tiled = op.Tile(unsqueezed, tile_repeat) | ||||||
| if self_rank == 1: | ||||||
| return op.Identity(tiled) | ||||||
| final_shape = op.Concat( | ||||||
| op.Shape(self, start=0, end=dim), | ||||||
| op.Constant(value_ints=[-1]), | ||||||
| op.Shape(self, start=dim + 1), | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest using
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be good to test-cases for negative dim, including -1.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| axis=0, | ||||||
| ) | ||||||
| return op.Reshape(tiled, final_shape) | ||||||
|
|
||||||
|
|
||||||
| @torch_op("aten::repeat_interleave.Tensor", trace_only=True) | ||||||
| def aten_repeat_interleave_Tensor( | ||||||
| self: TensorType, repeats: Optional[TensorType] = None, dim: Optional[int] = None | ||||||
| ) -> TensorType: | ||||||
| """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor | ||||||
|
|
||||||
| When `repeats` is a tensor, each line is multiplied | ||||||
| by a different number. | ||||||
| There are multiple strategies. Here is one. | ||||||
|
|
||||||
| .. code-block:: python | ||||||
|
|
||||||
| import torch | ||||||
|
|
||||||
| x = torch.tensor([[0, 1, 2], [3, 4, 5]]) | ||||||
| times = torch.tensor([2, 3], dtype=torch.int64) | ||||||
| y = x.repeat_interleave(times, dim=0) | ||||||
| print("repeat_interleave") | ||||||
| print(y) | ||||||
|
|
||||||
| ci = times.cumsum(dim=0) | ||||||
| rows = torch.arange(ci[-1], dtype=torch.int64) < ci.reshape((-1, 1)) | ||||||
| srows = times.shape[0] - rows.to(torch.int64).sum(axis=0) | ||||||
| indices = srows.reshape((-1, )) | ||||||
| print("decomposed") | ||||||
| print(x[indices, :]) | ||||||
| """ | ||||||
| if repeats is None: | ||||||
| repeats = self | ||||||
| self = op.Range(0, op.Squeeze(op.Shape(repeats, start=-1), [0]), 1) | ||||||
| if dim is None: | ||||||
| # flatten | ||||||
| self = op.Reshape(self, [-1]) | ||||||
| rk = 1 | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| else: | ||||||
| rk = len(self.shape) | ||||||
|
|
||||||
| if rk > 2: | ||||||
| shape_x0 = op.Shape(self, start=0, end=1) | ||||||
| shape_x = op.Shape(self, start=1) | ||||||
| self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0)) | ||||||
| elif rk == 1: | ||||||
| shape_x = None | ||||||
| self = op.Reshape(self, [-1, 1]) | ||||||
| else: | ||||||
| if rk != 2: | ||||||
| raise NotImplementedError(f"rank(self)={rk} not implemented for repeat_interleave") | ||||||
xadupre marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| shape_x = None | ||||||
|
|
||||||
| ci = op.CumSum(repeats, [0]) | ||||||
| last_ci = op.Gather(ci, [-1]) | ||||||
| trange = op.Range(0, op.Squeeze(last_ci, [0]), 1) | ||||||
| rows = op.Less(trange, op.Unsqueeze(ci, [-1])) | ||||||
| srows = op.Sub( | ||||||
| op.Shape(self, start=0, end=1), | ||||||
| op.ReduceSum(op.Cast(rows, to=INT64.dtype), [0]), | ||||||
| ) | ||||||
| indices = op.Reshape(srows, [-1]) | ||||||
| values = op.GatherND(self, op.Unsqueeze(indices, [-1])) | ||||||
| if rk == 2: | ||||||
| return values | ||||||
| # shape_x is None at this stage. | ||||||
| assert shape_x is None # for mypy | ||||||
| return op.Reshape( | ||||||
| values, | ||||||
| op.Concat([-1], shape_x, axis=0) if shape_x else [-1], | ||||||
|
||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| @torch_op("aten::reshape") | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.