|
9 | 9 | from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
10 | 10 | create_kv_caches_with_random)
|
11 | 11 |
|
12 |
| -NUM_BLOCKS = 1024 |
| 12 | +NUM_BLOCKS = 128 * 1024 |
13 | 13 | PARTITION_SIZE = 512
|
| 14 | +PARTITION_SIZE_ROCM = 256 |
14 | 15 |
|
15 | 16 |
|
16 | 17 | @torch.inference_mode()
|
@@ -78,6 +79,12 @@ def main(
|
78 | 79 | # Prepare for the paged attention kernel.
|
79 | 80 | output = torch.empty_like(query)
|
80 | 81 | if version == "v2":
|
| 82 | + if current_platform.is_rocm(): |
| 83 | + global PARTITION_SIZE |
| 84 | + if not args.custom_paged_attn: |
| 85 | + PARTITION_SIZE = 1024 |
| 86 | + else: |
| 87 | + PARTITION_SIZE = PARTITION_SIZE_ROCM |
81 | 88 | num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
82 | 89 | tmp_output = torch.empty(
|
83 | 90 | size=(num_seqs, num_query_heads, num_partitions, head_size),
|
@@ -119,25 +126,48 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
119 | 126 | v_scale,
|
120 | 127 | )
|
121 | 128 | elif version == "v2":
|
122 |
| - ops.paged_attention_v2( |
123 |
| - output, |
124 |
| - exp_sums, |
125 |
| - max_logits, |
126 |
| - tmp_output, |
127 |
| - query, |
128 |
| - key_cache, |
129 |
| - value_cache, |
130 |
| - num_kv_heads, |
131 |
| - scale, |
132 |
| - block_tables, |
133 |
| - seq_lens, |
134 |
| - block_size, |
135 |
| - max_seq_len, |
136 |
| - alibi_slopes, |
137 |
| - kv_cache_dtype, |
138 |
| - k_scale, |
139 |
| - v_scale, |
140 |
| - ) |
| 129 | + if not args.custom_paged_attn: |
| 130 | + ops.paged_attention_v2( |
| 131 | + output, |
| 132 | + exp_sums, |
| 133 | + max_logits, |
| 134 | + tmp_output, |
| 135 | + query, |
| 136 | + key_cache, |
| 137 | + value_cache, |
| 138 | + num_kv_heads, |
| 139 | + scale, |
| 140 | + block_tables, |
| 141 | + seq_lens, |
| 142 | + block_size, |
| 143 | + max_seq_len, |
| 144 | + alibi_slopes, |
| 145 | + kv_cache_dtype, |
| 146 | + k_scale, |
| 147 | + v_scale, |
| 148 | + ) |
| 149 | + else: |
| 150 | + ops.paged_attention_rocm( |
| 151 | + output, |
| 152 | + exp_sums, |
| 153 | + max_logits, |
| 154 | + tmp_output, |
| 155 | + query, |
| 156 | + key_cache, |
| 157 | + value_cache, |
| 158 | + num_kv_heads, |
| 159 | + scale, |
| 160 | + block_tables, |
| 161 | + seq_lens, |
| 162 | + block_size, |
| 163 | + max_seq_len, |
| 164 | + alibi_slopes, |
| 165 | + kv_cache_dtype, |
| 166 | + k_scale, |
| 167 | + v_scale, |
| 168 | + None, |
| 169 | + PARTITION_SIZE, |
| 170 | + ) |
141 | 171 | else:
|
142 | 172 | raise ValueError(f"Invalid version: {version}")
|
143 | 173 | torch.cuda.synchronize()
|
@@ -191,6 +221,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
191 | 221 | help="Data type for kv cache storage. If 'auto', will use model "
|
192 | 222 | "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
193 | 223 | "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
|
| 224 | + parser.add_argument("--custom-paged-attn", |
| 225 | + action="store_true", |
| 226 | + help="Use custom paged attention") |
194 | 227 | args = parser.parse_args()
|
195 | 228 | print(args)
|
196 | 229 |
|
|
0 commit comments