@@ -3545,10 +3545,6 @@ def aten_feature_dropout(input: TFloat, p: FLOAT, train: BOOL) -> TFloat:
3545
3545
# Feature dropout applies dropout to entire feature maps/channels
3546
3546
# rather than individual elements
3547
3547
3548
- # Use ONNX operations to handle control flow
3549
- # In inference mode or when p=0, return input unchanged
3550
- should_dropout = op .And (train , p > 0.0 )
3551
-
3552
3548
# Get input shape
3553
3549
input_shape = op .Shape (input )
3554
3550
ndim = op .Size (input_shape )
@@ -3576,27 +3572,13 @@ def aten_feature_dropout(input: TFloat, p: FLOAT, train: BOOL) -> TFloat:
3576
3572
# Select appropriate mask shape
3577
3573
mask_shape = op .Where (is_2d , mask_shape_2d , mask_shape_nd )
3578
3574
3579
- # Generate random uniform values between 0 and 1
3580
- random_vals = op .RandomUniformLike (
3581
- op .ConstantOfShape (mask_shape , value = 0.0 ),
3582
- dtype = 1 , # float32
3583
- low = 0.0 ,
3584
- high = 1.0
3585
- )
3586
-
3587
- # Create binary mask: 1 where random_vals >= p, 0 otherwise
3588
- mask = op .Cast (random_vals >= p , to = input .dtype )
3589
-
3590
- # Scale by 1/(1-p) to maintain expected value
3591
- scale = op .Div (1.0 , op .Sub (1.0 , p ))
3592
- scaled_mask = op .Mul (mask , scale )
3593
-
3594
- # Apply dropout only if we should dropout, otherwise use all-ones mask
3595
- ones_mask = op .ConstantOfShape (mask_shape , value = 1.0 )
3596
- final_mask = op .Where (should_dropout , scaled_mask , ones_mask )
3575
+ # Create a dummy tensor of ones with the mask shape and apply dropout to it
3576
+ # This leverages op.Dropout to handle training mode, scaling, and random generation
3577
+ dummy_tensor = op .ConstantOfShape (mask_shape , value = 1.0 )
3578
+ mask , _ = op .Dropout (dummy_tensor , p , train )
3597
3579
3598
3580
# Apply mask to input (broadcasting will handle different shapes)
3599
- result = op .Mul (input , final_mask )
3581
+ result = op .Mul (input , mask )
3600
3582
3601
3583
return result
3602
3584
0 commit comments