Skip to content

Commit 4c5e920

Browse files
cyyeverSilv3S
authored andcommitted
[4/N] Use Python 3.10 typing (pytorch#167458)
This PR applies new Union and Optional typing syntax to some files. Pull Request resolved: pytorch#167458 Approved by: https://github.com/albanD
1 parent 80b541b commit 4c5e920

30 files changed

+145
-162
lines changed

torch/nativert/backends/_lowered_aoti_module.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Optional
2-
31
import torch
42
from torch.export import ExportedProgram
53

@@ -10,7 +8,7 @@ def __init__(
108
original_exported_program: ExportedProgram,
119
backend_id: str,
1210
*,
13-
module_name: Optional[str] = None,
11+
module_name: str | None = None,
1412
) -> None:
1513
super().__init__()
1614
self._backend_id = backend_id
@@ -22,7 +20,7 @@ def backend_id(self) -> str:
2220
return self._backend_id
2321

2422
@property
25-
def module_name(self) -> Optional[str]:
23+
def module_name(self) -> str | None:
2624
return self._module_name
2725

2826
@property

torch/nested/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727
def as_nested_tensor(
2828
ts: Union[Tensor, list[Tensor], tuple[Tensor, ...]],
29-
dtype: Optional[DType] = None,
30-
device: Optional[Device] = None,
29+
dtype: DType | None = None,
30+
device: Device | None = None,
3131
layout=None,
3232
) -> Tensor:
3333
r"""
@@ -358,11 +358,11 @@ def narrow(
358358

359359
def nested_tensor_from_jagged(
360360
values: Tensor,
361-
offsets: Optional[Tensor] = None,
362-
lengths: Optional[Tensor] = None,
363-
jagged_dim: Optional[int] = None,
364-
min_seqlen: Optional[int] = None,
365-
max_seqlen: Optional[int] = None,
361+
offsets: Tensor | None = None,
362+
lengths: Tensor | None = None,
363+
jagged_dim: int | None = None,
364+
min_seqlen: int | None = None,
365+
max_seqlen: int | None = None,
366366
) -> Tensor:
367367
r"""
368368
Constructs a jagged layout nested tensor from the given jagged components. The jagged layout

torch/nested/_internal/ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import math
44
import operator
55
from typing import * # noqa: F403
6-
from typing import Optional
76

87
import torch
98
import torch.nn.functional as F
@@ -249,7 +248,7 @@ def inner(*args, **kwargs):
249248
register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE)
250249

251250

252-
def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
251+
def lookup_jagged(func, *args, **kwargs) -> Callable | None:
253252
dispatch_func = JAGGED_OPS_TABLE.get(func, None)
254253
if dispatch_func is not None:
255254
return dispatch_func
@@ -1138,7 +1137,7 @@ def unbind_int(func, *args, **kwargs):
11381137
lengths = inp.lengths()
11391138
ragged_idx = inp._ragged_idx
11401139

1141-
def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None) -> None:
1140+
def _torch_check(_lengths: list[int], _offsets: list[int] | None = None) -> None:
11421141
# This torch._check are needed for torch.compile
11431142
# symbolic shapes processing.
11441143
# offsets and lengths are symbolic variables during compilation,

torch/nested/_internal/sdpa.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# mypy: allow-untyped-defs
22
import logging
3-
from typing import Optional
43

54
import torch
65
import torch.nn
@@ -27,7 +26,7 @@ def _validate_sdpa_input(
2726
query: torch.Tensor,
2827
key: torch.Tensor,
2928
value: torch.Tensor,
30-
attn_mask: Optional[torch.Tensor] = None,
29+
attn_mask: torch.Tensor | None = None,
3130
dropout_p=0.0,
3231
is_causal=False,
3332
scale=None,
@@ -668,8 +667,8 @@ def _autocast(
668667
query: torch.Tensor,
669668
key: torch.Tensor,
670669
value: torch.Tensor,
671-
attn_mask: Optional[torch.Tensor],
672-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
670+
attn_mask: torch.Tensor | None,
671+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
673672
"""
674673
[Autocasting SDPA for NJT]
675674
@@ -714,7 +713,7 @@ def jagged_scaled_dot_product_attention(
714713
query: torch.Tensor,
715714
key: torch.Tensor,
716715
value: torch.Tensor,
717-
attn_mask: Optional[torch.Tensor] = None,
716+
attn_mask: torch.Tensor | None = None,
718717
dropout_p=0.0,
719718
is_causal=False,
720719
scale=None,

torch/nn/attention/_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
"""Defines utilities for interacting with scaled_dot_product_attention"""
33

44
import math
5-
from typing import Optional
65

76
import torch
87

@@ -22,7 +21,7 @@ def _postprocess_flash_output(inpt_tensor: torch.Tensor, og_size: int) -> torch.
2221
return inpt_tensor
2322

2423

25-
def _calculate_scale(head_dim_size: int, scale: Optional[float]) -> float:
24+
def _calculate_scale(head_dim_size: int, scale: float | None) -> float:
2625
"""
2726
For FlashAttention we pad the head dimension to be a multiple of 8 so we need to scale the output
2827
by the original head size and not the padded.
@@ -36,7 +35,7 @@ def _validate_sdpa_input(
3635
query: torch.Tensor,
3736
key: torch.Tensor,
3837
value: torch.Tensor,
39-
attn_mask: Optional[torch.Tensor] = None,
38+
attn_mask: torch.Tensor | None = None,
4039
dropout_p=0.0,
4140
is_causal=False,
4241
scale=None,

torch/nn/attention/bias.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
"""Defines bias subclasses that work with scaled_dot_product_attention"""
33

44
from enum import auto, IntEnum
5-
from typing import Optional
65
from warnings import warn
76

87
import torch
@@ -155,7 +154,7 @@ def _lower_right(self, device: torch.device) -> torch.Tensor:
155154
)
156155

157156
# pyrefly: ignore [bad-return]
158-
def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
157+
def _materialize(self, device: torch.device | None = None) -> torch.Tensor:
159158
"""
160159
Materializes the causal bias into a tensor form.
161160
@@ -183,7 +182,7 @@ def _dispatch(
183182
attn_mask: "CausalBias",
184183
dropout_p: float = 0.0,
185184
is_causal: bool = False,
186-
scale: Optional[float] = None,
185+
scale: float | None = None,
187186
enable_gqa: bool = False,
188187
) -> torch.Tensor:
189188
r"""

torch/nn/attention/experimental/_paged_attention.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
This module is experimental and subject to change.
55
"""
66

7-
from typing import Optional, Union
7+
from typing import Union
88

99
import torch
1010
from torch.nn.attention.flex_attention import (
@@ -197,8 +197,8 @@ def assign(
197197
def convert_logical_block_mask(
198198
self,
199199
block_mask: BlockMask,
200-
batch_idx: Optional[torch.Tensor] = None,
201-
kv_len: Optional[torch.Tensor] = None,
200+
batch_idx: torch.Tensor | None = None,
201+
kv_len: torch.Tensor | None = None,
202202
) -> BlockMask:
203203
"""
204204
Converts a logical block mask by mapping its logical kv indices to the corresponding
@@ -279,8 +279,8 @@ def convert_logical_block_mask(
279279

280280
def get_mask_mod(
281281
self,
282-
mask_mod: Optional[_mask_mod_signature],
283-
kv_len: Optional[torch.Tensor] = None,
282+
mask_mod: _mask_mod_signature | None,
283+
kv_len: torch.Tensor | None = None,
284284
) -> _mask_mod_signature:
285285
"""
286286
Converts a mask_mod based on mapping from the physical block index to the logical
@@ -316,8 +316,8 @@ def new_mask_mod(
316316

317317
def get_score_mod(
318318
self,
319-
score_mod: Optional[_score_mod_signature],
320-
kv_len: Optional[torch.Tensor] = None,
319+
score_mod: _score_mod_signature | None,
320+
kv_len: torch.Tensor | None = None,
321321
) -> _score_mod_signature:
322322
"""
323323
Converts a score_mod based on mapping from the physical block index to the logical

0 commit comments

Comments
 (0)