|
6 | 6 |
|
7 | 7 | from vllm import _custom_ops as ops |
8 | 8 | from vllm._custom_C import paged_attention_custom |
9 | | -from vllm.utils import get_max_shared_memory_bytes, is_hip |
| 9 | +from vllm.utils import is_hip |
10 | 10 |
|
11 | 11 | from .allclose_default import get_default_atol, get_default_rtol |
12 | 12 |
|
13 | | -MAX_SEQ_LEN = 32*1024 |
| 13 | +MAX_SEQ_LEN = 32 * 1024 |
14 | 14 | # There may not be enough gpu memory due to large NUM_BLOCKS. |
15 | 15 | # Reduce NUM_BLOCKS when it happens. |
16 | | -NUM_BLOCKS = 128*1024+4321 # Arbitrary values for testing |
| 16 | +NUM_BLOCKS = 128 * 1024 + 4321 # Arbitrary values for testing |
17 | 17 | PARTITION_SIZE = 256 |
18 | | -DTYPES = [torch.bfloat16,torch.half] |
19 | | -NUM_GEN_SEQS = [1,17] # Arbitrary values for testing |
| 18 | +DTYPES = [torch.bfloat16, torch.half] |
| 19 | +NUM_GEN_SEQS = [1, 17] # Arbitrary values for testing |
20 | 20 | NUM_HEADS = [(8 * x, 8) for x in range(1, 17)] # Arbitrary values for testing |
21 | 21 |
|
22 | | -HEAD_SIZES = [64,128] |
23 | | -BLOCK_SIZES = [16,32] |
24 | | -USE_ALIBI = [True,False] |
| 22 | +HEAD_SIZES = [64, 128] |
| 23 | +BLOCK_SIZES = [16, 32] |
| 24 | +USE_ALIBI = [True, False] |
25 | 25 | KV_CACHE_DTYPE = ["auto"] |
26 | 26 | SEEDS = [37] |
27 | 27 | CUDA_DEVICES = [ |
28 | 28 | f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 1) |
29 | 29 | ] |
30 | 30 |
|
| 31 | + |
31 | 32 | def ref_masked_attention( |
32 | 33 | query: torch.Tensor, |
33 | 34 | key: torch.Tensor, |
@@ -279,10 +280,13 @@ def test_paged_attention( |
279 | 280 | # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, |
280 | 281 | # so we use a relaxed tolerance for the test. |
281 | 282 | atol, rtol = 1e-4, 1e-5 |
282 | | - if dtype == torch.bfloat16: atol, rtol = 2e-4, 1e-5 |
| 283 | + if dtype == torch.bfloat16: |
| 284 | + atol, rtol = 2e-4, 1e-5 |
283 | 285 | if use_alibi: |
284 | | - if dtype == torch.half: atol, rtol = 5e-4, 1e-5 |
285 | | - if dtype == torch.bfloat16: atol, rtol = 1e-3, 1e-5 |
| 286 | + if dtype == torch.half: |
| 287 | + atol, rtol = 5e-4, 1e-5 |
| 288 | + if dtype == torch.bfloat16: |
| 289 | + atol, rtol = 1e-3, 1e-5 |
286 | 290 | if kv_cache_dtype == "fp8": |
287 | 291 | atol, rtol = 1e-2, 1e-5 |
288 | 292 | assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) |
0 commit comments