diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 92b8abb36d..36176ef20b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7280,12 +7280,211 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor: return op.Tile(self_expanded, repeats) +@torch_op("aten::repeat_interleave.Tensor", trace_only=True) def aten_repeat_interleave( repeats: TensorType, output_size: Optional[int] = None ) -> TensorType: """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor""" - raise NotImplementedError() + # Convert repeats to int64 for ONNX compatibility + repeats_int64 = op.Cast(repeats, to=INT64.dtype) + + # Get cumulative sum of repeats to find the boundaries + cumsum = op.CumSum(repeats_int64, axis=0) + total_size = op.Gather(cumsum, op.Constant(value_ints=[-1]), axis=0) + + # Create output tensor indices + output_range = op.Range( + op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1]) + ) + + # Find which original index each output position corresponds to + # Use the same approach as in self_tensor version + num_elements = op.Size(repeats_int64) + + cumsum_expanded = op.Unsqueeze(cumsum, [0]) # [1, num_elements] + output_expanded = op.Unsqueeze(output_range, [1]) # [total_size, 1] + + # Use LessOrEqual to find cumsum <= output_pos + mask = op.LessOrEqual(cumsum_expanded, output_expanded) # [total_size, num_elements] + + # Sum to get the count of cumsum values <= each position + result_indices = op.ReduceSum(op.Cast(mask, to=INT64.dtype), axes=[1], keepdims=False) + + # Clamp to valid range [0, num_elements-1] + max_index = op.Sub(num_elements, op.Constant(value_ints=[1])) + result_indices = op.Clip(result_indices, op.Constant(value_ints=[0]), max_index) + + return result_indices + + +@torch_op("aten::repeat_interleave.self_Tensor", trace_only=True) +def aten_repeat_interleave_self_tensor( + self: TensorType, + repeats: TensorType, + dim: Optional[int] = None, + output_size: Optional[int] = None, +) -> TensorType: + """repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor""" + + if dim is None: + # Flatten the tensor first + self_flat = op.Reshape(self, [-1]) + + # Convert repeats to int64 for ONNX compatibility + repeats_int64 = op.Cast(repeats, to=INT64.dtype) + + # Create a simple approach: for each element, tile it according to its repeat count + # Then concatenate all results + + # Get the length of repeats (number of elements) + num_elements = op.Size(repeats_int64) + + # We'll build the result by processing each element + # Since we can't use loops, we need a different approach + + # Alternative: create indices by "unrolling" the repeats + # Build a tensor where position i contains the element index for output position i + + # First, get cumulative sum to know boundaries + cumsum = op.CumSum(repeats_int64, axis=0) + total_size = op.Gather(cumsum, op.Constant(value_ints=[-1]), axis=0) + + # Create the indices tensor directly using a different algorithm + # We'll create a "mask" approach but compute indices differently + + # For each possible output position, compute which input element it corresponds to + # by comparing against cumulative sums + + # Create range for all output positions + output_positions = op.Range( + op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1]) + ) + + # For each output position, we need to find which element it belongs to + # Instead of ArgMax, we can use: sum(cumsum <= output_pos) + # This gives us the number of elements whose cumsum is <= output_pos + # Which means output_pos belongs to the next element + + # Expand for broadcasting + cumsum_expanded = op.Unsqueeze(cumsum, [0]) # [1, num_elements] + positions_expanded = op.Unsqueeze(output_positions, [1]) # [total_size, 1] + + # Compare: cumsum <= output_pos (note: LessOrEqual instead of Less) + mask = op.LessOrEqual( + cumsum_expanded, positions_expanded + ) # [total_size, num_elements] + + # Sum to get the count of cumsum values <= each position + indices = op.ReduceSum(op.Cast(mask, to=INT64.dtype), axes=[1], keepdims=False) + + # Clamp to valid range [0, num_elements-1] + max_index = op.Sub(num_elements, op.Constant(value_ints=[1])) + indices = op.Clip(indices, op.Constant(value_ints=[0]), max_index) + + # Gather elements from the flattened tensor + result = op.Gather(self_flat, indices, axis=0) + return result + + else: + # Repeat along specific dimension using the same approach + repeats_int64 = op.Cast(repeats, to=INT64.dtype) + + num_elements = op.Size(repeats_int64) + cumsum = op.CumSum(repeats_int64, axis=0) + total_size = op.Gather(cumsum, op.Constant(value_ints=[-1]), axis=0) + + output_positions = op.Range( + op.Constant(value_ints=[0]), total_size, op.Constant(value_ints=[1]) + ) + + cumsum_expanded = op.Unsqueeze(cumsum, [0]) + positions_expanded = op.Unsqueeze(output_positions, [1]) + + mask = op.LessOrEqual(cumsum_expanded, positions_expanded) + indices = op.ReduceSum(op.Cast(mask, to=INT64.dtype), axes=[1], keepdims=False) + + max_index = op.Sub(num_elements, op.Constant(value_ints=[1])) + indices = op.Clip(indices, op.Constant(value_ints=[0]), max_index) + + result = op.Gather(self, indices, axis=dim) + return result + + +@torch_op("aten::repeat_interleave.self_int", trace_only=True) +def aten_repeat_interleave_self_int( + self: TensorType, + repeats: int, + dim: Optional[int] = None, + output_size: Optional[int] = None, +) -> TensorType: + """repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor""" + + if dim is None: + # Flatten the tensor first, then repeat each element 'repeats' times + self_flat = op.Reshape(self, [-1]) + + # Add a new dimension and tile to repeat each element + self_expanded = op.Unsqueeze(self_flat, [1]) # Shape: [num_elements, 1] + repeat_pattern = op.Constant(value_ints=[1, repeats]) + tiled = op.Tile(self_expanded, repeat_pattern) # Shape: [num_elements, repeats] + result = op.Reshape(tiled, [-1]) # Shape: [num_elements * repeats] + return result + + else: + # Repeat along specific dimension using simpler approach + # First, get the shape of the input tensor + original_shape = op.Shape(self) + + # Use the approach similar to aten_repeat but for a single dimension + # Add a new dimension after the target dimension + self_expanded = op.Unsqueeze(self, [dim + 1]) + + # Get the rank and build tile pattern + rank = op.Size(original_shape) + ones_before = op.ConstantOfShape( + op.Reshape( + op.Add(op.Constant(value_ints=[dim]), op.Constant(value_ints=[1])), [1] + ), + op.Constant(value_ints=[1]), + ) + repeat_val = op.Constant(value_ints=[repeats]) + ones_after = op.ConstantOfShape( + op.Reshape( + op.Sub( + rank, op.Add(op.Constant(value_ints=[dim]), op.Constant(value_ints=[1])) + ), + [1], + ), + op.Constant(value_ints=[1]), + ) + + # Concatenate to build tile pattern: [1, 1, ..., 1, repeats, 1, ..., 1] + tile_pattern = op.Concat(ones_before, repeat_val, ones_after, axis=0) + + # Tile the expanded tensor + tiled = op.Tile(self_expanded, tile_pattern) + + # Reshape to merge the repeated dimension + # Calculate new shape + target_dim_size = op.Gather(original_shape, op.Constant(value_ints=[dim])) + new_target_size = op.Mul(target_dim_size, op.Constant(value_ints=[repeats])) + + # Build new shape by concatenating parts + shape_before = op.Slice( + original_shape, op.Constant(value_ints=[0]), op.Constant(value_ints=[dim]) + ) + shape_after = op.Slice( + original_shape, + op.Add(op.Constant(value_ints=[dim]), op.Constant(value_ints=[1])), + op.Constant(value_ints=[2147483647]), + ) + new_shape = op.Concat( + shape_before, op.Reshape(new_target_size, [1]), shape_after, axis=0 + ) + + result = op.Reshape(tiled, new_shape) + return result @torch_op("aten::reshape") diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 73ea68116c..62cecad00e 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1249,6 +1249,7 @@ def _where_input_wrangler( core_ops.aten_remainder, ), TorchLibOpInfo("repeat", core_ops.aten_repeat), + TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_tensor), TorchLibOpInfo("reshape", core_ops.aten_reshape), TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj), TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg),