Skip to content

[torchlib] Implement torch.ops.prims.broadcast_in_dim.default #2382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
46 changes: 43 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import Optional, Sequence

from onnxscript import INT64
from onnxscript import INT64, ir
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import RealType, TTensor
Expand Down Expand Up @@ -176,12 +176,52 @@
raise NotImplementedError()


@torch_op("prims::broadcast_in_dim", trace_only=True)
def prims_broadcast_in_dim(
a: TensorType, shape: INT64, broadcast_dimensions: Sequence[int]
) -> TensorType:
"""broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)"""

raise NotImplementedError()

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Simplified approach that replaces ScatterElements with more basic operations
# while still leveraging compile-time knowledge of broadcast_dimensions

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

input_shape = op.Shape(a)
target_rank = len(shape)

Check warning on line 189 in onnxscript/function_libs/torch_lib/ops/prims.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/prims.py#L188-L189

Added lines #L188 - L189 were not covered by tests

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
if not broadcast_dimensions:
# Special case: no broadcast dimensions - all target dims should be 1

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

ones = op.ConstantOfShape(op.Constant(value_ints=[target_rank]), value=op.Constant(value_int=1))
reshaped = op.Reshape(a, ones)
return op.Expand(reshaped, shape)

Check warning on line 195 in onnxscript/function_libs/torch_lib/ops/prims.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/prims.py#L193-L195

Added lines #L193 - L195 were not covered by tests

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Build intermediate shape using a simpler approach than ScatterElements
# We'll construct it by concatenating the right values for each position

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

# Create base shape of all 1s
ones = [1] * target_rank

Check warning on line 201 in onnxscript/function_libs/torch_lib/ops/prims.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/prims.py#L201

Added line #L201 was not covered by tests

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# For each broadcast dimension, we'll replace the 1 with the actual input dimension
# Since broadcast_dimensions is compile-time known, we can do this with individual operations
intermediate_shape = ones

Check warning on line 205 in onnxscript/function_libs/torch_lib/ops/prims.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/prims.py#L205

Added line #L205 was not covered by tests

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
for i, broadcast_dim in enumerate(broadcast_dimensions):
# Get the input dimension value
input_dim_value = op.Gather(input_shape, op.Constant(value_int=i))

Check warning on line 209 in onnxscript/function_libs/torch_lib/ops/prims.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/prims.py#L209

Added line #L209 was not covered by tests

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

# Create a one-hot mask for this position
indices = op.Range(op.Constant(value_int=0), op.Constant(value_int=target_rank), op.Constant(value_int=1))
mask = op.Equal(indices, op.Constant(value_int=broadcast_dim))

Check warning on line 213 in onnxscript/function_libs/torch_lib/ops/prims.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/prims.py#L212-L213

Added lines #L212 - L213 were not covered by tests

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Use Where to replace the 1 with the input dimension value at this position
intermediate_shape = op.Where(

Check warning on line 216 in onnxscript/function_libs/torch_lib/ops/prims.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/prims.py#L216

Added line #L216 was not covered by tests
mask,
op.Cast(input_dim_value, to=ir.TensorType.INT64),
intermediate_shape
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
# Reshape input to intermediate shape and expand to target
reshaped = op.Reshape(a, intermediate_shape)
return op.Expand(reshaped, shape)

Check warning on line 224 in onnxscript/function_libs/torch_lib/ops/prims.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/prims.py#L223-L224

Added lines #L223 - L224 were not covered by tests


def prims_cat(tensors: Sequence[TensorType], dim: int) -> TensorType:
Expand Down
Loading