Skip to content

Commit d539011

Browse files
committed
py linter fixes
1 parent afd6689 commit d539011

File tree

3 files changed

+19
-14
lines changed

3 files changed

+19
-14
lines changed

benchmarks/kernels/benchmark_paged_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from vllm._custom_C import paged_attention_custom
1010
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
1111

12-
NUM_BLOCKS = 1024*1024
12+
NUM_BLOCKS = 1024 * 1024
1313
PARTITION_SIZE = 256
1414

1515

tests/kernels/test_attention_custom.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,29 @@
66

77
from vllm import _custom_ops as ops
88
from vllm._custom_C import paged_attention_custom
9-
from vllm.utils import get_max_shared_memory_bytes, is_hip
9+
from vllm.utils import is_hip
1010

1111
from .allclose_default import get_default_atol, get_default_rtol
1212

13-
MAX_SEQ_LEN = 32*1024
13+
MAX_SEQ_LEN = 32 * 1024
1414
# There may not be enough gpu memory due to large NUM_BLOCKS.
1515
# Reduce NUM_BLOCKS when it happens.
16-
NUM_BLOCKS = 128*1024+4321 # Arbitrary values for testing
16+
NUM_BLOCKS = 128 * 1024 + 4321 # Arbitrary values for testing
1717
PARTITION_SIZE = 256
18-
DTYPES = [torch.bfloat16,torch.half]
19-
NUM_GEN_SEQS = [1,17] # Arbitrary values for testing
18+
DTYPES = [torch.bfloat16, torch.half]
19+
NUM_GEN_SEQS = [1, 17] # Arbitrary values for testing
2020
NUM_HEADS = [(8 * x, 8) for x in range(1, 17)] # Arbitrary values for testing
2121

22-
HEAD_SIZES = [64,128]
23-
BLOCK_SIZES = [16,32]
24-
USE_ALIBI = [True,False]
22+
HEAD_SIZES = [64, 128]
23+
BLOCK_SIZES = [16, 32]
24+
USE_ALIBI = [True, False]
2525
KV_CACHE_DTYPE = ["auto"]
2626
SEEDS = [37]
2727
CUDA_DEVICES = [
2828
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 1)
2929
]
3030

31+
3132
def ref_masked_attention(
3233
query: torch.Tensor,
3334
key: torch.Tensor,
@@ -279,10 +280,13 @@ def test_paged_attention(
279280
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
280281
# so we use a relaxed tolerance for the test.
281282
atol, rtol = 1e-4, 1e-5
282-
if dtype == torch.bfloat16: atol, rtol = 2e-4, 1e-5
283+
if dtype == torch.bfloat16:
284+
atol, rtol = 2e-4, 1e-5
283285
if use_alibi:
284-
if dtype == torch.half: atol, rtol = 5e-4, 1e-5
285-
if dtype == torch.bfloat16: atol, rtol = 1e-3, 1e-5
286+
if dtype == torch.half:
287+
atol, rtol = 5e-4, 1e-5
288+
if dtype == torch.bfloat16:
289+
atol, rtol = 1e-3, 1e-5
286290
if kv_cache_dtype == "fp8":
287291
atol, rtol = 1e-2, 1e-5
288292
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)

vllm/attention/ops/paged_attn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,10 @@ def forward_decode(
118118
num_seqs, num_heads, head_size = query.shape
119119
gqa_ratio = num_heads // num_kv_heads
120120
use_custom = (custom_attn_available
121-
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
121+
and (query.dtype == torch.half
122+
or query.dtype == torch.bfloat16)
122123
and (head_size == 128 or head_size == 64)
123-
and (block_size == 16 or block_size==32)
124+
and (block_size == 16 or block_size == 32)
124125
and kv_cache_dtype == "auto"
125126
and (gqa_ratio >= 1 and gqa_ratio <= 16)
126127
and max_seq_len <= 32768)

0 commit comments

Comments
 (0)