diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ab992e0580..2e6bf9530c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7292,12 +7292,114 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor: return op.Tile(self_expanded, repeats) -def aten_repeat_interleave( - repeats: TensorType, output_size: Optional[int] = None +@torch_op("aten::repeat_interleave.self_int", trace_only=True) +def aten_repeat_interleave_self_int( + self: TensorType, repeats: int, dim: Optional[int] = None ) -> TensorType: - """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor""" + """repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor - raise NotImplementedError() + The trick is to repeat in one direction orthogonal to reshape. + + .. code-block:: python + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + x.repeat_interleave(2, dim=0) + + is equivalent to: + + .. code-block:: python + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + x.repeat((1, 2)).reshape((-1, t.shape[1])) + """ + if dim is None: + raise NotImplementedError("No conversion available yet when dim is None.") + + self_rank = len(self.shape) + pos_dim = (dim + self_rank) % self_rank + unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) + tiles = [1] * (self_rank + 1) + tiles[pos_dim + 1] = repeats + tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype)) + tiled = op.Tile(unsqueezed, tile_repeat) + if self_rank == 1: + return op.Identity(tiled) + final_shape = op.Concat( + op.Shape(self, start=0, end=dim), + op.Constant(value_ints=[-1]), + op.Shape(self, start=dim + 1), + axis=0, + ) + return op.Reshape(tiled, final_shape) + + +@torch_op("aten::repeat_interleave.Tensor", trace_only=True) +def aten_repeat_interleave_Tensor( + self: TensorType, repeats: Optional[TensorType] = None, dim: Optional[int] = None +) -> TensorType: + """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor + + When `repeats` is a tensor, each line is multiplied + by a different number. + There are multiple strategies. Here is one. + + .. code-block:: python + + import torch + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + times = torch.tensor([2, 3], dtype=torch.int64) + y = x.repeat_interleave(times, dim=0) + print("repeat_interleave") + print(y) + + ci = times.cumsum(dim=0) + rows = torch.arange(ci[-1], dtype=torch.int64) < ci.reshape((-1, 1)) + srows = times.shape[0] - rows.to(torch.int64).sum(axis=0) + indices = srows.reshape((-1, )) + print("decomposed") + print(x[indices, :]) + """ + if repeats is None: + repeats = self + self = op.Range(0, op.Squeeze(op.Shape(repeats, start=-1), [0]), 1) + if dim is None: + # flatten + self = op.Reshape(self, [-1]) + rk = 1 + else: + rk = len(self.shape) + + if rk > 2: + shape_x0 = op.Shape(self, start=0, end=1) + shape_x = op.Shape(self, start=1) + self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0)) + elif rk == 1: + shape_x = None + self = op.Reshape(self, [-1, 1]) + else: + if rk != 2: + raise NotImplementedError(f"rank(self)={rk} not implemented for repeat_interleave") + shape_x = None + + ci = op.CumSum(repeats, [0]) + last_ci = op.Gather(ci, [-1]) + trange = op.Range(0, op.Squeeze(last_ci, [0]), 1) + rows = op.Less(trange, op.Unsqueeze(ci, [-1])) + srows = op.Sub( + op.Shape(self, start=0, end=1), + op.ReduceSum(op.Cast(rows, to=INT64.dtype), [0]), + ) + indices = op.Reshape(srows, [-1]) + values = op.GatherND(self, op.Unsqueeze(indices, [-1])) + if rk == 2: + return values + # shape_x is None at this stage. + assert shape_x is None # for mypy + return op.Reshape( + values, + op.Concat([-1], shape_x, axis=0) if shape_x else [-1], + ) @torch_op("aten::reshape") diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index ab58bbc1a1..a0d0a0d880 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -76,6 +76,67 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) _testing.assert_onnx_program(onnx_program) + def test_repeat_interleave_integer_1(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.repeat_interleave(x, 3, dim=1) + + onnx_program = torch.onnx.export( + Model(), (torch.randn(2, 3),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_integer_2(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.repeat_interleave(x, 3, dim=1) + + onnx_program = torch.onnx.export( + Model(), (torch.randn(2, 3, 4),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_tensor(self): + class Model(torch.nn.Module): + def forward(self, x, ind): + return torch.repeat_interleave(x, ind, dim=0) + + onnx_program = torch.onnx.export( + Model(), + ( + torch.arange(6, dtype=torch.float32).reshape((2, 3)), + torch.tensor([1, 2], dtype=torch.int64), + ), + dynamo=True, + optimize=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_tensor_none(self): + class Model(torch.nn.Module): + def forward(self, x, ind): + return torch.repeat_interleave(x, ind) + + inputs = ( + torch.arange(4, dtype=torch.float32).reshape((2, 2)), + torch.tensor([1, 2, 3, 2], dtype=torch.int64), + ) + onnx_program = torch.onnx.export( + Model(), + inputs, + dynamo=True, + optimize=False, + ) + onnx_program = torch.onnx.export( + Model(), + inputs, + input_names=["x", "ind"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + def test_sdpa_with_bool_attn_mask(self): class ScaledDotProductAttention(torch.nn.Module): def forward(self, query, key, value, attn_mask): diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 7af7413185..01db7161b5 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1250,6 +1250,40 @@ def _where_input_wrangler( core_ops.aten_remainder, ), TorchLibOpInfo("repeat", core_ops.aten_repeat), + TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_int) + .skip( + matcher=lambda sample: not isinstance(sample.kwargs.get("repeats", None), int), + reason=("ignore cases when repeasts is a Tensor"), + ) + .skip( + dtypes=(torch.bool,), + reason="bool not supported", + ) + .skip( + matcher=lambda sample: sample.kwargs.get("dim") is None, + reason="fixme: conversion not implemented if dim is None", + ) + .skip( + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: conversion not implemented when input tensor is empty", + ), + TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Tensor) + .skip( + matcher=lambda sample: isinstance(sample.kwargs.get("repeats", None), int), + reason=("ignore cases when repeasts is an int"), + ) + .skip( + dtypes=(torch.bool,), + reason="bool not supported", + ) + .skip( + matcher=lambda sample: sample.kwargs.get("dim") is None, + reason="fixme: conversion not implemented if dim is None", + ) + .skip( + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: conversion not implemented when input tensor is empty", + ), TorchLibOpInfo("reshape", core_ops.aten_reshape), TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj), TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg),