Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ async def benchmark(
ignore_eos: bool,
goodput_config_dict: Dict[str, float],
max_concurrency: Optional[int],
lora_modules: Optional[List[str]],
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
Expand All @@ -562,6 +563,7 @@ async def benchmark(
multi_modal_content=test_mm_content,
ignore_eos=ignore_eos,
)

test_output = await request_func(request_func_input=test_input)
if not test_output.success:
raise ValueError(
Expand All @@ -570,6 +572,11 @@ async def benchmark(
else:
print("Initial test run completed. Starting main benchmark run...")

if lora_modules:
# For each input request, choose a LoRA module at random.
lora_modules = iter(
[random.choice(lora_modules) for _ in range(len(input_requests))])

if profile:
print("Starting profiler...")
profile_input = RequestFuncInput(model=model_id,
Expand Down Expand Up @@ -616,8 +623,13 @@ async def limited_request_func(request_func_input, pbar):
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness):
prompt, prompt_len, output_len, mm_content = request
request_func_input = RequestFuncInput(model=model_id,
model_name=model_name,
req_model_id, req_model_name = model_id, model_name
if lora_modules:
req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module

request_func_input = RequestFuncInput(model=req_model_id,
model_name=req_model_name,
prompt=prompt,
api_url=api_url,
prompt_len=prompt_len,
Expand Down Expand Up @@ -900,6 +912,7 @@ def main(args: argparse.Namespace):
ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
))

# Save config and results to json
Expand Down Expand Up @@ -1237,5 +1250,12 @@ def main(args: argparse.Namespace):
"If not specified, the model name will be the "
"same as the ``--model`` argument. ")

parser.add_argument("--lora-modules",
nargs='+',
default=None,
help="A subset of LoRA module names passed in when "
"launching the server. For each request, the "
"script chooses a LoRA module at random.")

args = parser.parse_args()
main(args)