Skip to content
Merged
110 changes: 106 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7292,12 +7292,114 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor:
return op.Tile(self_expanded, repeats)


def aten_repeat_interleave(
repeats: TensorType, output_size: Optional[int] = None
@torch_op("aten::repeat_interleave.self_int", trace_only=True)
def aten_repeat_interleave_self_int(
self: TensorType, repeats: int, dim: Optional[int] = None
) -> TensorType:
"""repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor"""
"""repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor

raise NotImplementedError()
The trick is to repeat in one direction orthogonal to reshape.

.. code-block:: python

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
x.repeat_interleave(2, dim=0)

is equivalent to:

.. code-block:: python

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
x.repeat((1, 2)).reshape((-1, t.shape[1]))
"""
if dim is None:
raise NotImplementedError("No conversion available yet when dim is None.")

self_rank = len(self.shape)
pos_dim = (dim + self_rank) % self_rank
unsqueezed = op.Unsqueeze(self, [pos_dim + 1])
tiles = [1] * (self_rank + 1)
tiles[pos_dim + 1] = repeats
tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype))
tiled = op.Tile(unsqueezed, tile_repeat)
if self_rank == 1:
return op.Identity(tiled)
final_shape = op.Concat(
op.Shape(self, start=0, end=dim),
op.Constant(value_ints=[-1]),
op.Shape(self, start=dim + 1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest using pos_dim instead of dim ... otherwise, dim+1 can cause problems when dim == -1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to test-cases for negative dim, including -1.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

axis=0,
)
return op.Reshape(tiled, final_shape)


@torch_op("aten::repeat_interleave.Tensor", trace_only=True)
def aten_repeat_interleave_Tensor(
self: TensorType, repeats: Optional[TensorType] = None, dim: Optional[int] = None
) -> TensorType:
"""repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor

When `repeats` is a tensor, each line is multiplied
by a different number.
There are multiple strategies. Here is one.

.. code-block:: python

import torch

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
times = torch.tensor([2, 3], dtype=torch.int64)
y = x.repeat_interleave(times, dim=0)
print("repeat_interleave")
print(y)

ci = times.cumsum(dim=0)
rows = torch.arange(ci[-1], dtype=torch.int64) < ci.reshape((-1, 1))
srows = times.shape[0] - rows.to(torch.int64).sum(axis=0)
indices = srows.reshape((-1, ))
print("decomposed")
print(x[indices, :])
"""
if repeats is None:
repeats = self
self = op.Range(0, op.Squeeze(op.Shape(repeats, start=-1), [0]), 1)
if dim is None:
# flatten
self = op.Reshape(self, [-1])
rk = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rk = 1
rank = 1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

else:
rk = len(self.shape)

if rk > 2:
shape_x0 = op.Shape(self, start=0, end=1)
shape_x = op.Shape(self, start=1)
self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0))
elif rk == 1:
shape_x = None
self = op.Reshape(self, [-1, 1])
else:
if rk != 2:
raise NotImplementedError(f"rank(self)={rk} not implemented for repeat_interleave")
shape_x = None

ci = op.CumSum(repeats, [0])
last_ci = op.Gather(ci, [-1])
trange = op.Range(0, op.Squeeze(last_ci, [0]), 1)
rows = op.Less(trange, op.Unsqueeze(ci, [-1]))
srows = op.Sub(
op.Shape(self, start=0, end=1),
op.ReduceSum(op.Cast(rows, to=INT64.dtype), [0]),
)
indices = op.Reshape(srows, [-1])
values = op.GatherND(self, op.Unsqueeze(indices, [-1]))
if rk == 2:
return values
# shape_x is None at this stage.
assert shape_x is None # for mypy
return op.Reshape(
values,
op.Concat([-1], shape_x, axis=0) if shape_x else [-1],
)


@torch_op("aten::reshape")
Expand Down
61 changes: 61 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,67 @@
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_integer_1(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.repeat_interleave(x, 3, dim=1)

onnx_program = torch.onnx.export(
Model(), (torch.randn(2, 3),), dynamo=True, optimize=False
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_integer_2(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.repeat_interleave(x, 3, dim=1)

onnx_program = torch.onnx.export(
Model(), (torch.randn(2, 3, 4),), dynamo=True, optimize=False
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_tensor(self):
class Model(torch.nn.Module):
def forward(self, x, ind):
return torch.repeat_interleave(x, ind, dim=0)

onnx_program = torch.onnx.export(
Model(),
(
torch.arange(6, dtype=torch.float32).reshape((2, 3)),
torch.tensor([1, 2], dtype=torch.int64),
),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_tensor_none(self):
class Model(torch.nn.Module):
def forward(self, x, ind):
return torch.repeat_interleave(x, ind)

inputs = (
torch.arange(4, dtype=torch.float32).reshape((2, 2)),
torch.tensor([1, 2, 3, 2], dtype=torch.int64),
)
onnx_program = torch.onnx.export(

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning test

This assignment to 'onnx_program' is unnecessary as it is
redefined
before this value is used.
Model(),
inputs,
dynamo=True,
optimize=False,
)
onnx_program = torch.onnx.export(
Model(),
inputs,
input_names=["x", "ind"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_sdpa_with_bool_attn_mask(self):
class ScaledDotProductAttention(torch.nn.Module):
def forward(self, query, key, value, attn_mask):
Expand Down
34 changes: 34 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,40 @@ def _where_input_wrangler(
core_ops.aten_remainder,
),
TorchLibOpInfo("repeat", core_ops.aten_repeat),
TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_int)
.skip(
matcher=lambda sample: not isinstance(sample.kwargs.get("repeats", None), int),
reason=("ignore cases when repeasts is a Tensor"),
)
.skip(
dtypes=(torch.bool,),
reason="bool not supported",
)
.skip(
matcher=lambda sample: sample.kwargs.get("dim") is None,
reason="fixme: conversion not implemented if dim is None",
)
.skip(
matcher=lambda sample: sample.input.numel() == 0,
reason="fixme: conversion not implemented when input tensor is empty",
),
TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Tensor)
.skip(
matcher=lambda sample: isinstance(sample.kwargs.get("repeats", None), int),
reason=("ignore cases when repeasts is an int"),
)
.skip(
dtypes=(torch.bool,),
reason="bool not supported",
)
.skip(
matcher=lambda sample: sample.kwargs.get("dim") is None,
reason="fixme: conversion not implemented if dim is None",
)
.skip(
matcher=lambda sample: sample.input.numel() == 0,
reason="fixme: conversion not implemented when input tensor is empty",
),
TorchLibOpInfo("reshape", core_ops.aten_reshape),
TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj),
TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg),
Expand Down
Loading