Skip to content

Commit ef731cd

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[2/3] Update .pyi Python stub files: Prettify rnn.py by using type annotated NamedTuple (#95267)
Changes: - #95200 1. Recognize `.py.in` and `.pyi.in` files as Python in VS Code for a better development experience. 2. Fix deep setting merge in `tools/vscode_settings.py`. - => this PR: #95267 3. Use `Namedtuple` rather than `namedtuple + __annotations__` for `torch.nn.utils.rnn.PackedSequence_`: `namedtuple + __annotations__`: ```python PackedSequence_ = namedtuple('PackedSequence_', ['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices']) # type annotation for PackedSequence_ to make it compatible with TorchScript PackedSequence_.__annotations__ = {'data': torch.Tensor, 'batch_sizes': torch.Tensor, 'sorted_indices': Optional[torch.Tensor], 'unsorted_indices': Optional[torch.Tensor]} ``` `Namedtuple`: Python 3.6+ ```python class PackedSequence_(NamedTuple): data: torch.Tensor batch_sizes: torch.Tensor sorted_indices: Optional[torch.Tensor] unsorted_indices: Optional[torch.Tensor] ``` - #95268 4. Sort import statements and remove unnecessary imports in `.pyi`, `.pyi.in` files. 5. Format `.pyi`, `.pyi.in` files and remove unnecessary ellipsis `...` in type stubs. Pull Request resolved: #95267 Approved by: https://github.com/janeyx99
1 parent a46e550 commit ef731cd

File tree

6 files changed

+48
-39
lines changed

6 files changed

+48
-39
lines changed

torch/ao/nn/quantized/dynamic/modules/rnn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -455,11 +455,11 @@ def forward_packed(
455455
self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
456456
) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]:
457457
input_, batch_sizes, sorted_indices, unsorted_indices = input
458-
max_batch_size = batch_sizes[0]
459-
max_batch_size = int(max_batch_size)
458+
max_batch_size = int(batch_sizes[0])
460459

461460
output_, hidden = self.forward_impl(
462-
input_, hx, batch_sizes, max_batch_size, sorted_indices)
461+
input_, hx, batch_sizes, max_batch_size, sorted_indices
462+
)
463463

464464
output = PackedSequence(output_, batch_sizes,
465465
sorted_indices, unsorted_indices)
@@ -701,10 +701,10 @@ def forward_packed(
701701
self, input: PackedSequence, hx: Optional[Tensor] = None
702702
) -> Tuple[PackedSequence, Tensor]:
703703
input_, batch_sizes, sorted_indices, unsorted_indices = input
704-
max_batch_size = batch_sizes[0]
705-
max_batch_size = int(max_batch_size)
704+
max_batch_size = int(batch_sizes[0])
706705
output_, hidden = self.forward_impl(
707-
input_, hx, batch_sizes, max_batch_size, sorted_indices)
706+
input_, hx, batch_sizes, max_batch_size, sorted_indices
707+
)
708708

709709
output = PackedSequence(output_, batch_sizes,
710710
sorted_indices, unsorted_indices)

torch/ao/nn/quantized/reference/modules/rnn.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,7 @@ def forward(self, input, hx=None): # noqa: F811
412412
batch_sizes = None
413413
if isinstance(orig_input, PackedSequence):
414414
input, batch_sizes, sorted_indices, unsorted_indices = input
415-
max_batch_size = batch_sizes[0]
416-
max_batch_size = int(max_batch_size)
415+
max_batch_size = int(batch_sizes[0])
417416
else:
418417
batch_sizes = None
419418
is_batched = input.dim() == 3
@@ -544,8 +543,7 @@ def forward(self, input, hx=None): # noqa: F811
544543
# xxx: isinstance check needs to be in conditional for TorchScript to compile
545544
if isinstance(orig_input, PackedSequence):
546545
input, batch_sizes, sorted_indices, unsorted_indices = input
547-
max_batch_size = batch_sizes[0]
548-
max_batch_size = int(max_batch_size)
546+
max_batch_size = int(batch_sizes[0])
549547
else:
550548
batch_sizes = None
551549
assert (input.dim() in (2, 3)), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"

torch/jit/quantized.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,15 @@ def forward_tensor(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = No
406406
return output, self.permute_hidden(hidden, unsorted_indices)
407407

408408
@torch.jit.script_method
409-
def forward_packed(self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
410-
) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]:
411-
input, batch_sizes, sorted_indices, unsorted_indices = input
412-
max_batch_size = batch_sizes[0]
413-
max_batch_size = int(max_batch_size)
414-
415-
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
409+
def forward_packed(
410+
self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
411+
) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]:
412+
input_, batch_sizes, sorted_indices, unsorted_indices = input
413+
max_batch_size = int(batch_sizes[0])
414+
415+
output, hidden = self.forward_impl(
416+
input_, hx, batch_sizes, max_batch_size, sorted_indices
417+
)
416418

417419
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
418420
return output, self.permute_hidden(hidden, unsorted_indices)
@@ -490,11 +492,12 @@ def forward_tensor(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Te
490492

491493
@torch.jit.script_method
492494
def forward_packed(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]:
493-
input, batch_sizes, sorted_indices, unsorted_indices = input
494-
max_batch_size = batch_sizes[0]
495-
max_batch_size = int(max_batch_size)
495+
input_, batch_sizes, sorted_indices, unsorted_indices = input
496+
max_batch_size = int(batch_sizes[0])
496497

497-
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
498+
output, hidden = self.forward_impl(
499+
input_, hx, batch_sizes, max_batch_size, sorted_indices
500+
)
498501

499502
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
500503
return output, self.permute_hidden(hidden, unsorted_indices)

torch/nn/modules/rnn.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -766,8 +766,7 @@ def forward(self, input, hx=None): # noqa: F811
766766
batch_sizes = None
767767
if isinstance(orig_input, PackedSequence):
768768
input, batch_sizes, sorted_indices, unsorted_indices = input
769-
max_batch_size = batch_sizes[0]
770-
max_batch_size = int(max_batch_size)
769+
max_batch_size = int(batch_sizes[0])
771770
else:
772771
batch_sizes = None
773772
assert (input.dim() in (2, 3)), f"LSTM: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
@@ -961,8 +960,7 @@ def forward(self, input, hx=None): # noqa: F811
961960
# xxx: isinstance check needs to be in conditional for TorchScript to compile
962961
if isinstance(orig_input, PackedSequence):
963962
input, batch_sizes, sorted_indices, unsorted_indices = input
964-
max_batch_size = batch_sizes[0]
965-
max_batch_size = int(max_batch_size)
963+
max_batch_size = int(batch_sizes[0])
966964
else:
967965
batch_sizes = None
968966
assert (input.dim() in (2, 3)), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"

torch/nn/utils/rnn.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
1-
from collections import namedtuple
21
import warnings
2+
from typing import Iterable, List, NamedTuple, Tuple, Union
33

44
import torch
55
from torch import Tensor
66
from ... import _VF
77
from ..._jit_internal import Optional
88

9-
from typing import List, Tuple, Union, Iterable
10-
119

1210
__all__ = ['PackedSequence', 'invert_permutation', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence',
1311
'unpad_sequence', 'pack_sequence', 'unpack_sequence']
1412

15-
PackedSequence_ = namedtuple('PackedSequence_',
16-
['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices'])
1713

18-
# type annotation for PackedSequence_ to make it compatible with TorchScript
19-
PackedSequence_.__annotations__ = {'data': torch.Tensor, 'batch_sizes': torch.Tensor,
20-
'sorted_indices': Optional[torch.Tensor],
21-
'unsorted_indices': Optional[torch.Tensor]}
14+
class PackedSequence_(NamedTuple):
15+
data: torch.Tensor
16+
batch_sizes: torch.Tensor
17+
sorted_indices: Optional[torch.Tensor]
18+
unsorted_indices: Optional[torch.Tensor]
2219

2320

2421
def bind(optional, fn):

torch/nn/utils/rnn.pyi

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,23 @@
1-
from collections import namedtuple
2-
from typing import Any, List, Optional, overload, Union, TypeVar, Tuple, Sequence
3-
from torch import Tensor
4-
from torch.types import _dtype, _device
1+
from typing import (
2+
Any,
3+
List,
4+
Optional,
5+
Sequence,
6+
Tuple,
7+
TypeVar,
8+
Union,
9+
NamedTuple,
10+
overload,
11+
)
512

6-
PackedSequence_ = namedtuple('PackedSequence_', ['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices'])
13+
from torch import Tensor
14+
from torch.types import _device, _dtype
715

16+
class PackedSequence_(NamedTuple):
17+
data: Tensor
18+
batch_sizes: Tensor
19+
sorted_indices: Optional[Tensor]
20+
unsorted_indices: Optional[Tensor]
821

922
def bind(optional: Any, fn: Any): ...
1023

0 commit comments

Comments
 (0)