diff --git a/onnxscript/function_libs/torch_lib/ops/prims.py b/onnxscript/function_libs/torch_lib/ops/prims.py index ed870b0d7..af41c246e 100644 --- a/onnxscript/function_libs/torch_lib/ops/prims.py +++ b/onnxscript/function_libs/torch_lib/ops/prims.py @@ -14,7 +14,7 @@ from typing import Optional, Sequence -from onnxscript import INT64 +from onnxscript import INT64, ir from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import RealType, TTensor @@ -176,12 +176,52 @@ def prims_bitwise_xor(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("prims::broadcast_in_dim", trace_only=True) def prims_broadcast_in_dim( a: TensorType, shape: INT64, broadcast_dimensions: Sequence[int] ) -> TensorType: """broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)""" - - raise NotImplementedError() + + # Simplified approach that replaces ScatterElements with more basic operations + # while still leveraging compile-time knowledge of broadcast_dimensions + + input_shape = op.Shape(a) + target_rank = len(shape) + + if not broadcast_dimensions: + # Special case: no broadcast dimensions - all target dims should be 1 + ones = op.ConstantOfShape(op.Constant(value_ints=[target_rank]), value=op.Constant(value_int=1)) + reshaped = op.Reshape(a, ones) + return op.Expand(reshaped, shape) + + # Build intermediate shape using a simpler approach than ScatterElements + # We'll construct it by concatenating the right values for each position + + # Create base shape of all 1s + ones = [1] * target_rank + + # For each broadcast dimension, we'll replace the 1 with the actual input dimension + # Since broadcast_dimensions is compile-time known, we can do this with individual operations + intermediate_shape = ones + + for i, broadcast_dim in enumerate(broadcast_dimensions): + # Get the input dimension value + input_dim_value = op.Gather(input_shape, op.Constant(value_int=i)) + + # Create a one-hot mask for this position + indices = op.Range(op.Constant(value_int=0), op.Constant(value_int=target_rank), op.Constant(value_int=1)) + mask = op.Equal(indices, op.Constant(value_int=broadcast_dim)) + + # Use Where to replace the 1 with the input dimension value at this position + intermediate_shape = op.Where( + mask, + op.Cast(input_dim_value, to=ir.TensorType.INT64), + intermediate_shape + ) + + # Reshape input to intermediate shape and expand to target + reshaped = op.Reshape(a, intermediate_shape) + return op.Expand(reshaped, shape) def prims_cat(tensors: Sequence[TensorType], dim: int) -> TensorType: