Skip to content

Commit 72f1ef4

Browse files
alanhdumarkc-614
authored andcommitted
Add minimal nn.functional.log_softmax support for NestedTensor (pytorch#159662)
This only works for the jagged layout and for the non-batch and non-jagged dimensions. I did this mostly by copy-pasting from the existing softmax implementation, but it seems fairly straightforward and I think it should work. Pull Request resolved: pytorch#159662 Approved by: https://github.com/jbschlosser
1 parent 28397ee commit 72f1ef4

File tree

2 files changed

+58
-9
lines changed

2 files changed

+58
-9
lines changed

test/test_nestedtensor.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4444,12 +4444,18 @@ def test_jagged_op_different_output_shape_dim(
44444444
@dtypes(torch.float32)
44454445
@parametrize("requires_grad", [False, True])
44464446
@parametrize("components_require_grad", [False, True])
4447+
@parametrize(
4448+
"func",
4449+
[torch.nn.functional.softmax, torch.nn.functional.log_softmax],
4450+
name_fn=lambda func: func.__name__,
4451+
)
44474452
def test_softmax_dim(
44484453
self,
44494454
device,
44504455
dtype,
44514456
requires_grad,
44524457
components_require_grad,
4458+
func,
44534459
):
44544460
"""
44554461
Softmax passes when reducing on valid reduction dimensions.
@@ -4468,7 +4474,7 @@ def test_softmax_dim(
44684474

44694475
for reduce_dim, _ in reduce_dims:
44704476
nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
4471-
out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim)
4477+
out_actual = func(nt, dim=reduce_dim)
44724478
torch._dynamo.disable(self.assertEqual)(
44734479
len(out_actual.shape), len(output_shape)
44744480
) # disable if running on dynamo
@@ -4498,12 +4504,10 @@ def test_softmax_dim(
44984504
reduce_dim, reduce_dim_expected = reduce_dim_tuple
44994505

45004506
if nt.dim() > reduce_dim:
4501-
out_actual = torch.nn.functional.softmax(
4502-
nt, dim=reduce_dim
4503-
) # nested tensor
4504-
out_expected = torch.nn.functional.softmax(
4505-
nt.values(), dim=reduce_dim_expected
4506-
) # dense tensor of dimensions 1 less than out_actual
4507+
# nested tensor
4508+
out_actual = func(nt, dim=reduce_dim)
4509+
# dense tensor of dimensions 1 less than out_actual
4510+
out_expected = func(nt.values(), dim=reduce_dim_expected)
45074511
self.assertTrue(
45084512
torch.allclose(out_actual.values().view(-1), out_expected.view(-1))
45094513
)
@@ -4601,8 +4605,13 @@ def test_softmax_dim_reduce_ragged_idx_1(
46014605
@dtypes(torch.float32)
46024606
@parametrize("requires_grad", [False, True])
46034607
@parametrize("components_require_grad", [False, True])
4608+
@parametrize(
4609+
"func",
4610+
[torch.nn.functional.softmax, torch.nn.functional.log_softmax],
4611+
name_fn=lambda func: func.__name__,
4612+
)
46044613
def test_softmax_reduce_batch_dim(
4605-
self, device, dtype, requires_grad, components_require_grad
4614+
self, device, dtype, requires_grad, components_require_grad, func
46064615
):
46074616
"""
46084617
Softmax on NestedTensor fails when trying to reduce across batch dimension.
@@ -4627,7 +4636,7 @@ def test_softmax_reduce_batch_dim(
46274636
RuntimeError,
46284637
"not supported when reducing across the batch dimension for NestedTensor",
46294638
):
4630-
out = torch.nn.functional.softmax(nt, dim=reduce_dim)
4639+
out = func(nt, dim=reduce_dim)
46314640

46324641
@dtypes(torch.float32)
46334642
@parametrize("requires_grad", [False, True])

torch/nested/_internal/ops.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,46 @@ def _softmax_default(func, *args, **kwargs):
841841
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
842842

843843

844+
@register_jagged_func(
845+
torch.ops.aten._log_softmax.default, "self: jt_all, dim: any, half_to_float: any"
846+
)
847+
def _log_softmax_default(func, *args, **kwargs):
848+
_, new_kwargs = normalize_function( # type: ignore[misc]
849+
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
850+
)
851+
852+
if isinstance(new_kwargs["dim"], tuple):
853+
raise RuntimeError(
854+
"log_softmax(): not supported for dimensions of type 'tuple' for NestedTensor"
855+
)
856+
857+
inp = new_kwargs.pop("input")
858+
859+
(
860+
new_kwargs["dim"],
861+
reduce_on_batch,
862+
reduce_on_ragged,
863+
_reduce_on_non_batch,
864+
) = _wrap_jagged_dims(
865+
inp.dim(), (new_kwargs["dim"],), "log_softmax", inp._ragged_idx
866+
)
867+
868+
if reduce_on_batch:
869+
raise RuntimeError(
870+
"log_softmax(): not supported when reducing across the batch dimension for NestedTensor"
871+
)
872+
873+
if reduce_on_ragged:
874+
raise RuntimeError(
875+
"log_softmax(): not supported when reducing along the ragged dimension for NestedTensor"
876+
)
877+
878+
# torch.log_softmax takes in the reduction dimension as an integer
879+
new_kwargs["dim"] = new_kwargs["dim"][0]
880+
881+
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
882+
883+
844884
@register_jagged_func(
845885
torch.ops.aten._softmax_backward_data.default,
846886
"grad_output: jt, output: jt, dim: any, input_dtype: any",

0 commit comments

Comments
 (0)