@@ -182,30 +182,45 @@ def prims_broadcast_in_dim(
182
182
) -> TensorType :
183
183
"""broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)"""
184
184
185
- # Get the shape of the input tensor
185
+ # Simplified approach that replaces ScatterElements with more basic operations
186
+ # while still leveraging compile-time knowledge of broadcast_dimensions
187
+
186
188
input_shape = op .Shape (a )
187
189
target_rank = op .Size (shape )
188
190
189
- # Create the intermediate shape by constructing it with the right dimensions
190
- # Start with a shape of all 1s
191
+ if not broadcast_dimensions :
192
+ # Special case: no broadcast dimensions - all target dims should be 1
193
+ ones = op .ConstantOfShape (op .Unsqueeze (target_rank , axes = [0 ]), value = op .Constant (value_int = 1 ))
194
+ reshaped = op .Reshape (a , ones )
195
+ return op .Expand (reshaped , shape )
196
+
197
+ # Build intermediate shape using a simpler approach than ScatterElements
198
+ # We'll construct it by concatenating the right values for each position
199
+
200
+ # Create base shape of all 1s
191
201
ones = op .ConstantOfShape (op .Unsqueeze (target_rank , axes = [0 ]), value = op .Constant (value_int = 1 ))
192
202
193
- # Since broadcast_dimensions is known at compile time , we can create the mapping directly
194
- # Convert broadcast_dimensions and input shape to tensors we can work with
195
- broadcast_dims_tensor = op . Constant ( value_ints = list ( broadcast_dimensions ))
203
+ # For each broadcast dimension , we'll replace the 1 with the actual input dimension
204
+ # Since broadcast_dimensions is compile-time known, we can do this with individual operations
205
+ intermediate_shape = ones
196
206
197
- # Scatter the input dimensions into the intermediate shape at the specified positions
198
- intermediate_shape = op .ScatterElements (
199
- ones ,
200
- op .Unsqueeze (broadcast_dims_tensor , axes = [0 ]),
201
- op .Unsqueeze (input_shape , axes = [0 ]),
202
- axis = 0
203
- )
207
+ for i , broadcast_dim in enumerate (broadcast_dimensions ):
208
+ # Get the input dimension value
209
+ input_dim_value = op .Gather (input_shape , op .Constant (value_int = i ))
210
+
211
+ # Create a one-hot mask for this position
212
+ indices = op .Range (op .Constant (value_int = 0 ), target_rank , op .Constant (value_int = 1 ))
213
+ mask = op .Equal (indices , op .Constant (value_int = broadcast_dim ))
214
+
215
+ # Use Where to replace the 1 with the input dimension value at this position
216
+ intermediate_shape = op .Where (
217
+ mask ,
218
+ op .Cast (input_dim_value , to = ir .TensorType .INT64 ),
219
+ intermediate_shape
220
+ )
204
221
205
- # Reshape the input tensor to the intermediate shape
222
+ # Reshape input to intermediate shape and expand to target
206
223
reshaped = op .Reshape (a , intermediate_shape )
207
-
208
- # Expand to the target shape
209
224
return op .Expand (reshaped , shape )
210
225
211
226
0 commit comments