|
4 | 4 | This module is experimental and subject to change. |
5 | 5 | """ |
6 | 6 |
|
7 | | -from typing import Optional, Union |
| 7 | +from typing import Union |
8 | 8 |
|
9 | 9 | import torch |
10 | 10 | from torch.nn.attention.flex_attention import ( |
@@ -197,8 +197,8 @@ def assign( |
197 | 197 | def convert_logical_block_mask( |
198 | 198 | self, |
199 | 199 | block_mask: BlockMask, |
200 | | - batch_idx: Optional[torch.Tensor] = None, |
201 | | - kv_len: Optional[torch.Tensor] = None, |
| 200 | + batch_idx: torch.Tensor | None = None, |
| 201 | + kv_len: torch.Tensor | None = None, |
202 | 202 | ) -> BlockMask: |
203 | 203 | """ |
204 | 204 | Converts a logical block mask by mapping its logical kv indices to the corresponding |
@@ -279,8 +279,8 @@ def convert_logical_block_mask( |
279 | 279 |
|
280 | 280 | def get_mask_mod( |
281 | 281 | self, |
282 | | - mask_mod: Optional[_mask_mod_signature], |
283 | | - kv_len: Optional[torch.Tensor] = None, |
| 282 | + mask_mod: _mask_mod_signature | None, |
| 283 | + kv_len: torch.Tensor | None = None, |
284 | 284 | ) -> _mask_mod_signature: |
285 | 285 | """ |
286 | 286 | Converts a mask_mod based on mapping from the physical block index to the logical |
@@ -316,8 +316,8 @@ def new_mask_mod( |
316 | 316 |
|
317 | 317 | def get_score_mod( |
318 | 318 | self, |
319 | | - score_mod: Optional[_score_mod_signature], |
320 | | - kv_len: Optional[torch.Tensor] = None, |
| 319 | + score_mod: _score_mod_signature | None, |
| 320 | + kv_len: torch.Tensor | None = None, |
321 | 321 | ) -> _score_mod_signature: |
322 | 322 | """ |
323 | 323 | Converts a score_mod based on mapping from the physical block index to the logical |
|
0 commit comments