Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,33 @@ def prims_bitwise_xor(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("prims::broadcast_in_dim", trace_only=True)
def prims_broadcast_in_dim(
a: TensorType, shape: INT64, broadcast_dimensions: Sequence[int]
a: TensorType, shape: Sequence[INT64], broadcast_dimensions: Sequence[int]
) -> TensorType:
"""broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)"""

raise NotImplementedError()
target_rank = len(shape)

if not broadcast_dimensions:
# Special case: no broadcast dimensions - all target dims should be 1
return op.Expand(a, common_ops.merge_dims(shape))

# Create base shape of all 1s
ones = [1] * target_rank

# For each broadcast dimension, we'll replace the 1 with the actual input dimension
# Since broadcast_dimensions is compile-time known, we can do this with individual operations
intermediate_shape = ones

for i, broadcast_dim in enumerate(broadcast_dimensions):
# Get the input dimension value
input_dim_value = op.Shape(a, start=i, end=i + 1)
intermediate_shape[broadcast_dim] = input_dim_value

# Reshape input to intermediate shape and expand to target
reshaped = op.Reshape(a, common_ops.merge_dims(intermediate_shape))
return op.Expand(reshaped, shape)


def prims_cat(tensors: Sequence[TensorType], dim: int) -> TensorType:
Expand Down
36 changes: 36 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,35 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra
yield opinfo_core.SampleInput(t, kwargs={"p": p})


def sample_inputs_broadcast_in_dim(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

# cases: (input_shape, target_shape, broadcast_dimensions)
# broadcast_dimensions maps each input dim to an axis in target_shape
cases = (
# scalar -> 1-D tensor
((), (3,), ()),
# identity (no-op broadcast)
((3,), (3,), (0,)),
# rank-preserving broadcast where singleton dims expand
((1, 3, 1), (2, 3, 4), (0, 1, 2)),
# input rank 2 -> output rank 3, input dims map to trailing axes
((3, 1), (2, 3, 4), (1, 2)),
# add leading broadcast axis
((3, 4), (1, 3, 4), (1, 2)),
# insert broadcasting in middle axis
((3,), (2, 3, 1), (1,)),
)
make_arg = functools.partial(
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)

for shape, target_shape, broadcast_dimensions in cases:
tensor = make_arg(shape)
yield opinfo_core.SampleInput(tensor, args=(target_shape, broadcast_dimensions))


def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs):
del op_info
# input_shape, output_size, kernal, dilation, padding, stride
Expand Down Expand Up @@ -2687,6 +2716,13 @@ def __init__(self):
sample_inputs_func=sample_inputs_upsample_trilinear3d_vec,
supports_out=False,
),
opinfo_core.ReductionOpInfo(
"ops.prims.broadcast_in_dim.default",
op=torch.ops.prims.broadcast_in_dim.default,
dtypes=common_dtype.all_types(),
sample_inputs_func=sample_inputs_broadcast_in_dim,
supports_out=False,
),
opinfo_core.ReductionOpInfo(
"ops.prims.var.default",
nan_policy="propagate",
Expand Down
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2136,6 +2136,7 @@ def _where_input_wrangler(
"Our implementation is based on that for CUDA"
),
),
TorchLibOpInfo("ops.prims.broadcast_in_dim.default", prims_ops.prims_broadcast_in_dim),
TorchLibOpInfo(
"ops.prims.var.default", prims_ops.prims_var, tolerance={torch.float16: (1e-3, 5e-2)}
),
Expand Down
Loading