Skip to content

Commit bcd73cb

Browse files
vllmellmsanyalington
authored andcommitted
added rocm custom paged attention.
Ported ROCm/vllm changes to upstream vLLM This commit manually ports changes from ROCm/vllm (ROCm#372) to upstream vLLM. The original work was done by sanyalington. Co-authored-by: sanyalington <[email protected]> Signed-off-by: vllmellm <[email protected]>
1 parent 68ad4e3 commit bcd73cb

File tree

7 files changed

+1074
-360
lines changed

7 files changed

+1074
-360
lines changed

benchmarks/kernels/benchmark_paged_attention.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
1010
create_kv_caches_with_random)
1111

12-
NUM_BLOCKS = 1024
12+
NUM_BLOCKS = 128 * 1024
1313
PARTITION_SIZE = 512
14+
PARTITION_SIZE_ROCM = 256
1415

1516

1617
@torch.inference_mode()
@@ -78,6 +79,12 @@ def main(
7879
# Prepare for the paged attention kernel.
7980
output = torch.empty_like(query)
8081
if version == "v2":
82+
if current_platform.is_rocm():
83+
global PARTITION_SIZE
84+
if not args.custom_paged_attn:
85+
PARTITION_SIZE = 1024
86+
else:
87+
PARTITION_SIZE = PARTITION_SIZE_ROCM
8188
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
8289
tmp_output = torch.empty(
8390
size=(num_seqs, num_query_heads, num_partitions, head_size),
@@ -119,25 +126,48 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
119126
v_scale,
120127
)
121128
elif version == "v2":
122-
ops.paged_attention_v2(
123-
output,
124-
exp_sums,
125-
max_logits,
126-
tmp_output,
127-
query,
128-
key_cache,
129-
value_cache,
130-
num_kv_heads,
131-
scale,
132-
block_tables,
133-
seq_lens,
134-
block_size,
135-
max_seq_len,
136-
alibi_slopes,
137-
kv_cache_dtype,
138-
k_scale,
139-
v_scale,
140-
)
129+
if not args.custom_paged_attn:
130+
ops.paged_attention_v2(
131+
output,
132+
exp_sums,
133+
max_logits,
134+
tmp_output,
135+
query,
136+
key_cache,
137+
value_cache,
138+
num_kv_heads,
139+
scale,
140+
block_tables,
141+
seq_lens,
142+
block_size,
143+
max_seq_len,
144+
alibi_slopes,
145+
kv_cache_dtype,
146+
k_scale,
147+
v_scale,
148+
)
149+
else:
150+
ops.paged_attention_rocm(
151+
output,
152+
exp_sums,
153+
max_logits,
154+
tmp_output,
155+
query,
156+
key_cache,
157+
value_cache,
158+
num_kv_heads,
159+
scale,
160+
block_tables,
161+
seq_lens,
162+
block_size,
163+
max_seq_len,
164+
alibi_slopes,
165+
kv_cache_dtype,
166+
k_scale,
167+
v_scale,
168+
None,
169+
PARTITION_SIZE,
170+
)
141171
else:
142172
raise ValueError(f"Invalid version: {version}")
143173
torch.cuda.synchronize()
@@ -191,6 +221,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
191221
help="Data type for kv cache storage. If 'auto', will use model "
192222
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
193223
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
224+
parser.add_argument("--custom-paged-attn",
225+
action="store_true",
226+
help="Use custom paged attention")
194227
args = parser.parse_args()
195228
print(args)
196229

0 commit comments

Comments
 (0)