Skip to content

Conversation

russellb
Copy link
Member

This does a couple of things:

  1. Defer initializing the grammar bitmask until the first time it is needed
    instead of at engine creation time. For environments where structured
    output is not used, this will prevent xgrammar from ever being imported.

  2. Cleanly reject structured output requests for TPU since
    that is not expected to work right now.

Signed-off-by: Russell Bryant [email protected]

This does a couple of things:

1. Defer initializing the grammar bitmask until the first time it is needed
   instead of at engine creation time. For environments where structured
   output is not used, this will prevent xgrammar from ever being imported.

2. Cleanly reject structured output requests for TPU since
   that is not expected to work right now.

Signed-off-by: Russell Bryant <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mmoskal
Copy link
Contributor

mmoskal commented Mar 10, 2025

You could just simply allocate a torch tensor for bitmask. llguidance uses the same format (as did AICI before it).

@Ubospica
Copy link

Ubospica commented Mar 10, 2025

LGTM. We can provide a TPU support soon. The allocate_token_bitmask just provides a convenient way to compute the size of the bitmask. It's indeed a torch tensor.

@russellb
Copy link
Member Author

You could just simply allocate a torch tensor for bitmask. llguidance uses the same format (as did AICI before it).

True!

Either way, it seems reasonable to not allocate it until we know it's actually needed.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 10, 2025
Copy link
Collaborator

@aarnphm aarnphm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the discussion on slack, seems like we are all agree on installing triton for TPU for now?

@mgoin mgoin enabled auto-merge (squash) March 10, 2025 22:37
@mgoin mgoin merged commit 04421df into vllm-project:main Mar 10, 2025
45 checks passed
@NickLucche
Copy link
Collaborator

ok then we should at least update the PR title, as it doesn't really solve the TPU issue.

@russellb
Copy link
Member Author

ok then we should at least update the PR title, as it doesn't really solve the TPU issue.

I assume you were responding to "seems like we all agree on installing triton for TPU" comment? That's not what merged, though. Let me know if you still have trouble after this change. I know structured output won't work (yet), but it at least shouldn't get in the way of anything else now.

@NickLucche
Copy link
Collaborator

NickLucche commented Mar 11, 2025

Yeah I assumed that was the conclusion of some other discussion I had no context about, because after testing current main I am still experiencing the same issue :/

Let me share that. Were you able to have it work locally on your side?

(vllm) ➜  vllm git:(1477ffc38) ✗ VLLM_USE_V1=1 vllm serve Qwen/Qwen2.5-1.5B-Instruct --enforce-eager                                                 
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
INFO 03-11 13:31:24 [__init__.py:256] Automatically detected platform tpu.
INFO 03-11 13:31:25 [api_server.py:912] vLLM API server version 0.7.3.dev434+gf35f8e22.d20250303
INFO 03-11 13:31:25 [api_server.py:913] args: Namespace(subparser='serve', model_tag='Qwen/Qwen2.5-1.5B-Instruct', config='', host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, enable_ssl_refresh=False, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='Qwen/Qwen2.5-1.5B-Instruct', task='auto', tokenizer=None, hf_config_path=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, allowed_local_media_path=None, download_dir=None, load_format='auto', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='auto', kv_cache_dtype='auto', max_model_len=None, guided_decoding_backend='xgrammar', logits_processor_pattern=None, model_impl='auto', distributed_executor_backend=None, pipeline_parallel_size=1, tensor_parallel_size=1, enable_expert_parallel=False, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=None, enable_prefix_caching=None, disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=None, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_partial_prefills=1, max_long_partial_prefills=1, long_prefill_token_threshold=0, max_num_seqs=None, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_overrides=None, enforce_eager=True, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, use_tqdm_on_load=True, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_disable_mqa_scorer=False, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', scheduler_cls='vllm.core.scheduler.Scheduler', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', worker_extension_cls='', generation_config='auto', override_generation_config=None, enable_sleep_mode=False, calculate_kv_scales=False, additional_config=None, enable_reasoning=False, reasoning_parser=None, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, dispatch_function=<function ServeSubcommand.cmd at 0x7ad958b847c0>)
WARNING 03-11 13:31:25 [arg_utils.py:1473] Setting max_num_batched_tokens to 2048 for OPENAI_API_SERVER usage context.
INFO 03-11 13:31:31 [config.py:576] This model supports multiple tasks: {'score', 'embed', 'classify', 'generate', 'reward'}. Defaulting to 'generate'.
INFO 03-11 13:31:31 [config.py:1666] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 03-11 13:31:31 [tpu.py:76] [TPU] Forcing DYNAMO_ONCE compilation level
WARNING 03-11 13:31:31 [tpu.py:108] [V1][TPU] Disable prefix caching
INFO 03-11 13:31:36 [__init__.py:256] Automatically detected platform tpu.
INFO 03-11 13:31:37 [core.py:51] Initializing a V1 LLM engine (v0.7.3.dev434+gf35f8e22.d20250303) with config: model='Qwen/Qwen2.5-1.5B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-1.5B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=None, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=Qwen/Qwen2.5-1.5B-Instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=True, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"level":2,"backend":"openxla","splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"compile_sizes":[],"cudagraph_capture_sizes":[],"max_capture_size":0}
INFO 03-11 13:31:37 [parallel_state.py:948] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
WARNING 03-11 13:32:03 [tpu.py:116] Pin memory is not supported on TPU.
INFO 03-11 13:32:03 [tpu.py:39] Cannot use None backend on TPU.
INFO 03-11 13:32:03 [tpu.py:42] Using Pallas V1 backend.
WARNING 03-11 13:32:03 [topk_topp_sampler.py:46] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
INFO 03-11 13:32:03 [weight_utils.py:257] Using model weights format ['*.safetensors']
INFO 03-11 13:32:03 [weight_utils.py:307] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.39s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.39s/it]

INFO 03-11 13:32:05 [loader.py:429] Loading weights took 1.49 seconds
INFO 03-11 13:32:30 [kv_cache_utils.py:537] GPU KV cache size: 923,648 tokens
INFO 03-11 13:32:30 [kv_cache_utils.py:540] Maximum concurrency for 32,768 tokens per request: 28.19x
INFO 03-11 13:32:30 [core.py:120] init engine (profile, create kv cache, warmup model) took 24.73 seconds
ERROR 03-11 13:32:30 [core.py:319] EngineCore hit an exception: Traceback (most recent call last):
ERROR 03-11 13:32:30 [core.py:319]   File "/home/nick/vllm/vllm/v1/engine/core.py", line 311, in run_engine_core
ERROR 03-11 13:32:30 [core.py:319]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 03-11 13:32:30 [core.py:319]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-11 13:32:30 [core.py:319]   File "/home/nick/vllm/vllm/v1/engine/core.py", line 266, in __init__
ERROR 03-11 13:32:30 [core.py:319]     super().__init__(vllm_config, executor_class, log_stats)
ERROR 03-11 13:32:30 [core.py:319]   File "/home/nick/vllm/vllm/v1/engine/core.py", line 65, in __init__
ERROR 03-11 13:32:30 [core.py:319]     self.structured_output_manager = StructuredOutputManager(vllm_config)
ERROR 03-11 13:32:30 [core.py:319]                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-11 13:32:30 [core.py:319]   File "/home/nick/vllm/vllm/v1/structured_output/__init__.py", line 43, in __init__
ERROR 03-11 13:32:30 [core.py:319]     tokenizer_info = xgr.TokenizerInfo.from_huggingface(
ERROR 03-11 13:32:30 [core.py:319]                      ^^^^^^^^^^^^^^^^^
ERROR 03-11 13:32:30 [core.py:319]   File "/home/nick/vllm/vllm/utils.py", line 2357, in __getattr__
ERROR 03-11 13:32:30 [core.py:319]     self._module = self._load()
ERROR 03-11 13:32:30 [core.py:319]                    ^^^^^^^^^^^^
ERROR 03-11 13:32:30 [core.py:319]   File "/home/nick/vllm/vllm/utils.py", line 2347, in _load
ERROR 03-11 13:32:30 [core.py:319]     raise err from None
ERROR 03-11 13:32:30 [core.py:319]   File "/home/nick/vllm/vllm/utils.py", line 2341, in _load
ERROR 03-11 13:32:30 [core.py:319]     module = importlib.import_module(self.__name__)
ERROR 03-11 13:32:30 [core.py:319]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-11 13:32:30 [core.py:319]   File "/home/nick/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/importlib/__init__.py", line 126, in import_module
ERROR 03-11 13:32:30 [core.py:319]     return _bootstrap._gcd_import(name[level:], package, level)
ERROR 03-11 13:32:30 [core.py:319]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-11 13:32:30 [core.py:319]   File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
ERROR 03-11 13:32:30 [core.py:319]   File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
ERROR 03-11 13:32:30 [core.py:319]   File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
ERROR 03-11 13:32:30 [core.py:319]   File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
ERROR 03-11 13:32:30 [core.py:319]   File "<frozen importlib._bootstrap_external>", line 940, in exec_module
ERROR 03-11 13:32:30 [core.py:319]   File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
ERROR 03-11 13:32:30 [core.py:319]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/xgrammar/__init__.py", line 1, in <module>
ERROR 03-11 13:32:30 [core.py:319]     from . import testing
ERROR 03-11 13:32:30 [core.py:319]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/xgrammar/testing.py", line 11, in <module>
ERROR 03-11 13:32:30 [core.py:319]     from .matcher import GrammarMatcher, bitmask_dtype, get_bitmask_shape
ERROR 03-11 13:32:30 [core.py:319]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/xgrammar/matcher.py", line 11, in <module>
ERROR 03-11 13:32:30 [core.py:319]     from .kernels import apply_token_bitmask_inplace_cpu, apply_token_bitmask_inplace_triton
ERROR 03-11 13:32:30 [core.py:319]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/xgrammar/kernels/__init__.py", line 4, in <module>
ERROR 03-11 13:32:30 [core.py:319]     from .apply_token_bitmask_inplace_triton import apply_token_bitmask_inplace_triton
ERROR 03-11 13:32:30 [core.py:319]   File "/home/nick/vllm/.venv/lib/python3.11/site-packages/xgrammar/kernels/apply_token_bitmask_inplace_triton.py", line 4, in <module>
ERROR 03-11 13:32:30 [core.py:319]     import triton
ERROR 03-11 13:32:30 [core.py:319] ModuleNotFoundError: No module named 'triton'
ERROR 03-11 13:32:30 [core.py:319] 
CRITICAL 03-11 13:32:30 [core_client.py:260] Got fatal signal from worker processes, shutting down. See stack trace above for root cause issue.
[1]    902018 killed     VLLM_USE_V1=1 vllm serve Qwen/Qwen2.5-1.5B-Instruct --enforce-eager

@russellb
Copy link
Member Author

Yeah I assumed that was the conclusion of some other discussion I had no context about, because after testing current main I am still experiencing the same issue :/

Let me share that. Were you able to have it work locally on your side?

sorry, no, I didn't have a TPU environment to test on. I just made sure the feature was still working.

I see that my change wasn't sufficient. A little more needs to be moved around. I'll submit a follow-up PR in a few minutes.

@russellb
Copy link
Member Author

@NickLucche can you take a look at this and see if it helps? #14616

russellb added a commit to russellb/vllm that referenced this pull request Mar 11, 2025
PR vllm-project#14575 delayed initialization of the grammar bitmask until it was
needed to try to fix a problem encountered on TPU systems.
Unfortunately, that change was not sufficient.

We need to delay usage of ALL xgrammar APIs, not just the grammar
initialization. This change implements that. More initialization is now
deferred until the first time a structured output request is received.

Signed-off-by: Russell Bryant <[email protected]>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants