@@ -74,25 +74,31 @@ def run_vllm(
7474 quantization_param_path : Optional [str ],
7575 device : str ,
7676 enable_prefix_caching : bool ,
77+ enable_chunked_prefill : bool ,
78+ max_num_batched_tokens : int ,
7779 gpu_memory_utilization : float = 0.9 ,
7880 download_dir : Optional [str ] = None ,
7981) -> float :
8082 from vllm import LLM , SamplingParams
81- llm = LLM (model = model ,
82- tokenizer = tokenizer ,
83- quantization = quantization ,
84- tensor_parallel_size = tensor_parallel_size ,
85- seed = seed ,
86- trust_remote_code = trust_remote_code ,
87- dtype = dtype ,
88- max_model_len = max_model_len ,
89- gpu_memory_utilization = gpu_memory_utilization ,
90- enforce_eager = enforce_eager ,
91- kv_cache_dtype = kv_cache_dtype ,
92- quantization_param_path = quantization_param_path ,
93- device = device ,
94- enable_prefix_caching = enable_prefix_caching ,
95- download_dir = download_dir )
83+ llm = LLM (
84+ model = model ,
85+ tokenizer = tokenizer ,
86+ quantization = quantization ,
87+ tensor_parallel_size = tensor_parallel_size ,
88+ seed = seed ,
89+ trust_remote_code = trust_remote_code ,
90+ dtype = dtype ,
91+ max_model_len = max_model_len ,
92+ gpu_memory_utilization = gpu_memory_utilization ,
93+ enforce_eager = enforce_eager ,
94+ kv_cache_dtype = kv_cache_dtype ,
95+ quantization_param_path = quantization_param_path ,
96+ device = device ,
97+ enable_prefix_caching = enable_prefix_caching ,
98+ download_dir = download_dir ,
99+ enable_chunked_prefill = enable_chunked_prefill ,
100+ max_num_batched_tokens = max_num_batched_tokens ,
101+ )
96102
97103 # Add the requests to the engine.
98104 for prompt , _ , output_len in requests :
@@ -213,15 +219,15 @@ def main(args: argparse.Namespace):
213219 args .output_len )
214220
215221 if args .backend == "vllm" :
216- elapsed_time = run_vllm (requests , args . model , args . tokenizer ,
217- args .quantization , args .tensor_parallel_size ,
218- args .seed , args .n , args .use_beam_search ,
219- args .trust_remote_code , args .dtype ,
220- args .max_model_len , args .enforce_eager ,
221- args .kv_cache_dtype ,
222- args .quantization_param_path , args .device ,
223- args .enable_prefix_caching ,
224- args . gpu_memory_utilization , args .download_dir )
222+ elapsed_time = run_vllm (
223+ requests , args . model , args .tokenizer , args .quantization ,
224+ args . tensor_parallel_size , args .seed , args .n , args .use_beam_search ,
225+ args .trust_remote_code , args .dtype , args . max_model_len ,
226+ args .enforce_eager , args .kv_cache_dtype ,
227+ args . quantization_param_path , args .device ,
228+ args .enable_prefix_caching , args .enable_chunked_prefill ,
229+ args . max_num_batched_tokens , args .gpu_memory_utilization ,
230+ args .download_dir )
225231 elif args .backend == "hf" :
226232 assert args .tensor_parallel_size == 1
227233 elapsed_time = run_hf (requests , args .model , tokenizer , args .n ,
@@ -335,6 +341,14 @@ def main(args: argparse.Namespace):
335341 "--enable-prefix-caching" ,
336342 action = 'store_true' ,
337343 help = "enable automatic prefix caching for vLLM backend." )
344+ parser .add_argument ("--enable-chunked-prefill" ,
345+ action = 'store_true' ,
346+ help = "enable chunked prefill for vLLM backend." )
347+ parser .add_argument ('--max-num-batched-tokens' ,
348+ type = int ,
349+ default = None ,
350+ help = 'maximum number of batched tokens per '
351+ 'iteration' )
338352 parser .add_argument ('--download-dir' ,
339353 type = str ,
340354 default = None ,
0 commit comments