diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7b74bc9c3520..f60aeaf93c2f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -204,6 +204,7 @@ steps: - VLLM_USE_V1=1 pytest -v -s v1/engine - VLLM_USE_V1=1 pytest -v -s v1/sample - VLLM_USE_V1=1 pytest -v -s v1/worker + - VLLM_USE_V1=1 pytest -v -s v1/structured_output - VLLM_USE_V1=1 pytest -v -s v1/test_stats.py - VLLM_USE_V1=1 pytest -v -s v1/test_utils.py # TODO: accuracy does not match, whether setting diff --git a/.gitignore b/.gitignore index 89dab8f13bab..e40752f4dea0 100644 --- a/.gitignore +++ b/.gitignore @@ -197,7 +197,7 @@ _build/ hip_compat.h # Benchmark dataset -benchmarks/*.json +benchmarks/**/*.json # Linting actionlint diff --git a/benchmarks/benchmark_guided.py b/benchmarks/benchmark_guided.py deleted file mode 100644 index 2e0f6c6b5d20..000000000000 --- a/benchmarks/benchmark_guided.py +++ /dev/null @@ -1,507 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Benchmark guided decoding throughput.""" -import argparse -import dataclasses -import json -import os -import random -import time - -import datasets -import pandas as pd -import uvloop -from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) -from vllm.sampling_params import GuidedDecodingParams -from vllm.utils import FlexibleArgumentParser, merge_async_iterators - - -@dataclasses.dataclass -class SampleRequest: - """A class representing a single inference request for benchmarking. - - Attributes: - prompt: The input text prompt for the model. - multi_modal_data: Optional dictionary containing multi-modal data (e.g. - images). - prompt_len: The length of the prompt in tokens. - expected_output_len: The expected length of the output in tokens. - """ - prompt: str - prompt_len: int - expected_output_len: int - schema: dict - structure_type: str = 'json' - completion: str = None - - -def run_vllm(requests: list[SampleRequest], - engine_args: EngineArgs, - n: int, - guided_decoding_rate: float = 1.0, - warmup: bool = False) -> float: - from vllm import LLM, SamplingParams - llm = LLM(**vars(engine_args)) - assert all( - llm.llm_engine.model_config.max_model_len >= ( - request.prompt_len + request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests.") - - # Add the requests to the engine. - prompts: list[str] = [] - sampling_params: list[SamplingParams] = [] - # create a list containing random selected true or false - guided_decoding_req_idx = random.sample( - range(len(requests)), int(len(requests) * guided_decoding_rate)) - - if warmup: - print(">>>>> Running warmup prompt, for the first 5") - # We setup the first 5 requests to warmup FSM - # if using xgrammar dataset, we will skip warmup - warmup_requests = requests[:5] - for i, request in enumerate(warmup_requests): - prompts.append(request.prompt) - sampling_params.append( - SamplingParams( - n=n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=request.expected_output_len, - guided_decoding=GuidedDecodingParams(json=request.schema) - if guided_decoding_rate > 0 else None, - )) - llm.generate(prompts, sampling_params, use_tqdm=False) - - print(">>>>> Benchmark started...") - prompts = [] - sampling_params = [] - for i, request in enumerate(requests): - prompts.append(request.prompt) - sampling_params.append( - SamplingParams( - n=n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=request.expected_output_len, - guided_decoding=GuidedDecodingParams( - **{request.structure_type: request.schema}) - if i in guided_decoding_req_idx else None, - )) - - start = time.perf_counter() - outputs = llm.generate(prompts, sampling_params, use_tqdm=False) - ret = [] - for output, request in zip(outputs, requests): - generated_text = output.outputs[0].text - ret.append({ - "generated": generated_text, - "expected": request.completion - }) - end = time.perf_counter() - return end - start, ret - - -async def run_vllm_async( - requests: list[SampleRequest], - engine_args: AsyncEngineArgs, - n: int, - guided_decoding_rate: float = 1.0, - warmup: bool = False, - disable_frontend_multiprocessing: bool = False) -> float: - from vllm import SamplingParams - - async with build_async_engine_client_from_engine_args( - engine_args, disable_frontend_multiprocessing) as llm: - - assert all( - llm.model_config.max_model_len >= (request.prompt_len + - request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests.") - - # Add the requests to the engine. - prompts: list[str] = [] - sampling_params: list[SamplingParams] = [] - guided_decoding_req_idx = random.sample( - range(len(requests)), int(len(requests) * guided_decoding_rate)) - - if warmup: - print(">>>>>> Running warmup prompt, for the first 5") - # We setup the first 5 requests to warmup FSM - # if using xgrammar dataset, we will skip warmup - warmup_requests = requests[:5] - for i, request in enumerate(warmup_requests): - prompts.append(request.prompt) - sampling_params.append( - SamplingParams( - n=n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=request.expected_output_len, - guided_decoding=GuidedDecodingParams( - json=request.schema) - if guided_decoding_rate > 0 else None, - )) - generators = [] - for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): - generator = llm.generate(prompt, sp, request_id=f"test{i}") - generators.append(generator) - all_gens = merge_async_iterators(*generators) - async for i, res in all_gens: - pass - - print(">>>>> Benchmark started...") - prompts = [] - sampling_params = [] - for i, request in enumerate(requests): - prompts.append(request.prompt) - sampling_params.append( - SamplingParams( - n=n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=request.expected_output_len, - guided_decoding=GuidedDecodingParams(json=request.schema) - if i in guided_decoding_req_idx else None, - )) - - generators = [] - start_time = [] - latencies = [] - start = time.perf_counter() - for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): - generator = llm.generate(prompt, sp, request_id=f"test{i}") - generators.append(generator) - start_time.append(time.perf_counter()) - latencies.append([]) - all_gens = merge_async_iterators(*generators) - generated_texts = [''] * len(requests) - async for i, res in all_gens: - generated_texts[i] = res.outputs[0].text - lat = time.perf_counter() - start_time[i] - latencies[i].append(lat) - ret = [{ - 'generated': gt, - 'expected': req.completion - } for gt, req in zip(generated_texts, requests)] - end = time.perf_counter() - first_latency = pd.Series([lat[0] * 1000 for lat in latencies]) - next_latency = pd.Series([(lat[-1] - lat[0]) / len(lat[1:]) * 1000 - for lat in latencies]) - return end - start, ret, (first_latency, next_latency) - - -def sample_requests(tokenizer: PreTrainedTokenizerBase, - args: argparse.Namespace) -> list[SampleRequest]: - if args.dataset == 'json': - if args.json_schema_path is None: - dir_path = os.path.dirname(os.path.realpath(__file__)) - args.json_schema_path = os.path.join(dir_path, - "structured_schemas", - "structured_schema_1.json") - with open(args.json_schema_path) as f: - schema = json.load(f) - prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501 - input_len = len(tokenizer(prompt).input_ids) - print(f"Input length of the prompt: {input_len} tokens") - requests = [ - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=schema, - structure_type=args.structure_type) - for _ in range(args.num_prompts) - ] - - elif args.dataset == "grammar": - schema = """ - ?start: select_statement - - ?select_statement: "SELECT " column_list " FROM " table_name - - ?column_list: column_name ("," column_name)* - - ?table_name: identifier - - ?column_name: identifier - - ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ - """ - prompt = "Generate an SQL query to show the 'username' \ - and 'email' from the 'users' table." - - input_len = len(tokenizer(prompt).input_ids) - print(f"Input length of the prompt: {input_len} tokens") - requests = [ - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=schema, - structure_type=args.structure_type) - for _ in range(args.num_prompts) - ] - - elif args.dataset == "regex": - regex = r"\w+@\w+\.com\n" - args.regex = regex - prompt = "Generate an email address for Alan Turing, \ - who works in Enigma. End in .com and new line. \ - Example result: alan.turing@enigma.com\n" - - input_len = len(tokenizer(prompt).input_ids) - print(f"Input length of the prompt: {input_len} tokens") - requests = [ - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=regex, - structure_type=args.structure_type) - for _ in range(args.num_prompts) - ] - - elif args.dataset == "choice": - choice = ["Positive", "Negative"] - args.choice = choice - prompt = "Classify this sentiment: vLLM is wonderful!" - input_len = len(tokenizer(prompt).input_ids) - print(f"Input length of the prompt: {input_len} tokens") - requests = [ - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=choice, - structure_type=args.structure_type) - for _ in range(args.num_prompts) - ] - - elif args.dataset == "xgrammar_bench": - args.warmup = False - requests: list[SampleRequest] = [] - dataset = datasets.load_dataset("NousResearch/json-mode-eval", - split="train") - print(f"dataset has {len(dataset)} entries") - len_dataset = len(dataset) - for data_point_idx in range(args.num_prompts): - idx = data_point_idx - while idx >= len_dataset: - idx -= len_dataset - schema = dataset["schema"][idx] - prompt = tokenizer.apply_chat_template(dataset["prompt"][idx], - tokenize=False) - input_len = len(tokenizer(prompt).input_ids) - completion = dataset["completion"][idx] - - requests.append( - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=schema, - completion=completion)) - - return requests - - -def evaluate(ret, args): - - def _eval_correctness_json(expected, actual): - # extract json string from string using regex - import re - actual = actual.replace('\n', '').replace(' ', '').strip() - try: - actual = re.search(r'\{.*\}', actual).group() - actual = json.loads(actual) - except Exception: - return False - - return True - - def _eval_correctness_choice(expected, actual): - return actual in args.choice - - def _eval_correctness_regex(expected, actual): - import re - return re.match(args.regex, actual) is not None - - def _eval_correctness(expected, actual): - if args.structure_type == 'json': - return _eval_correctness_json(expected, actual) - elif args.structure_type == 'regex': - return _eval_correctness_regex(expected, actual) - elif args.structure_type == 'choice': - return _eval_correctness_choice(expected, actual) - else: - return None - - scores = [] - for res in ret: - score = _eval_correctness(res['expected'], res['generated']) - res['correctness'] = score - scores.append(score) - - not_none_scores = [score for score in scores if score is not None] - - return (sum(not_none_scores) / len(not_none_scores) * - 100) if len(not_none_scores) > 0 else None - - -def main(args: argparse.Namespace): - print(args) - random.seed(args.seed) - - # async engine is working for 'regex', 'choice' and 'grammar' - if args.dataset == 'grammar': - args.structure_type = 'grammar' - args.async_engine = False - elif args.dataset == 'regex': - args.structure_type = 'regex' - args.async_engine = False - elif args.dataset == 'choice': - args.structure_type = 'choice' - args.async_engine = False - else: - args.structure_type = 'json' - - if args.no_guided_decoding: - args.guided_decoding_ratio = 0 - if args.save_results: - result_file_name = f'{args.guided_decoding_ratio}guided' - result_file_name += f"_{args.model.split('/')[-1]}" - result_file_name += f"_{args.dataset}" - result_file_name += f"_{args.num_prompts}" - result_file_name += f"_out{args.output_len}" - result_file_name += f"_async{args.async_engine}" - result_file_name += f"_warmup{args.warmup}" - result_file_name += f"_chunkedprefill{args.enable_chunked_prefill}" - result_file_name += ".txt" - else: - result_file_name = None - - # Synthesize a prompt with the given input length. - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code) - requests = sample_requests(tokenizer, args) - - if args.async_engine: - engine_args = AsyncEngineArgs.from_cli_args(args) - elapsed_time, ret, (first_latency, next_latency) = uvloop.run( - run_vllm_async(requests, engine_args, args.n, - args.guided_decoding_ratio, args.warmup, - args.disable_frontend_multiprocessing)) - else: - engine_args = EngineArgs.from_cli_args(args) - elapsed_time, ret = run_vllm(requests, engine_args, args.n, - args.guided_decoding_ratio, args.warmup) - first_latency, next_latency = None, None - - score = evaluate(ret, args) - total_num_tokens = sum(request.prompt_len + request.expected_output_len - for request in requests) - total_output_tokens = sum(request.expected_output_len - for request in requests) - if first_latency is not None: - latency_breakdown = "\nFirst token latency(msecs):\n" - latency_breakdown += f"{first_latency.describe()}" - latency_breakdown += "\nNext token latency(msecs):\n" - latency_breakdown += f"{next_latency.describe()}" - print( - f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " - f"{total_output_tokens / elapsed_time:.2f} output tokens/s", - f"Correct rate is {score} %", - f"{latency_breakdown if first_latency is not None else ''}") - - # Output JSON results if specified - if args.output_json or result_file_name: - results = { - "elapsed_time": elapsed_time, - "num_requests": len(requests), - "total_num_tokens": total_num_tokens, - "total_output_tokens": total_output_tokens, - "requests_per_second": len(requests) / elapsed_time, - "tokens_per_second": f"{total_num_tokens / elapsed_time:.2f}", - "output_tokens_per_second": - f"{total_output_tokens / elapsed_time:.2f}", - "correct_rate(%)": score - } - results = {"outputs": ret, **results} - if first_latency is not None: - results["first_token_latency(msecs)"] = first_latency.describe( - ).to_dict() - results["next_token_latency(msecs)"] = next_latency.describe( - ).to_dict() - if args.output_json: - with open(args.output_json, "w") as f: - json.dump(results, f, indent=4) - elif result_file_name: - with open(result_file_name, "w") as f: - json.dump(results, f, indent=4) - - -if __name__ == "__main__": - parser = FlexibleArgumentParser(description="Benchmark guided decoding.") - parser = AsyncEngineArgs.add_cli_args(parser) - - parser.add_argument("--output-len", - type=int, - default=512, - help="Output length for each request. Overrides the " - "output length from the dataset.") - parser.add_argument( - "--dataset", - default='json', - choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench']) - parser.add_argument("--json_schema_path", - type=str, - default=None, - help="Path to json schema.") - parser.add_argument("--n", - type=int, - default=1, - help="Number of generated sequences per prompt.") - parser.add_argument("--num-prompts", - type=int, - default=10, - help="Number of prompts to process.") - parser.add_argument( - '--output-json', - type=str, - default=None, - help='Path to save the throughput results in JSON format.') - parser.add_argument("--async-engine", - action='store_true', - default=False, - help="Use vLLM async engine rather than LLM class.") - parser.add_argument("--no-guided-decoding", - action='store_true', - default=False, - help="Whether to disable JSON decoding or not.") - parser.add_argument("--guided-decoding-ratio", - type=float, - default=1.0, - help="Ratio of Guided Decoding requests") - parser.add_argument("--disable-frontend-multiprocessing", - action='store_true', - default=False, - help="Disable decoupled async engine frontend.") - parser.add_argument("--warmup", - action="store_true", - default=False, - help="Run warmup prompts before benchmark.") - parser.add_argument("--save-results", - action="store_true", - default=False, - help="save output results.") - args = parser.parse_args() - if args.tokenizer is None: - args.tokenizer = args.model - main(args) diff --git a/benchmarks/benchmark_serving_guided.py b/benchmarks/benchmark_serving_structured_output.py similarity index 94% rename from benchmarks/benchmark_serving_guided.py rename to benchmarks/benchmark_serving_structured_output.py index 6c132d05f1b6..3d43e04598f5 100644 --- a/benchmarks/benchmark_serving_guided.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -r"""Benchmark online serving throughput with guided decoding. +r"""Benchmark online serving throughput with structured outputs. On the server side, run one of the following commands: (vLLM OpenAI API server) @@ -9,12 +9,12 @@ ./launch_tgi_server.sh On the client side, run: - python benchmarks/benchmark_serving_guided.py \ + python benchmarks/benchmark_serving_structured_output.py \ --backend \ --model \ --dataset json \ - --guided-decoding-ratio 1.0 \ - --guided-decoding-backend xgrammar \ + --structured-output-ratio 1.0 \ + --structured-output-backend xgrammar \ --request-rate 10 \ --num-prompts 1000 @@ -52,6 +52,9 @@ except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser +from vllm.v1.structured_output.utils import ( + has_xgrammar_unsupported_json_features) + MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -191,7 +194,17 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, requests: list[SampleRequest] = [] dataset = datasets.load_dataset("NousResearch/json-mode-eval", split="train") - print(f"dataset has {len(dataset)} entries") + full_dataset_len = len(dataset) + + def _filter_func(item): + import json + schema = json.loads(item["schema"]) + return not has_xgrammar_unsupported_json_features(schema) + + dataset = dataset.filter(_filter_func) + num_filtered_out = full_dataset_len - len(dataset) + print(f"dataset has {len(dataset)} entries after filtering " + f"out {num_filtered_out} entries with unsupported features") len_dataset = len(dataset) for data_point_idx in range(args.num_prompts): idx = data_point_idx @@ -220,21 +233,21 @@ async def get_request( burstiness: float = 1.0, ) -> AsyncGenerator[tuple[int, SampleRequest], None]: """ - Asynchronously generates requests at a specified rate + Asynchronously generates requests at a specified rate with OPTIONAL burstiness. - + Args: - input_requests: + input_requests: A list of input requests, each represented as a tuple. - request_rate: + request_rate: The rate at which requests are generated (requests/s). - burstiness (optional): - The burstiness factor of the request generation. + burstiness (optional): + The burstiness factor of the request generation. Only takes effect when request_rate is not inf. Default value is 1, which follows a Poisson process. Otherwise, the request intervals follow a gamma distribution. - A lower burstiness value (0 < burstiness < 1) results - in more bursty requests, while a higher burstiness value + A lower burstiness value (0 < burstiness < 1) results + in more bursty requests, while a higher burstiness value (burstiness > 1) results in a more uniform arrival of requests. """ input_requests = iter(input_requests) @@ -378,8 +391,8 @@ async def benchmark( selected_percentiles: list[str], ignore_eos: bool, max_concurrency: Optional[int], - guided_decoding_ratio: float, - guided_decoding_backend: str, + structured_output_ratio: float, + structured_output_backend: str, goodput_config_dict: Optional[dict[str, float]] = None, ): if backend in ASYNC_REQUEST_FUNCS: @@ -391,16 +404,18 @@ def prepare_extra_body(request) -> dict: extra_body = {} # Add the schema to the extra_body extra_body[request.structure_type] = request.schema - # Add the specific guided_decoding_backend - extra_body["guided_decoding_backend"] = guided_decoding_backend + # Add the specific structured_output_backend + extra_body["guided_decoding_backend"] = structured_output_backend return extra_body print("Starting initial single prompt test run...") - guided_decoding_req_idx = random.sample( + structured_output_req_idx = random.sample( range(len(input_requests)), - int(len(input_requests) * guided_decoding_ratio)) + int(len(input_requests) * structured_output_ratio)) test_request = input_requests[0] + test_req_extra_body = (prepare_extra_body(test_request) + if 0 in structured_output_req_idx else None) test_input = RequestFuncInput( model=model_id, prompt=test_request.prompt, @@ -408,7 +423,7 @@ def prepare_extra_body(request) -> dict: prompt_len=test_request.prompt_len, output_len=test_request.expected_output_len, ignore_eos=ignore_eos, - extra_body=prepare_extra_body(test_request), + extra_body=test_req_extra_body, ) test_output = await request_func(request_func_input=test_input) if not test_output.success: @@ -427,7 +442,7 @@ def prepare_extra_body(request) -> dict: prompt_len=test_request.prompt_len, output_len=test_request.expected_output_len, ignore_eos=ignore_eos, - extra_body=prepare_extra_body(test_request), + extra_body=test_req_extra_body, ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: @@ -465,7 +480,7 @@ async def limited_request_func(request_func_input, pbar): async for i, request in get_request(input_requests, request_rate, burstiness): extra_body = prepare_extra_body( - request) if i in guided_decoding_req_idx else None + request) if i in structured_output_req_idx else None request_func_input = RequestFuncInput( model=model_id, prompt=request.prompt, @@ -708,10 +723,10 @@ def main(args: argparse.Namespace): else: args.structure_type = 'guided_json' - if args.no_guided_decoding: - args.guided_decoding_ratio = 0 + if args.no_structured_output: + args.structured_output_ratio = 0 if args.save_results: - result_file_name = f'{args.guided_decoding_ratio}guided' + result_file_name = f'{args.structured_output_ratio}guided' result_file_name += f"_{backend}" result_file_name += f"_{args.request_rate}qps" result_file_name += f"_{args.model.split('/')[-1]}" @@ -744,8 +759,8 @@ def main(args: argparse.Namespace): ], ignore_eos=args.ignore_eos, max_concurrency=args.max_concurrency, - guided_decoding_ratio=args.guided_decoding_ratio, - guided_decoding_backend=args.guided_decoding_backend, + structured_output_ratio=args.structured_output_ratio, + structured_output_backend=args.structured_output_backend, goodput_config_dict=goodput_config_dict, )) @@ -943,19 +958,19 @@ def main(args: argparse.Namespace): "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "and the blog: https://hao-ai-lab.github.io/blogs/distserve") - parser.add_argument("--no-guided-decoding", + parser.add_argument("--no-structured-output", action='store_true', default=False, help="Whether to disable JSON decoding or not.") - parser.add_argument("--guided-decoding-ratio", + parser.add_argument("--structured-output-ratio", type=float, default=1.0, - help="Ratio of Guided Decoding requests") - parser.add_argument("--guided-decoding-backend", + help="Ratio of Structured Outputs requests") + parser.add_argument("--structured-output-backend", type=str, choices=["outlines", "lm-format-enforcer", "xgrammar"], default="xgrammar", - help="Backend to use for guided decoding") + help="Backend to use for structured outputs") args = parser.parse_args() main(args) diff --git a/benchmarks/run_structured_output_benchmark.sh b/benchmarks/run_structured_output_benchmark.sh new file mode 100755 index 000000000000..8a777320f735 --- /dev/null +++ b/benchmarks/run_structured_output_benchmark.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# Define the model to use +MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"} + +# Define the backend to use +BACKEND=${2:-"vllm"} + +# Define the dataset to use +DATASET=${3:-"xgrammar_bench"} + +# Define the guided decoding backend +GUIDED_BACKEND=${4:-"xgrammar"} + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +OUTPUT_DIR=${5:-"$SCRIPT_DIR/structured_output_benchmark_results"} + +GUIDED_RATIO=${6:-0.5} + +# Create output directory if it doesn't exist +mkdir -p "$OUTPUT_DIR" + +# Define QPS values to test +QPS_VALUES=(70 60 50 25 20 15 10) + +# Common parameters +COMMON_PARAMS="--backend $BACKEND \ + --model $MODEL \ + --dataset $DATASET \ + --structured-output-backend $GUIDED_BACKEND \ + --structured-output-ratio $GUIDED_RATIO \ + --save-results \ + --result-dir $OUTPUT_DIR" + +echo "Starting structured output benchmark with model: $MODEL" +echo "Backend: $BACKEND" +echo "Dataset: $DATASET" +echo "Structured output backend: $GUIDED_BACKEND" +echo "Results will be saved to: $OUTPUT_DIR" +echo "----------------------------------------" + +# Run benchmarks with different QPS values +for qps in "${QPS_VALUES[@]}"; do + echo "Running benchmark with QPS: $qps" + + # Get git hash and branch for the filename + GIT_HASH=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown") + GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown") + + # Construct filename for this run + FILENAME="${GUIDED_BACKEND}_${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json" + + # Run the benchmark + python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \ + --request-rate $qps \ + --result-filename "$FILENAME" \ + --port ${PORT:-8000} + + echo "Completed benchmark with QPS: $qps" + echo "----------------------------------------" +done + +echo "All benchmarks completed!" +echo "Results saved to: $OUTPUT_DIR" diff --git a/benchmarks/structured_schemas/structured_schema_1.json b/benchmarks/structured_schemas/structured_schema_1.json index 6003698469e8..1bd189c9e704 100644 --- a/benchmarks/structured_schemas/structured_schema_1.json +++ b/benchmarks/structured_schemas/structured_schema_1.json @@ -1,113 +1,25 @@ { - "$schema": - "https://json-schema.org/draft/2020-12/schema", - "title": - "User Profile", - "type": - "object", + "type": "array", + "items": { + "type": "object", "properties": { - "userId": { - "type": "string", - "description": "Unique identifier for the user." - }, - "personalInfo": { - "type": "object", - "properties": { - "firstName": { - "type": "string", - "description": "The user's first name." - }, - "lastName": { - "type": "string", - "description": "The user's last name." - }, - "age": { - "type": "integer", - "minimum": 0, - "description": "The user's age." - }, - "phoneNumbers": { - "type": - "array", - "items": { - "type": "object", - "properties": { - "type": { - "type": "string", - "enum": ["home", "work", "mobile"], - "description": "Type of phone number." - }, - "number": { - "type": "string", - "pattern": "^\\+?[1-9]\\d{1,14}$", - "description": "Phone number in E.164 format." - } - }, - "required": ["type", "number"] - }, - "description": - "List of phone numbers associated with the user." - } - }, - "required": ["firstName", "lastName"] - }, - "address": { - "type": "object", - "properties": { - "street": { - "type": "string", - "description": "Street address." - }, - "city": { - "type": "string", - "description": "City name." - }, - "state": { - "type": "string", - "description": "State or province." - }, - "postalCode": { - "type": "string", - "pattern": "^\\d{5}(-\\d{4})?$", - "description": "Postal code." - }, - "country": { - "type": "string", - "description": "Country name." - } - }, - "required": ["street", "city", "state", "postalCode", "country"] - }, - "preferences": { - "type": "object", - "properties": { - "newsletterSubscribed": { - "type": - "boolean", - "description": - "Indicates if the user is subscribed to the newsletter." - }, - "favoriteCategories": { - "type": "array", - "items": { - "type": "string" - }, - "description": "List of user's favorite categories." - } - }, - "required": ["newsletterSubscribed"] - }, - "accountStatus": { - "type": "string", - "enum": ["active", "inactive", "suspended"], - "description": "Current status of the user's account." - }, - "registrationDate": { - "type": "string", - "format": "date-time", - "description": "ISO 8601 formatted date-time of user registration." - } + "name": { "type": "string" }, + "race": { "type": "string" }, + "class": { "type": "string" }, + "level": { "type": "integer" }, + "background": { "type": "string" }, + "alignment": { "type": "string" }, + "backstory": { "type": "string" } }, - "required": - ["userId", "personalInfo", "address", "accountStatus", "registrationDate"] -} \ No newline at end of file + "required": [ + "name", + "race", + "class", + "level", + "background", + "alignment", + "backstory" + ] + } +} + diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index f45c21ab75ba..738ab2ef03de 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Optional -from vllm.config import CacheConfig, ModelConfig, SchedulerConfig +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.v1.core.scheduler import Scheduler, SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus +from vllm.v1.structured_output import StructuredOutputManager EOS_TOKEN_ID = 50256 @@ -36,13 +37,21 @@ def create_scheduler( swap_space=0, cache_dtype="auto", ) + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + ) cache_config.num_gpu_blocks = 10000 - return Scheduler(scheduler_config, - model_config, - cache_config, - speculative_config=None, - lora_config=None, - log_stats=True) + return Scheduler( + scheduler_config, + model_config, + cache_config, + speculative_config=None, + lora_config=None, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) def create_requests( @@ -249,7 +258,9 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[]) + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], @@ -299,7 +310,9 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[]) + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], @@ -347,7 +360,9 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[]) + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], @@ -392,7 +407,9 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[]) + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[requests[0].request_id], diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index b00e168db9d3..6d4278b4c871 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -29,6 +29,7 @@ def sample_regex(): r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") +# Note: Ensure this only uses attributes compatible with xgrammar @pytest.fixture def sample_json_schema(): return { @@ -44,9 +45,7 @@ def sample_json_schema(): "type": "array", "items": { "type": "string", - "maxLength": 10 - }, - "minItems": 3 + } }, "work_history": { "type": "array", @@ -71,8 +70,9 @@ def sample_json_schema(): } +# A schema unsupported by xgrammar @pytest.fixture -def sample_complex_json_schema(): +def unsupported_json_schema(): return { "type": "object", "properties": { @@ -150,7 +150,19 @@ def sample_guided_choice(): @pytest.fixture -def sample_sql_statements(): +def sample_sql_ebnf(): + return """ +root ::= select_statement +select_statement ::= "SELECT" column "from" table "where" condition +column ::= "col_1" | "col_2" +table ::= "table_1" | "table_2" +condition ::= column "=" number +number ::= "1" | "2" +""" + + +@pytest.fixture +def sample_sql_lark(): return (""" start: select_statement select_statement: "SELECT" column "from" table "where" condition diff --git a/tests/v1/entrypoints/llm/__init__.py b/tests/v1/entrypoints/llm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py new file mode 100644 index 000000000000..871739bcf164 --- /dev/null +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -0,0 +1,269 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json + +import jsonschema +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.outputs import RequestOutput +from vllm.sampling_params import GuidedDecodingParams, SamplingParams + +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" +GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_json_completion(monkeypatch, sample_json_schema, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_json_schema, + backend=guided_decoding_backend)) + outputs = llm.generate(prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, schema=sample_json_schema) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_json_object(monkeypatch, guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=1.0, + max_tokens=100, + n=2, + guided_decoding=GuidedDecodingParams( + json_object=True, + backend=guided_decoding_backend)) + + outputs = llm.generate( + prompts=("Generate a JSON object with curly braces for a person with " + "name and age fields for John Smith who is 31 years old."), + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + + for i in range(2): + generated_text = output.outputs[i].text + print(generated_text) + assert generated_text is not None + + # Parse to verify it is valid JSON + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=unsupported_json_schema, + backend=guided_decoding_backend)) + with pytest.raises(ValueError, + match="The provided JSON schema contains features " + "not supported by xgrammar."): + llm.generate(prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {unsupported_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + grammar=sample_sql_ebnf, + backend=guided_decoding_backend)) + outputs = llm.generate( + prompts=("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1"), + sampling_params=sampling_params, + use_tqdm=True, + ) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( + " ", "") + + assert generated_text.strip() == ground_truth + + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_grammar_lark(monkeypatch, sample_sql_lark, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + grammar=sample_sql_lark, + backend=guided_decoding_backend)) + outputs = llm.generate( + prompts=("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1"), + sampling_params=sampling_params, + use_tqdm=True, + ) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(sample_sql_lark) + parser.parse(generated_text) + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( + " ", "") + + assert generated_text.strip() == ground_truth + + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_grammar_ebnf_invalid(monkeypatch, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + grammar="not a grammar", + backend=guided_decoding_backend)) + with pytest.raises(ValueError, + match="Failed to convert the grammar " + "from Lark to EBNF."): + llm.generate( + prompts=("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1"), + sampling_params=sampling_params, + use_tqdm=True, + ) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams( + regex=sample_regex, + backend=guided_decoding_backend)) + with pytest.raises(ValueError, + match="Regex guided decoding is not supported."): + llm.generate(prompts=[ + f"Give an example IPv4 address with this regex: {sample_regex}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + # Once regex is supported -- + #assert outputs is not None + #for output in outputs: + # assert output is not None + # assert isinstance(output, RequestOutput) + # prompt = output.prompt + # generated_text = output.outputs[0].text + # print(generated_text) + # assert generated_text is not None + # assert re.fullmatch(sample_regex, generated_text) is not None + # print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize("guided_decoding_backend", + GUIDED_DECODING_BACKENDS_V1) +def test_guided_choice_completion(monkeypatch, sample_guided_choice, + guided_decoding_backend: str): + monkeypatch.setenv("VLLM_USE_V1", "1") + llm = LLM(model=MODEL_NAME, max_model_len=1024) + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams( + choice=sample_guided_choice, + backend=guided_decoding_backend)) + outputs = llm.generate( + prompts="The best language for type-safe systems programming is ", + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + print(generated_text) + assert generated_text is not None + assert generated_text in sample_guided_choice + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/v1/structured_output/__init__.py b/tests/v1/structured_output/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/structured_output/test_utils.py b/tests/v1/structured_output/test_utils.py new file mode 100644 index 000000000000..3aa86cbec533 --- /dev/null +++ b/tests/v1/structured_output/test_utils.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from vllm.v1.structured_output.utils import ( + has_xgrammar_unsupported_json_features) + + +@pytest.fixture +def unsupported_string_schemas(): + return [ + { + "type": "string", + "pattern": "^[a-zA-Z]+$" + }, + { + "type": "string", + "enum": ["active", "inactive", "pending"] + }, + { + "type": "string", + "minLength": 1 + }, + { + "type": "string", + "maxLength": 100 + }, + { + "type": "string", + "format": "email" + }, + ] + + +@pytest.fixture +def unsupported_integer_schemas(): + return [ + { + "type": "integer", + "minimum": 0 + }, + { + "type": "integer", + "maximum": 120 + }, + { + "type": "integer", + "exclusiveMinimum": 120 + }, + { + "type": "integer", + "exclusiveMaximum": 120 + }, + { + "type": "integer", + "multipleOf": 120 + }, + ] + + +@pytest.fixture +def unsupported_number_schemas(): + return [ + { + "type": "number", + "minimum": 0 + }, + { + "type": "number", + "maximum": 120 + }, + { + "type": "number", + "exclusiveMinimum": 120 + }, + { + "type": "number", + "exclusiveMaximum": 120 + }, + { + "type": "number", + "multipleOf": 120 + }, + ] + + +@pytest.fixture +def unsupported_array_schemas(): + return [ + { + "type": "array", + "uniqueItems": True + }, + { + "type": "array", + "contains": { + "type": "string" + } + }, + { + "type": "array", + "minContains": 1 + }, + { + "type": "array", + "maxContains": 5 + }, + { + "type": "array", + "minItems": 1 + }, + { + "type": "array", + "maxItems": 10 + }, + ] + + +@pytest.fixture +def unsupported_object_schemas(): + return [ + { + "type": "object", + "minProperties": 1 + }, + { + "type": "object", + "maxProperties": 5 + }, + { + "type": "object", + "propertyNames": { + "pattern": "^[a-z]+$" + } + }, + { + "type": "object", + "patternProperties": { + "^S": { + "type": "string" + } + } + }, + ] + + +@pytest.fixture +def supported_schema(): + return { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "status": { + "type": "string" + }, + "scores": { + "type": "array", + "items": { + "type": "number" + } + }, + "address": { + "type": "object", + "properties": { + "street": { + "type": "string" + }, + "city": { + "type": "string" + } + } + } + } + } + + +@pytest.mark.parametrize("schema_type", [ + "unsupported_string_schemas", "unsupported_integer_schemas", + "unsupported_number_schemas", "unsupported_array_schemas", + "unsupported_object_schemas" +]) +def test_unsupported_json_features_by_type(schema_type, request): + schemas = request.getfixturevalue(schema_type) + for schema in schemas: + assert has_xgrammar_unsupported_json_features( + schema), f"Schema should be unsupported: {schema}" + + +def test_supported_json_features(supported_schema): + assert not has_xgrammar_unsupported_json_features( + supported_schema), "Schema should be supported" diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index ff4058a3b923..345519a07e41 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -72,6 +72,8 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, ) @@ -135,6 +137,8 @@ def test_update_states_request_finished(model_runner): num_common_prefix_blocks=0, finished_req_ids={req_id}, free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, ) metadata_before = model_runner.input_batch.sampling_metadata @@ -165,6 +169,8 @@ def test_update_states_request_resumed(model_runner): num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, ) model_runner._update_states(scheduler_output) @@ -190,6 +196,8 @@ def test_update_states_request_resumed(model_runner): num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, ) metadata_before = model_runner.input_batch.sampling_metadata @@ -221,6 +229,8 @@ def test_update_states_no_changes(model_runner): num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, ) metadata_before = model_runner.input_batch.sampling_metadata @@ -256,6 +266,8 @@ def test_update_states_request_unscheduled(model_runner): num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, ) metadata_before = model_runner._update_states(scheduler_output) diff --git a/vllm/utils.py b/vllm/utils.py index 1de2180deb50..3e2f6d24bc33 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import argparse import asyncio import concurrent @@ -8,6 +10,7 @@ import enum import gc import getpass +import importlib import importlib.metadata import importlib.util import inspect @@ -23,6 +26,7 @@ import threading import time import traceback +import types import uuid import warnings import weakref @@ -982,7 +986,7 @@ def current_stream() -> torch.cuda.Stream: return _current_stream -def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None: +def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: """Set up function tracing for the current thread, if enabled via the VLLM_TRACE_FUNCTION environment variable """ @@ -1977,7 +1981,7 @@ def measure(self): self.non_torch_memory = self.cuda_memory - self.torch_memory self.timestamp = time.time() - def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot": + def __sub__(self, other: MemorySnapshot) -> MemorySnapshot: return MemorySnapshot( torch_peak=self.torch_peak - other.torch_peak, cuda_memory=self.cuda_memory - other.cuda_memory, @@ -2306,3 +2310,54 @@ def wrapped_init(self, *args, **kwargs) -> None: type.__setattr__(cls, '__init__', wrapped_init) return cls + + +class LazyLoader(types.ModuleType): + """ + LazyLoader module borrowed from Tensorflow + https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py + with a addition of "module caching". + + Lazily import a module, mainly to avoid pulling in large dependencies. + Modules such as `xgrammar` might do additional side effects, so we + only want to use this when it is needed, delaying all eager effects + """ + + def __init__( + self, + local_name: str, + parent_module_globals: dict[str, Any], + name: str, + ): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + self._module: types.ModuleType | None = None + + super().__init__(str(name)) + + def _load(self) -> types.ModuleType: + # Import the target module and insert it into the parent's namespace + try: + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + # The additional add to sys.modules + # ensures library is actually loaded. + sys.modules[self._local_name] = module + except ModuleNotFoundError as err: + raise err from None + + # Update this object's dict so that if someone keeps a + # reference to the LazyLoader, lookups are efficient + # (__getattr__ is only called on lookups that fail). + self.__dict__.update(module.__dict__) + return module + + def __getattr__(self, item: Any) -> Any: + if self._module is None: + self._module = self._load() + return getattr(self._module, item) + + def __dir__(self) -> list[str]: + if self._module is None: + self._module = self._load() + return dir(self._module) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index db14c9455a1f..70e36e2dc152 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import time from collections import deque from collections.abc import Iterable @@ -18,6 +20,7 @@ from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus +from vllm.v1.structured_output import StructuredOutputManager logger = init_logger(__name__) @@ -32,12 +35,14 @@ def __init__( lora_config: Optional[LoRAConfig], speculative_config: Optional[SpeculativeConfig], log_stats: bool, + structured_output_manager: StructuredOutputManager, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config self.speculative_config = speculative_config self.log_stats = log_stats + self.structured_output_manager = structured_output_manager # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -97,7 +102,7 @@ def __init__( self.encoder_cache_manager = EncoderCacheManager( cache_size=encoder_cache_size) - def schedule(self) -> "SchedulerOutput": + def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. # Each request just has the num_computed_tokens and @@ -114,6 +119,14 @@ def schedule(self) -> "SchedulerOutput": scheduled_running_reqs: list[Request] = [] preempted_reqs: list[Request] = [] + # NOTE: structured_output_request_ids maps + # a request's (request that uses structured output) + # request_id to the running request index. + # This will helps us determine to slice the grammar bitmask + # and only applies valid mask for requests that + # uses structured decoding. + structured_output_request_ids: dict[str, int] = {} + req_to_new_block_ids: dict[str, list[int]] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens @@ -184,6 +197,12 @@ def schedule(self) -> "SchedulerOutput": # Schedule the request. scheduled_running_reqs.append(request) self.scheduled_req_ids.add(request.request_id) + if request.use_structured_output: + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids[request.request_id] = req_index req_to_new_block_ids[request.request_id] = [ b.block_id for b in new_blocks ] @@ -219,6 +238,10 @@ def schedule(self) -> "SchedulerOutput": if req.lora_request and req.lora_request.lora_int_id > 0) assert len(requested_loras) <= self.lora_config.max_loras + # Use a temporary deque to collect requests that need to be skipped + # and put back at the head of the waiting queue later + waiting_for_fsm: deque[Request] = deque() + # Next, schedule the WAITING requests. if not preempted_reqs: while self.waiting and token_budget > 0: @@ -227,6 +250,16 @@ def schedule(self) -> "SchedulerOutput": request = self.waiting[0] + if request.status == RequestStatus.WAITING_FOR_FSM: + structured_output_req = request.structured_output_request + if structured_output_req and structured_output_req.grammar: + request.status = RequestStatus.WAITING + else: + waiting_structured_output_req = self.waiting.popleft() + waiting_for_fsm.appendleft( + waiting_structured_output_req) + continue + # Check that adding the request still respects the max_loras # constraint. if self.lora_config and request.lora_request: @@ -281,6 +314,10 @@ def schedule(self) -> "SchedulerOutput": break self.waiting.popleft() + if request.use_structured_output: + structured_output_request_ids[ + request.request_id] = req_index + req_index += 1 self.running.append(request) self.scheduled_req_ids.add(request.request_id) self.request_scheduled(request, scheduled_timestamp) @@ -311,6 +348,10 @@ def schedule(self) -> "SchedulerOutput": self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget + # Put back any skipped requests at the head of the waiting queue + if waiting_for_fsm: + self.waiting.extendleft(waiting_for_fsm) + # Check if the scheduling constraints are satisfied. total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens @@ -331,6 +372,11 @@ def schedule(self) -> "SchedulerOutput": self.kv_cache_manager.get_num_common_prefix_blocks( any_request, len(self.running))) + grammar_bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + len(self.running), + ) # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request(req, @@ -369,6 +415,8 @@ def schedule(self) -> "SchedulerOutput": # the previous and the current steps. finished_req_ids=self.finished_req_ids, free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + structured_output_request_ids=structured_output_request_ids, + grammar_bitmask=grammar_bitmask, ) self.finished_req_ids = set() @@ -381,7 +429,7 @@ def _make_cached_request_data( num_scheduled_spec_tokens: int, new_block_ids: list[int], resumed_from_preemption: bool, - ) -> "CachedRequestData": + ) -> CachedRequestData: # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. num_computed_tokens = request.num_computed_tokens @@ -474,8 +522,8 @@ def _try_schedule_encoder_inputs( def update_from_output( self, - scheduler_output: "SchedulerOutput", - model_runner_output: "ModelRunnerOutput", + scheduler_output: SchedulerOutput, + model_runner_output: ModelRunnerOutput, ) -> EngineCoreOutputs: sampled_token_ids = model_runner_output.sampled_token_ids spec_token_ids = model_runner_output.spec_token_ids @@ -565,6 +613,15 @@ def update_from_output( # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) + if new_token_ids and request.use_structured_output: + # NOTE: structured_output_request + # should not be None if use_structured_output, we have + # check above, so safe to ignore type warning + request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + request.request_id, + new_token_ids, + ) + # Transmit partial if chunked prefill & prompt logprobs is enabled if new_token_ids or prompt_logprobs_tensors is not None: # Add EngineCoreOutput for this Request. diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index b6caa8b4ebf7..bb883acdb44b 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -1,9 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from dataclasses import dataclass from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.base import PlaceholderRange @@ -17,20 +22,20 @@ class NewRequestData: req_id: str prompt_token_ids: list[int] prompt: Optional[str] - mm_inputs: list["MultiModalKwargs"] + mm_inputs: list[MultiModalKwargs] mm_hashes: list[str] - mm_positions: list["PlaceholderRange"] - sampling_params: "SamplingParams" + mm_positions: list[PlaceholderRange] + sampling_params: SamplingParams block_ids: list[int] num_computed_tokens: int - lora_request: Optional["LoRARequest"] + lora_request: Optional[LoRARequest] @classmethod def from_request( cls, - request: "Request", + request: Request, block_ids: list[int], - ) -> "NewRequestData": + ) -> NewRequestData: return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, @@ -60,11 +65,11 @@ class CachedRequestData: @classmethod def from_request( cls, - request: "Request", + request: Request, resumed_from_preemption: bool, new_token_ids: list[int], new_block_ids: list[int], - ) -> "CachedRequestData": + ) -> CachedRequestData: return cls( req_id=request.request_id, resumed_from_preemption=resumed_from_preemption, @@ -111,3 +116,9 @@ class SchedulerOutput: # list of (req_id, encoder_input_index) tuples. # Used to free the encoder cache. free_encoder_input_ids: list[tuple[str, int]] + + # Dict of request ids to their index within the batch + # for filling the next token bitmask + structured_output_request_ids: dict[str, int] + # the bitmask for the whole batch + grammar_bitmask: Optional[npt.NDArray[np.int32]] diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 4c9d4cb467ae..32cbc10e16f6 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -72,9 +72,7 @@ def __init__( # Processor (converts Inputs --> EngineCoreRequests). self.processor = Processor( - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - lora_config=vllm_config.lora_config, + vllm_config=vllm_config, tokenizer=self.tokenizer, input_registry=input_registry, ) @@ -194,8 +192,8 @@ async def generate( * 3) Adding the Request to the Detokenizer. * 4) Adding the Request to the EngineCore (separate process). - A separate output_handler loop runs in a background AsyncIO task, - pulling outputs from EngineCore and putting them into the + A separate output_handler loop runs in a background AsyncIO task, + pulling outputs from EngineCore and putting them into the per-request AsyncStream. The caller of generate() iterates the returned AsyncGenerator, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 671a72e2112d..e60aa5d45810 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -29,6 +29,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder +from vllm.v1.structured_output import StructuredOutputManager from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -61,6 +62,8 @@ def __init__( vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks + self.structured_output_manager = StructuredOutputManager(vllm_config) + # Setup scheduler. self.scheduler = Scheduler( scheduler_config=vllm_config.scheduler_config, @@ -69,6 +72,7 @@ def __init__( lora_config=vllm_config.lora_config, speculative_config=vllm_config.speculative_config, log_stats=self.log_stats, + structured_output_manager=self.structured_output_manager, ) # Setup MM Input Mapper. @@ -131,6 +135,9 @@ def add_request(self, request: EngineCoreRequest): request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) + if req.use_structured_output: + # Start grammar compilation asynchronously + self.structured_output_manager.populate_cache(req) self.scheduler.add_request(req) @@ -148,11 +155,24 @@ def step(self) -> EngineCoreOutputs: if not self.scheduler.has_unfinished_requests(): return EngineCoreOutputs( - outputs=[], scheduler_stats=self.scheduler.make_stats()) + outputs=[], + scheduler_stats=self.scheduler.make_stats(), + ) scheduler_output = self.scheduler.schedule() + + # This case may occur when the only unfinished requests are + # structured output requests where the grammar has not finished + # compiling yet, so there's nothing to run. + if scheduler_output.total_num_scheduled_tokens == 0: + return EngineCoreOutputs( + outputs=[], + scheduler_stats=self.scheduler.make_stats(), + ) + output = self.model_executor.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( scheduler_output, output) # type: ignore + return engine_core_outputs def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 99b97ac8e6c4..213faaa45160 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -66,9 +66,7 @@ def __init__( self.tokenizer.ping() # Processor (convert Inputs --> EngineCoreRequests) - self.processor = Processor(model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - lora_config=vllm_config.lora_config, + self.processor = Processor(vllm_config=vllm_config, tokenizer=self.tokenizer, input_registry=input_registry, mm_registry=mm_registry) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index a75f0946b4ce..b3226a280d8b 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from typing import Optional, Union -from vllm.config import CacheConfig, LoRAConfig, ModelConfig +from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, PromptType, SingletonInputsAdapter) from vllm.inputs.parse import is_encoder_decoder_inputs @@ -19,39 +19,41 @@ from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.mm_input_cache import MMInputCacheClient +from vllm.v1.structured_output.utils import validate_structured_output_request class Processor: def __init__( self, - model_config: ModelConfig, - cache_config: CacheConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, tokenizer: BaseTokenizerGroup, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.decoding_config = vllm_config.decoding_config self.tokenizer = tokenizer - self.generation_config_fields = model_config.try_get_generation_config( - ) - self.input_preprocessor = InputPreprocessor(model_config, + self.generation_config_fields = ( + self.model_config.try_get_generation_config()) + self.input_preprocessor = InputPreprocessor(self.model_config, self.tokenizer, mm_registry) self.input_processor = input_registry.create_input_processor( - model_config) + self.model_config) # Multi-modal (huggingface) input mapper - self.mm_input_cache_client = MMInputCacheClient(model_config) + self.mm_input_cache_client = MMInputCacheClient(self.model_config) # Multi-modal hasher (for images) - self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \ - cache_config.enable_prefix_caching + self.use_hash = ( + not self.model_config.disable_mm_preprocessor_cache) or \ + self.cache_config.enable_prefix_caching def _validate_logprobs( self, @@ -80,6 +82,8 @@ def _validate_sampling_params( self, params: SamplingParams, ) -> None: + self._validate_structured_output(params) + if params.allowed_token_ids is None: return if not params.allowed_token_ids: @@ -125,6 +129,21 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + def _validate_structured_output(self, params: SamplingParams) -> None: + if not params.guided_decoding or not self.decoding_config: + return + if self.decoding_config.guided_decoding_backend != "xgrammar": + raise ValueError( + "Only xgrammar structured output is supported in V1.") + if (params.guided_decoding.backend + and params.guided_decoding.backend != 'xgrammar'): + raise ValueError( + "Only xgrammar structured output is supported in V1.") + if self.vllm_config.speculative_config: + raise ValueError("Structured output is not supported with " + "speculative decoding.") + validate_structured_output_request(params) + def process_inputs( self, request_id: str, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 99df54734836..29609d313306 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,13 +3,15 @@ import enum from typing import TYPE_CHECKING, Optional, Union -from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreRequest, FinishReason) +from vllm.v1.structured_output.request import StructuredOutputRequest from vllm.v1.utils import ConstantList if TYPE_CHECKING: + + from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange @@ -27,15 +29,19 @@ def __init__( sampling_params: SamplingParams, eos_token_id: Optional[int], arrival_time: float, - lora_request: Optional[LoRARequest] = None, + lora_request: Optional["LoRARequest"] = None, + structured_output_request: Optional["StructuredOutputRequest"] = None, ) -> None: self.request_id = request_id self.sampling_params = sampling_params # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id self.lora_request = lora_request + self.structured_output_request = structured_output_request - self.status = RequestStatus.WAITING + self.status = (RequestStatus.WAITING_FOR_FSM + if sampling_params.guided_decoding is not None else + RequestStatus.WAITING) self.events: list[EngineCoreEvent] = [] self.stop_reason: Union[int, str, None] = None assert sampling_params.max_tokens is not None @@ -78,6 +84,8 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": eos_token_id=request.eos_token_id, arrival_time=request.arrival_time, lora_request=request.lora_request, + structured_output_request=StructuredOutputRequest( + sampling_params=request.sampling_params), ) def queued(self, timestamp: Optional[float] = None) -> None: @@ -134,18 +142,23 @@ def get_num_encoder_tokens(self, input_id: int) -> int: num_tokens = self.mm_positions[input_id]["length"] return num_tokens + @property + def use_structured_output(self) -> bool: + return self.sampling_params.guided_decoding is not None + class RequestStatus(enum.IntEnum): """Status of a request.""" - WAITING = 0 - RUNNING = 1 - PREEMPTED = 2 - # Note: anything after PREEMPTED (2) will be considered + WAITING = enum.auto() + WAITING_FOR_FSM = enum.auto() + RUNNING = enum.auto() + PREEMPTED = enum.auto() + # Note: anything after PREEMPTED will be considered # as a finished status. - FINISHED_STOPPED = 3 - FINISHED_LENGTH_CAPPED = 4 - FINISHED_ABORTED = 5 - FINISHED_IGNORED = 6 + FINISHED_STOPPED = enum.auto() + FINISHED_LENGTH_CAPPED = enum.auto() + FINISHED_ABORTED = enum.auto() + FINISHED_IGNORED = enum.auto() @staticmethod def is_finished(status: "RequestStatus") -> bool: diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py new file mode 100644 index 000000000000..0c2e0ac2aa73 --- /dev/null +++ b/vllm/v1/structured_output/__init__.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import copy +import multiprocessing +from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Optional + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import LazyLoader +from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey, + StructuredOutputOptions) + +if TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + import xgrammar as xgr + + from vllm.v1.request import Request +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") + +logger = init_logger(__name__) + + +class StructuredOutputManager: + + def __init__(self, vllm_config: VllmConfig, max_cache_size: int = 500): + tokenizer_group = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + parallel_config=vllm_config.parallel_config, + lora_config=vllm_config.lora_config) # type: ignore[arg-type] + tokenizer_group.ping() + self.vocab_size = vllm_config.model_config.get_vocab_size() + self.vllm_config = vllm_config + + tokenizer = tokenizer_group.get_lora_tokenizer(None) + tokenizer_info = xgr.TokenizerInfo.from_huggingface( + tokenizer, vocab_size=self.vocab_size) + self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + + self.max_cache_size = max_cache_size + self.request_key_to_grammar: OrderedDict[StructuredOutputKey, + Grammar] = OrderedDict() + + # The default max_workers if not specified is the number of CPUs * 5, + # which is way too high since these tasks are CPU-bound, not I/O bound. + # We also know we would never dominate CPU usage with just grammar + # compilation, so we set it to half the number of CPUs. + max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self._grammar_bitmask = xgr.allocate_token_bitmask( + self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size) + + def __getitem__(self, key: StructuredOutputKey) -> Optional[Grammar]: + # We need to pop and re-insert the grammar here for LRU cache + # of request_key_to_grammar + if key in self.request_key_to_grammar: + # Move accessed item to the end (most recently used) + value = self.request_key_to_grammar.pop(key) + if value is not None: + self.request_key_to_grammar[key] = value + return value + return None + + def populate_cache(self, request: Request) -> None: + if request.structured_output_request is None: + return + + grammar = self.request_key_to_grammar.get( + request.structured_output_request.structured_output_key) + if grammar: + request.structured_output_request.grammar = copy.copy(grammar) + return + request.structured_output_request.grammar = self.cache(request) + + def cache(self, request: Request): + return self.executor.submit(self._executor_loop, request) + + def _executor_loop(self, request: Request) -> Grammar: + # NOTE: The structured_output_request should never be + # None in this case, but mypy can't infer this + # correctly, so we need to ignore the error here. + key = request.structured_output_request.structured_output_key # type: ignore[union-attr] + grammar = self.request_key_to_grammar.get(key) + if grammar is not None: + return copy.copy(grammar) + grammar = self.initialize_grammar(key) + # If cache is full, remove the least recently used item + if len(self.request_key_to_grammar) >= self.max_cache_size: + self.request_key_to_grammar.popitem(last=False) + self.request_key_to_grammar[key] = grammar + return copy.copy(grammar) + + def initialize_grammar(self, key: StructuredOutputKey) -> Grammar: + # Note that the request was validated in the engine core client, + # so at this point we know it is a supported type of request. + # + # TODO: we still need to handle xgrammar compilation failures + request_type, grammar_spec = key + + if request_type == StructuredOutputOptions.JSON: + # TODO -- allow any_whitespace to be configurable + # pending merge of https://github.com/vllm-project/vllm/pull/12744 + ctx = self.compiler.compile_json_schema(grammar_spec, + any_whitespace=False) + elif request_type == StructuredOutputOptions.JSON_OBJECT: + ctx = self.compiler.compile_builtin_json_grammar() + elif request_type == StructuredOutputOptions.GRAMMAR: + ctx = self.compiler.compile_grammar(grammar_spec) + else: + logger.error("Validation should have already occurred. " + "Please file an issue.") + raise ValueError( + f"grammar is not of valid supported types. ({request_type!s})") + + return Grammar( + matcher=xgr.GrammarMatcher(ctx), + vocab_size=self.vocab_size, + ctx=ctx, + ) + + def grammar_bitmask( + self, + requests: dict[str, Request], + structured_output_request_ids: dict[str, int], + batch_len: int, + ) -> Optional[npt.NDArray[np.int32]]: + # Prepare the structured output bitmask for this batch. + if not structured_output_request_ids: + return None + + # Fill the bitmask using the index of each request equal to its + # position in the batch. Resize the bitmask down to the size of + # the batch. + bitmask_tensor = self._grammar_bitmask + for req_id, batch_index in structured_output_request_ids.items(): + request = requests[req_id].structured_output_request + assert request is not None and request.grammar is not None + if not request.grammar.matcher.is_terminated(): + request.grammar.fill_bitmask(bitmask_tensor, batch_index) + if batch_len < self._grammar_bitmask.shape[0]: + bitmask_tensor = self._grammar_bitmask[:batch_len] + + # After finishing with the xgrammar operations, we convert to + # np.ndarray, because that is much more efficient for serialization + # and deserialization when sending this to the GPU workers. + return bitmask_tensor.numpy() diff --git a/vllm/v1/structured_output/grammar.py b/vllm/v1/structured_output/grammar.py new file mode 100644 index 000000000000..0e9b2b172261 --- /dev/null +++ b/vllm/v1/structured_output/grammar.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import enum +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import torch + +from vllm.logger import init_logger +from vllm.utils import LazyLoader + +if TYPE_CHECKING: + import xgrammar as xgr +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") + +logger = init_logger(__name__) + + +class StructuredOutputOptions(enum.Enum): + JSON = enum.auto() + JSON_OBJECT = enum.auto() + REGEX = enum.auto() + GRAMMAR = enum.auto() + CHOICE = enum.auto() + + +StructuredOutputKey = tuple[StructuredOutputOptions, str] + + +@dataclass +class Grammar: + # NOTE: This would be a generic-enough class for + # supporting different backends, in the future. + # For now, just xgrammar. + # + # TODO: support max_rollback_tokens + # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string + # for jump-forward decoding + + vocab_size: int + matcher: xgr.GrammarMatcher = field(hash=False) + ctx: xgr.CompiledGrammar = field(hash=False) + num_processed_tokens: int = field(default_factory=lambda: 0, + repr=False, + hash=False, + init=False) + + def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: + """Accepts a list of tokens and advances the FSM. + + Returns True if the FSM was advanced successfully. + Returns False if the FSM failed to advance. + """ + for token in tokens: + if not self.matcher.accept_token(token): + logger.error( + "Failed to advance FSM for request %s " + "for tokens %s. Please file an issue.", request_id, token) + return False + self.num_processed_tokens += 1 + return True + + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool: + return self.matcher.fill_next_token_bitmask(bitmask, idx) + + def reset(self): + self.num_processed_tokens = 0 + self.matcher.reset() + + def __copy__(self): + return Grammar( + matcher=xgr.GrammarMatcher(self.ctx), + vocab_size=self.vocab_size, + ctx=self.ctx, + ) diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py new file mode 100644 index 000000000000..fbcfd541df54 --- /dev/null +++ b/vllm/v1/structured_output/request.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import dataclasses +import functools +import json +from concurrent.futures import Future +from concurrent.futures._base import TimeoutError +from typing import Optional, Union, cast + +from vllm.sampling_params import SamplingParams +from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey, + StructuredOutputOptions) + + +@dataclasses.dataclass +class StructuredOutputRequest: + + sampling_params: SamplingParams + _grammar: Optional[Union[Future[Grammar], Grammar]] = None + + def _check_grammar_completion(self) -> bool: + # NOTE: We have to lazy import to gate circular imports + from vllm.v1.request import RequestStatus + + if isinstance(self._grammar, Future): + try: + # We will check whether the future is ready within 100 us + self._grammar = self._grammar.result(timeout=0.0001) + self.status = RequestStatus.WAITING + except TimeoutError: + return False + return True + + @property + def is_grammar_ready(self) -> bool: + return self._check_grammar_completion() + + @property + def grammar(self) -> Optional[Grammar]: + completed = self._check_grammar_completion() + return cast(Optional[Grammar], self._grammar) if completed else None + + @grammar.setter + def grammar(self, grammar: Union[Grammar, Future[Grammar]]) -> None: + self._grammar = grammar + + @functools.cached_property + def structured_output_key(self) -> StructuredOutputKey: + params = self.sampling_params.guided_decoding + assert params is not None, "params can't be None." + if params.json is not None: + if not isinstance(params.json, str): + json_str = json.dumps(params.json) + else: + json_str = params.json + return (StructuredOutputOptions.JSON, json_str) + elif params.json_object: + return (StructuredOutputOptions.JSON_OBJECT, "") + elif params.regex is not None: + return (StructuredOutputOptions.REGEX, params.regex) + elif params.choice is not None: + if not isinstance(params.choice, str): + json_str = json.dumps(params.choice) + else: + json_str = params.choice + return (StructuredOutputOptions.CHOICE, json_str) + elif params.grammar is not None: + return (StructuredOutputOptions.GRAMMAR, params.grammar) + else: + raise ValueError("No valid structured output parameter found") diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py new file mode 100644 index 000000000000..7b1adb834e74 --- /dev/null +++ b/vllm/v1/structured_output/utils.py @@ -0,0 +1,295 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +import re +from typing import TYPE_CHECKING, Any + +from vllm.sampling_params import SamplingParams +from vllm.utils import LazyLoader + +if TYPE_CHECKING: + import xgrammar as xgr +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") + + +def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: + """Check if JSON schema contains features unsupported by xgrammar.""" + + def check_object(obj: dict[str, Any]) -> bool: + if not isinstance(obj, dict): + return False + + # Check for pattern restrictions + if "pattern" in obj: + return True + + # Check for enum restrictions + if "enum" in obj: + return True + + # Check for numeric ranges + if obj.get("type") in ("integer", "number") and any( + key in obj + for key in ("minimum", "maximum", "exclusiveMinimum", + "exclusiveMaximum", "multipleOf")): + return True + + # Check for array unsupported keywords + if obj.get("type") == "array" and any( + key in obj + for key in ("uniqueItems", "contains", "minContains", + "maxContains", "minItems", "maxItems")): + return True + + # Unsupported keywords for strings + if obj.get("type") == "string" and any( + key in obj for key in ("minLength", "maxLength", "format")): + return True + + # Unsupported keywords for objects + if obj.get("type") == "object" and any( + key in obj for key in ("minProperties", "maxProperties", + "propertyNames", "patternProperties")): + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + + +def grammar_is_likely_lark(grammar_str: str) -> bool: + """ + Check if grammar appears to use Lark syntax. + + Args: + grammar_str: Input grammar string + + Returns: + bool: True if grammar appears to be in Lark format, False otherwise + + Examples: + >>> grammar_is_likely_lark("rule: 'abc'") + True + >>> grammar_is_likely_lark("rule ::= 'abc'") + False + """ + if not grammar_str or not isinstance(grammar_str, str): + return False + + for line in grammar_str.split('\n'): + # Remove both comment styles + line = re.sub(r'(#|//).*$', '', line).strip() + if not line: + continue + + # Look for EBNF rule definition + if '::=' in line: + return False + + return True + + +def convert_lark_to_ebnf(grammar_str: str) -> str: + """ + Convert a Lark grammar string to EBNF format. + + EBNF reference: + https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + Lark grammar reference: + https://lark-parser.readthedocs.io/en/latest/grammar.html + + Args: + grammar_str: Input grammar in Lark format + + Returns: + str: Converted grammar in EBNF format + + Examples: + >>> print(convert_lark_to_ebnf("rule: 'hello'")) + root ::= rule + rule ::= "hello" + """ + if not isinstance(grammar_str, str): + raise ValueError(f"Grammar must be a string, got {type(grammar_str)}") + if not grammar_str.strip(): + raise ValueError("Grammar string cannot be empty") + + defined_rules = set() + referenced_rules = set() + output_lines = [] + + def clean_line(line: str) -> str: + """Remove comments and whitespace from line.""" + return re.sub(r'(#|//).*$', '', line).strip() + + def check_quotes(text: str, rule_name: str, line_num: int) -> None: + """Validate quote matching in text.""" + if text.count("'") % 2 != 0 or text.count('"') % 2 != 0: + raise ValueError( + f"Mismatched quotes in {rule_name} on line {line_num}") + + def extract_references(text: str) -> set: + """Extract rule references from text.""" + # Remove quoted strings and special characters + text = re.sub(r'"[^"]*"', '', text) + text = re.sub(r'[+*?()|\[\]{}]', ' ', text) + return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text)) + + # First pass: Find root rule and validate rule definitions + lines = [clean_line(line) for line in grammar_str.split('\n')] + first_rule = None + + for line_num, line in enumerate(lines, 1): + if not line or line.startswith('|'): + continue + + if ':' in line: + try: + name = line.split(':', 1)[0].strip().strip('?') + defined_rules.add(name) + if first_rule is None: + first_rule = name + if name == 'start': + first_rule = 'start' + except IndexError as e: + raise ValueError(f"Invalid rule format on line {line_num}. " + "Expected 'rule_name: definition'") from e + + if not defined_rules: + raise ValueError("No valid rules found in grammar") + + # Add root rule + output_lines.append(f"root ::= {first_rule}") + + # Second pass: Process rule definitions and alternatives + current_rule = None + current_definition = [] + + for line_num, line in enumerate(lines, 1): + if not line: + continue + + try: + if ':' in line and not line.startswith('|'): + # Save previous rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Process new rule + name, definition = line.split(':', 1) + current_rule = name.strip().strip('?') + + check_quotes(definition, f"rule '{current_rule}'", line_num) + definition = re.sub(r"'([^']*)'", r'"\1"', definition) + referenced_rules.update(extract_references(definition)) + current_definition = [definition.strip()] + + elif line.startswith('|'): + if not current_rule: + raise ValueError(f"Alternative '|' on line {line_num} " + "without a preceding rule definition") + + alt_def = line[1:].strip() + check_quotes(alt_def, f"alternative for rule '{current_rule}'", + line_num) + alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def) + referenced_rules.update(extract_references(alt_def)) + current_definition.append(alt_def) + + except ValueError as e: + raise ValueError(f"Error on line {line_num}: {str(e)}") from e + + # Add final rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Validate all rules are defined + undefined_rules = referenced_rules - defined_rules - {'root'} + if undefined_rules: + raise ValueError("Referenced rules are not defined: " + f"{', '.join(sorted(undefined_rules))}") + + return '\n'.join(output_lines) + + +def choice_as_grammar(choice: list[str]) -> str: + + def escape_ebnf_string(s: str) -> str: + """Escape special characters in a EBNF string.""" + # Escape double quotes and backslashes + return re.sub(r'(["\\])', r'\\\1', s) + + escaped_choices = (escape_ebnf_string(c) for c in choice) + grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices)) + return grammar + + +def validate_structured_output_request( + sampling_params: SamplingParams) -> None: + """Validate that the request is supported by structured output. + + Raises ValueError if the request is not supported. + """ + if sampling_params.guided_decoding is None: + return + + gd_params = sampling_params.guided_decoding + + if gd_params.regex: + raise ValueError("Regex structured output is not supported.") + + if gd_params.choice: + choice_grammar = choice_as_grammar(gd_params.choice) + try: + xgr.Grammar.from_ebnf(choice_grammar) + except Exception as err: + raise ValueError("Failed to transform choices into a grammar: " + "{err}") from err + gd_params.choice = None + gd_params.grammar = choice_grammar + return + + if gd_params.json: + if isinstance(gd_params.json, str): + try: + schema = json.loads(gd_params.json) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + else: + schema = gd_params.json + + if has_xgrammar_unsupported_json_features(schema): + raise ValueError("The provided JSON schema contains features not " + "supported by xgrammar.") + return + + if gd_params.grammar: + if grammar_is_likely_lark(gd_params.grammar): + # xgrammar supports EBNF grammars only + try: + gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar) + except ValueError as e: + raise ValueError( + "Failed to convert the grammar from Lark to EBNF. ") from e + + # Test parsing EBNF grammar, possibly already converted from Lark + try: + # parse the grammar, but we aren't compiling it. + xgr.Grammar.from_ebnf(gd_params.grammar) + except Exception as e: + raise ValueError("Invalid grammar specification.") from e diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 519f38cb0b72..2484f0799b82 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -25,7 +25,8 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LayerBlockType, cdiv, is_pin_memory_available) + LayerBlockType, LazyLoader, cdiv, + is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_cache import MMInputCacheClient @@ -40,7 +41,11 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin if TYPE_CHECKING: + import xgrammar as xgr + from vllm.v1.core.scheduler_output import SchedulerOutput +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") logger = init_logger(__name__) @@ -860,6 +865,53 @@ def _gather_encoder_outputs( def get_model(self) -> nn.Module: return self.model + def apply_grammar_bitmask( + self, + scheduler_output: "SchedulerOutput", + logits: torch.Tensor, + ): + # Serialization of np.ndarray is much more efficient than a tensor, + # so we receive it in that format. + grammar_bitmask = scheduler_output.grammar_bitmask + if grammar_bitmask is None: + return + + # We receive the structured output bitmask from the scheduler, but the + # indices of the requests in the batch may not match the indices of + # the bitmask since the scheduler doesn't know how the gpu runner is + # ordering the requests in the batch. We need to sort the bitmask to + # match the order of the requests used here. + struct_out_req_batch_indices: dict[str, int] = {} + indices_match = True + for req_id in self.input_batch.req_ids: + mask_index = scheduler_output.structured_output_request_ids.get( + req_id) + if mask_index is None: + # not a structured output request + continue + batch_index = self.input_batch.req_id_to_index[req_id] + if batch_index != mask_index: + indices_match = False + struct_out_req_batch_indices[req_id] = batch_index + + if not indices_match: + # Sort the bitmask to match the order of the requests + sorted_bitmask = np.zeros_like(grammar_bitmask) + for req_id, batch_index in struct_out_req_batch_indices.items(): + orig_index = scheduler_output.structured_output_request_ids[ + req_id] + sorted_bitmask[batch_index] = grammar_bitmask[orig_index] + grammar_bitmask = sorted_bitmask + + grammar_bitmask = torch.from_numpy(grammar_bitmask) + + # TODO: compatibility with spec decode + xgr.apply_token_bitmask_inplace( + logits, + grammar_bitmask.to(self.device, non_blocking=True), + indices=list(struct_out_req_batch_indices.values()), + ) + @torch.inference_mode() def execute_model( self, @@ -945,6 +997,10 @@ def execute_model( sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + self.apply_grammar_bitmask(scheduler_output, logits) + # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if not self.use_spec_decode: