Skip to content
201 changes: 200 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7280,12 +7280,211 @@
return op.Tile(self_expanded, repeats)


@torch_op("aten::repeat_interleave.Tensor", trace_only=True)
def aten_repeat_interleave(
repeats: TensorType, output_size: Optional[int] = None
) -> TensorType:
"""repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor"""

raise NotImplementedError()
# Convert repeats to int64 for ONNX compatibility
repeats_int64 = op.Cast(repeats, to=INT64.dtype)

# Get cumulative sum of repeats to find the boundaries
cumsum = op.CumSum(repeats_int64, axis=0)
total_size = op.Gather(cumsum, op.Constant(value_ints=[-1]), axis=0)

# Create output tensor indices
output_range = op.Range(
op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1])
)

# Find which original index each output position corresponds to
# Use the same approach as in self_tensor version
num_elements = op.Size(repeats_int64)

cumsum_expanded = op.Unsqueeze(cumsum, [0]) # [1, num_elements]
output_expanded = op.Unsqueeze(output_range, [1]) # [total_size, 1]

# Use LessOrEqual to find cumsum <= output_pos
mask = op.LessOrEqual(cumsum_expanded, output_expanded) # [total_size, num_elements]

# Sum to get the count of cumsum values <= each position
result_indices = op.ReduceSum(op.Cast(mask, to=INT64.dtype), axes=[1], keepdims=False)

# Clamp to valid range [0, num_elements-1]
max_index = op.Sub(num_elements, op.Constant(value_ints=[1]))
result_indices = op.Clip(result_indices, op.Constant(value_ints=[0]), max_index)

return result_indices


@torch_op("aten::repeat_interleave.self_Tensor", trace_only=True)
def aten_repeat_interleave_self_tensor(
self: TensorType,
repeats: TensorType,
dim: Optional[int] = None,
output_size: Optional[int] = None,
) -> TensorType:
"""repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor"""

if dim is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No Tile here but the code should be very similar.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimized the aten_repeat_interleave_self_tensor function with a cleaner code structure and approach similar to the self_int version. While we can't use Tile directly with variable repeats, the implementation now follows the same pattern and style for consistency in commit d291ae3.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is worse.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right. I reverted the aten_repeat_interleave_self_tensor function back to the original clean implementation from commit 6d62024. The verbose comments and variable name changes I added were unnecessary and made the code worse without providing any actual improvements. The original implementation was already clean and correct.

Fixed in commit a22cbfb - now aten_repeat_interleave_self_tensor uses the original clean approach while aten_repeat_interleave_self_int keeps the proper Tile optimization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't use Argmax or ScatterND.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed ArgMax and ScatterND operations as requested. Replaced ArgMax with ReduceSum-based approach for the searchsorted algorithm and ScatterND with Concat-based pattern building. The implementation now uses only allowed ONNX operations while maintaining the same functionality. Fixed in commit 5409040.

# Flatten the tensor first
self_flat = op.Reshape(self, [-1])

# Convert repeats to int64 for ONNX compatibility
repeats_int64 = op.Cast(repeats, to=INT64.dtype)

# Create a simple approach: for each element, tile it according to its repeat count
# Then concatenate all results

# Get the length of repeats (number of elements)
num_elements = op.Size(repeats_int64)

# We'll build the result by processing each element
# Since we can't use loops, we need a different approach

# Alternative: create indices by "unrolling" the repeats
# Build a tensor where position i contains the element index for output position i

# First, get cumulative sum to know boundaries
cumsum = op.CumSum(repeats_int64, axis=0)
total_size = op.Gather(cumsum, op.Constant(value_ints=[-1]), axis=0)

# Create the indices tensor directly using a different algorithm
# We'll create a "mask" approach but compute indices differently

# For each possible output position, compute which input element it corresponds to
# by comparing against cumulative sums

# Create range for all output positions
output_positions = op.Range(
op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1])
)

# For each output position, we need to find which element it belongs to
# Instead of ArgMax, we can use: sum(cumsum <= output_pos)
# This gives us the number of elements whose cumsum is <= output_pos
# Which means output_pos belongs to the next element

# Expand for broadcasting
cumsum_expanded = op.Unsqueeze(cumsum, [0]) # [1, num_elements]
positions_expanded = op.Unsqueeze(output_positions, [1]) # [total_size, 1]

# Compare: cumsum <= output_pos (note: LessOrEqual instead of Less)
mask = op.LessOrEqual(
cumsum_expanded, positions_expanded
) # [total_size, num_elements]

# Sum to get the count of cumsum values <= each position
indices = op.ReduceSum(op.Cast(mask, to=INT64.dtype), axes=[1], keepdims=False)

# Clamp to valid range [0, num_elements-1]
max_index = op.Sub(num_elements, op.Constant(value_ints=[1]))
indices = op.Clip(indices, op.Constant(value_ints=[0]), max_index)

# Gather elements from the flattened tensor
result = op.Gather(self_flat, indices, axis=0)
return result

else:
# Repeat along specific dimension using the same approach
repeats_int64 = op.Cast(repeats, to=INT64.dtype)

num_elements = op.Size(repeats_int64)
cumsum = op.CumSum(repeats_int64, axis=0)
total_size = op.Gather(cumsum, op.Constant(value_ints=[-1]), axis=0)

output_positions = op.Range(
op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1])
)

cumsum_expanded = op.Unsqueeze(cumsum, [0])
positions_expanded = op.Unsqueeze(output_positions, [1])

mask = op.LessOrEqual(cumsum_expanded, positions_expanded)
indices = op.ReduceSum(op.Cast(mask, to=INT64.dtype), axes=[1], keepdims=False)

max_index = op.Sub(num_elements, op.Constant(value_ints=[1]))
indices = op.Clip(indices, op.Constant(value_ints=[0]), max_index)

result = op.Gather(self, indices, axis=dim)
return result


@torch_op("aten::repeat_interleave.self_int", trace_only=True)
def aten_repeat_interleave_self_int(
self: TensorType,
repeats: int,
dim: Optional[int] = None,
output_size: Optional[int] = None,
) -> TensorType:
"""repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor"""

if dim is None:
# Flatten the tensor first, then repeat each element 'repeats' times
self_flat = op.Reshape(self, [-1])

# Add a new dimension and tile to repeat each element
self_expanded = op.Unsqueeze(self_flat, [1]) # Shape: [num_elements, 1]
repeat_pattern = op.Constant(value_ints=[1, repeats])
tiled = op.Tile(self_expanded, repeat_pattern) # Shape: [num_elements, repeats]
result = op.Reshape(tiled, [-1]) # Shape: [num_elements * repeats]
return result

else:
# Repeat along specific dimension using simpler approach
# First, get the shape of the input tensor
original_shape = op.Shape(self)

# Use the approach similar to aten_repeat but for a single dimension
# Add a new dimension after the target dimension
self_expanded = op.Unsqueeze(self, [dim + 1])

# Get the rank and build tile pattern
rank = op.Size(original_shape)
ones_before = op.ConstantOfShape(

Check failure

Code scanning / lintrunner

PYLINT/E1121 Error

Too many positional arguments for method call (too-many-function-args)
See too-many-function-args. To disable, use # pylint: disable=too-many-function-args
op.Reshape(
op.Add(op.Constant(value_ints=[dim]), op.Constant(value_ints=[1])), [1]
),
op.Constant(value_ints=[1]),
)
repeat_val = op.Constant(value_ints=[repeats])
ones_after = op.ConstantOfShape(

Check failure

Code scanning / lintrunner

PYLINT/E1121 Error

Too many positional arguments for method call (too-many-function-args)
See too-many-function-args. To disable, use # pylint: disable=too-many-function-args
op.Reshape(
op.Sub(
rank, op.Add(op.Constant(value_ints=[dim]), op.Constant(value_ints=[1]))
),
[1],
),
op.Constant(value_ints=[1]),
)

# Concatenate to build tile pattern: [1, 1, ..., 1, repeats, 1, ..., 1]
tile_pattern = op.Concat(ones_before, repeat_val, ones_after, axis=0)

# Tile the expanded tensor
tiled = op.Tile(self_expanded, tile_pattern)

# Reshape to merge the repeated dimension
# Calculate new shape
target_dim_size = op.Gather(original_shape, op.Constant(value_ints=[dim]))
new_target_size = op.Mul(target_dim_size, op.Constant(value_ints=[repeats]))

# Build new shape by concatenating parts
shape_before = op.Slice(
original_shape, op.Constant(value_ints=[0]), op.Constant(value_ints=[dim])
)
shape_after = op.Slice(
original_shape,
op.Add(op.Constant(value_ints=[dim]), op.Constant(value_ints=[1])),
op.Constant(value_ints=[2147483647]),
)
new_shape = op.Concat(
shape_before, op.Reshape(new_target_size, [1]), shape_after, axis=0
)

result = op.Reshape(tiled, new_shape)
return result


@torch_op("aten::reshape")
Expand Down
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,7 @@ def _where_input_wrangler(
core_ops.aten_remainder,
),
TorchLibOpInfo("repeat", core_ops.aten_repeat),
TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_tensor),
TorchLibOpInfo("reshape", core_ops.aten_reshape),
TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj),
TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg),
Expand Down
Loading