Skip to content

Commit 5c35fe1

Browse files
Copilotjustinchuby
andcommitted
Simplify broadcast_in_dim to avoid ScatterElements using Where operations
Co-authored-by: justinchuby <[email protected]>
1 parent 6567cfb commit 5c35fe1

File tree

1 file changed

+31
-16
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+31
-16
lines changed

onnxscript/function_libs/torch_lib/ops/prims.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -182,30 +182,45 @@ def prims_broadcast_in_dim(
182182
) -> TensorType:
183183
"""broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)"""
184184

185-
# Get the shape of the input tensor
185+
# Simplified approach that replaces ScatterElements with more basic operations
186+
# while still leveraging compile-time knowledge of broadcast_dimensions
187+
186188
input_shape = op.Shape(a)
187189
target_rank = op.Size(shape)
188190

189-
# Create the intermediate shape by constructing it with the right dimensions
190-
# Start with a shape of all 1s
191+
if not broadcast_dimensions:
192+
# Special case: no broadcast dimensions - all target dims should be 1
193+
ones = op.ConstantOfShape(op.Unsqueeze(target_rank, axes=[0]), value=op.Constant(value_int=1))
194+
reshaped = op.Reshape(a, ones)
195+
return op.Expand(reshaped, shape)
196+
197+
# Build intermediate shape using a simpler approach than ScatterElements
198+
# We'll construct it by concatenating the right values for each position
199+
200+
# Create base shape of all 1s
191201
ones = op.ConstantOfShape(op.Unsqueeze(target_rank, axes=[0]), value=op.Constant(value_int=1))
192202

193-
# Since broadcast_dimensions is known at compile time, we can create the mapping directly
194-
# Convert broadcast_dimensions and input shape to tensors we can work with
195-
broadcast_dims_tensor = op.Constant(value_ints=list(broadcast_dimensions))
203+
# For each broadcast dimension, we'll replace the 1 with the actual input dimension
204+
# Since broadcast_dimensions is compile-time known, we can do this with individual operations
205+
intermediate_shape = ones
196206

197-
# Scatter the input dimensions into the intermediate shape at the specified positions
198-
intermediate_shape = op.ScatterElements(
199-
ones,
200-
op.Unsqueeze(broadcast_dims_tensor, axes=[0]),
201-
op.Unsqueeze(input_shape, axes=[0]),
202-
axis=0
203-
)
207+
for i, broadcast_dim in enumerate(broadcast_dimensions):
208+
# Get the input dimension value
209+
input_dim_value = op.Gather(input_shape, op.Constant(value_int=i))
210+
211+
# Create a one-hot mask for this position
212+
indices = op.Range(op.Constant(value_int=0), target_rank, op.Constant(value_int=1))
213+
mask = op.Equal(indices, op.Constant(value_int=broadcast_dim))
214+
215+
# Use Where to replace the 1 with the input dimension value at this position
216+
intermediate_shape = op.Where(
217+
mask,
218+
op.Cast(input_dim_value, to=ir.TensorType.INT64),
219+
intermediate_shape
220+
)
204221

205-
# Reshape the input tensor to the intermediate shape
222+
# Reshape input to intermediate shape and expand to target
206223
reshaped = op.Reshape(a, intermediate_shape)
207-
208-
# Expand to the target shape
209224
return op.Expand(reshaped, shape)
210225

211226

0 commit comments

Comments
 (0)