Skip to content

Commit 6a72aaa

Browse files
zeshengzongpytorchmergebot
authored andcommitted
Fix torch.max optional args dim, keepdim description (pytorch#147177)
[`torch.max`](https://pytorch.org/docs/stable/generated/torch.max.html#torch.max) optional args `dim`, `keepdim` not described in document, but users can ignore them. ```python >>> import torch >>> a = torch.randn(3,1,3) >>> a.max() tensor(1.9145) >>> a.max(dim=1) torch.return_types.max( values=tensor([[ 1.1436, -0.0728, 1.3312], [-0.4049, 0.1792, -1.2247], [ 0.8767, -0.7888, 1.9145]]), indices=tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]])) ``` ## Changes - Add `optional` description for `dim`, `keepdim` - Add example of using `dim`, `keepdim` ## Test Result ### Before ![image](https://github.com/user-attachments/assets/3391bc45-b636-4e64-9406-04d80af0c087) ### After ![image](https://github.com/user-attachments/assets/1d70e282-409c-4573-b276-b8219fd6ef0a) Pull Request resolved: pytorch#147177 Approved by: https://github.com/colesbury
1 parent 452315c commit 6a72aaa

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

torch/_torch_docs.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def merge_dicts(*dicts):
7171
"opt_dim": """
7272
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
7373
If ``None``, all dimensions are reduced.
74+
"""
75+
},
76+
{
77+
"opt_keepdim": """
78+
keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``.
7479
"""
7580
},
7681
)
@@ -6483,8 +6488,8 @@ def merge_dicts(*dicts):
64836488
64846489
Args:
64856490
{input}
6486-
{dim}
6487-
{keepdim} Default: ``False``.
6491+
{opt_dim}
6492+
{opt_keepdim}
64886493
64896494
Keyword args:
64906495
out (tuple, optional): the result tuple of two output tensors (max, max_indices)
@@ -6499,13 +6504,22 @@ def merge_dicts(*dicts):
64996504
[-0.6172, 1.0036, -0.6060, -0.2432]])
65006505
>>> torch.max(a, 1)
65016506
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))
6507+
>>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
6508+
>>> a.max(dim=1, keepdim=True)
6509+
torch.return_types.max(
6510+
values=tensor([[2.], [4.]]),
6511+
indices=tensor([[1], [1]]))
6512+
>>> a.max(dim=1, keepdim=False)
6513+
torch.return_types.max(
6514+
values=tensor([2., 4.]),
6515+
indices=tensor([1, 1]))
65026516
65036517
.. function:: max(input, other, *, out=None) -> Tensor
65046518
:noindex:
65056519
65066520
See :func:`torch.maximum`.
65076521
6508-
""".format(**single_dim_common),
6522+
""".format(**multi_dim_common),
65096523
)
65106524

65116525
add_docstr(

0 commit comments

Comments
 (0)