@@ -4444,12 +4444,18 @@ def test_jagged_op_different_output_shape_dim(
44444444 @dtypes (torch .float32 )
44454445 @parametrize ("requires_grad" , [False , True ])
44464446 @parametrize ("components_require_grad" , [False , True ])
4447+ @parametrize (
4448+ "func" ,
4449+ [torch .nn .functional .softmax , torch .nn .functional .log_softmax ],
4450+ name_fn = lambda func : func .__name__ ,
4451+ )
44474452 def test_softmax_dim (
44484453 self ,
44494454 device ,
44504455 dtype ,
44514456 requires_grad ,
44524457 components_require_grad ,
4458+ func ,
44534459 ):
44544460 """
44554461 Softmax passes when reducing on valid reduction dimensions.
@@ -4468,7 +4474,7 @@ def test_softmax_dim(
44684474
44694475 for reduce_dim , _ in reduce_dims :
44704476 nt = torch .nested .as_nested_tensor (ts , layout = torch .jagged )
4471- out_actual = torch . nn . functional . softmax (nt , dim = reduce_dim )
4477+ out_actual = func (nt , dim = reduce_dim )
44724478 torch ._dynamo .disable (self .assertEqual )(
44734479 len (out_actual .shape ), len (output_shape )
44744480 ) # disable if running on dynamo
@@ -4498,12 +4504,10 @@ def test_softmax_dim(
44984504 reduce_dim , reduce_dim_expected = reduce_dim_tuple
44994505
45004506 if nt .dim () > reduce_dim :
4501- out_actual = torch .nn .functional .softmax (
4502- nt , dim = reduce_dim
4503- ) # nested tensor
4504- out_expected = torch .nn .functional .softmax (
4505- nt .values (), dim = reduce_dim_expected
4506- ) # dense tensor of dimensions 1 less than out_actual
4507+ # nested tensor
4508+ out_actual = func (nt , dim = reduce_dim )
4509+ # dense tensor of dimensions 1 less than out_actual
4510+ out_expected = func (nt .values (), dim = reduce_dim_expected )
45074511 self .assertTrue (
45084512 torch .allclose (out_actual .values ().view (- 1 ), out_expected .view (- 1 ))
45094513 )
@@ -4601,8 +4605,13 @@ def test_softmax_dim_reduce_ragged_idx_1(
46014605 @dtypes (torch .float32 )
46024606 @parametrize ("requires_grad" , [False , True ])
46034607 @parametrize ("components_require_grad" , [False , True ])
4608+ @parametrize (
4609+ "func" ,
4610+ [torch .nn .functional .softmax , torch .nn .functional .log_softmax ],
4611+ name_fn = lambda func : func .__name__ ,
4612+ )
46044613 def test_softmax_reduce_batch_dim (
4605- self , device , dtype , requires_grad , components_require_grad
4614+ self , device , dtype , requires_grad , components_require_grad , func
46064615 ):
46074616 """
46084617 Softmax on NestedTensor fails when trying to reduce across batch dimension.
@@ -4627,7 +4636,7 @@ def test_softmax_reduce_batch_dim(
46274636 RuntimeError ,
46284637 "not supported when reducing across the batch dimension for NestedTensor" ,
46294638 ):
4630- out = torch . nn . functional . softmax (nt , dim = reduce_dim )
4639+ out = func (nt , dim = reduce_dim )
46314640
46324641 @dtypes (torch .float32 )
46334642 @parametrize ("requires_grad" , [False , True ])
0 commit comments