Skip to content

Commit 29f1968

Browse files
Copilotjustinchuby
andcommitted
Refactor aten_feature_dropout to use op.Dropout for cleaner implementation
Co-authored-by: justinchuby <[email protected]>
1 parent 658f966 commit 29f1968

File tree

1 file changed

+5
-23
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+5
-23
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3545,10 +3545,6 @@ def aten_feature_dropout(input: TFloat, p: FLOAT, train: BOOL) -> TFloat:
35453545
# Feature dropout applies dropout to entire feature maps/channels
35463546
# rather than individual elements
35473547

3548-
# Use ONNX operations to handle control flow
3549-
# In inference mode or when p=0, return input unchanged
3550-
should_dropout = op.And(train, p > 0.0)
3551-
35523548
# Get input shape
35533549
input_shape = op.Shape(input)
35543550
ndim = op.Size(input_shape)
@@ -3576,27 +3572,13 @@ def aten_feature_dropout(input: TFloat, p: FLOAT, train: BOOL) -> TFloat:
35763572
# Select appropriate mask shape
35773573
mask_shape = op.Where(is_2d, mask_shape_2d, mask_shape_nd)
35783574

3579-
# Generate random uniform values between 0 and 1
3580-
random_vals = op.RandomUniformLike(
3581-
op.ConstantOfShape(mask_shape, value=0.0),
3582-
dtype=1, # float32
3583-
low=0.0,
3584-
high=1.0
3585-
)
3586-
3587-
# Create binary mask: 1 where random_vals >= p, 0 otherwise
3588-
mask = op.Cast(random_vals >= p, to=input.dtype)
3589-
3590-
# Scale by 1/(1-p) to maintain expected value
3591-
scale = op.Div(1.0, op.Sub(1.0, p))
3592-
scaled_mask = op.Mul(mask, scale)
3593-
3594-
# Apply dropout only if we should dropout, otherwise use all-ones mask
3595-
ones_mask = op.ConstantOfShape(mask_shape, value=1.0)
3596-
final_mask = op.Where(should_dropout, scaled_mask, ones_mask)
3575+
# Create a dummy tensor of ones with the mask shape and apply dropout to it
3576+
# This leverages op.Dropout to handle training mode, scaling, and random generation
3577+
dummy_tensor = op.ConstantOfShape(mask_shape, value=1.0)
3578+
mask, _ = op.Dropout(dummy_tensor, p, train)
35973579

35983580
# Apply mask to input (broadcasting will handle different shapes)
3599-
result = op.Mul(input, final_mask)
3581+
result = op.Mul(input, mask)
36003582

36013583
return result
36023584

0 commit comments

Comments
 (0)