diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index e934d228f7fd..1044bef59417 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -537,6 +537,7 @@ async def benchmark( ignore_eos: bool, goodput_config_dict: Dict[str, float], max_concurrency: Optional[int], + lora_modules: Optional[List[str]], ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -562,6 +563,7 @@ async def benchmark( multi_modal_content=test_mm_content, ignore_eos=ignore_eos, ) + test_output = await request_func(request_func_input=test_input) if not test_output.success: raise ValueError( @@ -570,6 +572,11 @@ async def benchmark( else: print("Initial test run completed. Starting main benchmark run...") + if lora_modules: + # For each input request, choose a LoRA module at random. + lora_modules = iter( + [random.choice(lora_modules) for _ in range(len(input_requests))]) + if profile: print("Starting profiler...") profile_input = RequestFuncInput(model=model_id, @@ -616,8 +623,13 @@ async def limited_request_func(request_func_input, pbar): tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate, burstiness): prompt, prompt_len, output_len, mm_content = request - request_func_input = RequestFuncInput(model=model_id, - model_name=model_name, + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id, req_model_name = req_lora_module, req_lora_module + + request_func_input = RequestFuncInput(model=req_model_id, + model_name=req_model_name, prompt=prompt, api_url=api_url, prompt_len=prompt_len, @@ -900,6 +912,7 @@ def main(args: argparse.Namespace): ignore_eos=args.ignore_eos, goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, )) # Save config and results to json @@ -1237,5 +1250,12 @@ def main(args: argparse.Namespace): "If not specified, the model name will be the " "same as the ``--model`` argument. ") + parser.add_argument("--lora-modules", + nargs='+', + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.") + args = parser.parse_args() main(args)