Skip to content

Commit 96f78cd

Browse files
Copilotjustinchuby
andcommitted
Refactor feature_dropout to use op.Shape with start/end parameters instead of op.Gather
Co-authored-by: justinchuby <[email protected]>
1 parent af8d622 commit 96f78cd

File tree

1 file changed

+4
-5
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+4
-5
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3548,15 +3548,14 @@ def aten_feature_dropout(input: TFloat, p: FLOAT, train: BOOL) -> TFloat:
35483548
if p == 0 or not train:
35493549
return input
35503550

3551-
# Get input shape
3552-
input_shape = op.Shape(input)
3553-
ndim = op.Size(input_shape)
3551+
# Get input dimensions
3552+
ndim = op.Size(op.Shape(input))
35543553

35553554
# Create mask shape for feature dropout
35563555
# For 2D tensors [N, C]: mask shape is [N, C]
35573556
# For higher dim tensors [N, C, ...]: mask shape is [N, C, 1, 1, ...]
3558-
batch_size = op.Gather(input_shape, [0])
3559-
channel_size = op.Gather(input_shape, [1])
3557+
batch_size = op.Shape(input, start=0, end=1)
3558+
channel_size = op.Shape(input, start=1, end=2)
35603559

35613560
# Create the appropriate mask shape based on tensor dimensions
35623561
is_2d = op.Equal(ndim, 2)

0 commit comments

Comments
 (0)