Skip to content

Conversation

russellb
Copy link
Member

When testing with V1 structured output + Llama-3.1-8B-Instruct, the
changes made in #14630 broke for me. I get the error:

ERROR 03-14 15:44:42 [core.py:337]   File "/home/rbryant/vllm/vllm/v1/structured_output/__init__.py", line 77, in _delayed_init
ERROR 03-14 15:44:42 [core.py:337]     tokenizer_info = xgr.TokenizerInfo.from_huggingface(
ERROR 03-14 15:44:42 [core.py:337]   File "/home/rbryant/vllm/venv/lib/python3.10/site-packages/xgrammar/tokenizer_info.py", line 184, in from_huggingface
ERROR 03-14 15:44:42 [core.py:337]     raise ValueError(msg)
ERROR 03-14 15:44:42 [core.py:337] ValueError: Input vocab_size less than minimum viable vocab size for tokenizer <class 'vllm.transformers_utils.tokenizer.get_cached_tokenizer.<locals>.CachedTokenizer'>.
ERROR 03-14 15:44:42 [core.py:337]

The vocab size was off by one. The max token ID is not == the vocab
size, since 0 is also a token ID. It's the max token ID + 1.

Signed-off-by: Russell Bryant [email protected]

When testing with V1 structured output + Llama-3.1-8B-Instruct, the
changes made in vllm-project#14630 broke for me. I get the error:

```
ERROR 03-14 15:44:42 [core.py:337]   File "/home/rbryant/vllm/vllm/v1/structured_output/__init__.py", line 77, in _delayed_init
ERROR 03-14 15:44:42 [core.py:337]     tokenizer_info = xgr.TokenizerInfo.from_huggingface(
ERROR 03-14 15:44:42 [core.py:337]   File "/home/rbryant/vllm/venv/lib/python3.10/site-packages/xgrammar/tokenizer_info.py", line 184, in from_huggingface
ERROR 03-14 15:44:42 [core.py:337]     raise ValueError(msg)
ERROR 03-14 15:44:42 [core.py:337] ValueError: Input vocab_size less than minimum viable vocab size for tokenizer <class 'vllm.transformers_utils.tokenizer.get_cached_tokenizer.<locals>.CachedTokenizer'>.
ERROR 03-14 15:44:42 [core.py:337]
```

The vocab size was off by one. The max token ID is not == the vocab
size, since 0 is also a token ID. It's the max token ID + 1.

Signed-off-by: Russell Bryant <[email protected]>
@russellb russellb requested a review from mgoin as a code owner March 14, 2025 15:52
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 14, 2025
@russellb russellb added this to the v0.8.0 milestone Mar 14, 2025
@russellb russellb requested a review from njhill March 14, 2025 15:54
@DarkLight1337
Copy link
Member

Is this currently covered by any test? If not we can force-merge this

@russellb
Copy link
Member Author

Is this currently covered by any test? If not we can force-merge this

it should be, but they're off right now -- #14619

@DarkLight1337
Copy link
Member

OK let's just merge then

@vllm-bot vllm-bot merged commit 1140991 into vllm-project:main Mar 14, 2025
9 of 11 checks passed
@comaniac
Copy link
Collaborator

comaniac commented Mar 14, 2025

qq: Why we need to use max token ID to infer the vocab_size instead of just tokenizer.vocab_size, even for CachedTokenizer we have get_vocab()?

@russellb
Copy link
Member Author

qq: Why we need to use max token ID to infer the vocab_size instead of just tokenizer.vocab_size, even for CachedTokenizer we have get_vocab_size()?

See discussion on #14630 that led to that change

@comaniac
Copy link
Collaborator

qq: Why we need to use max token ID to infer the vocab_size instead of just tokenizer.vocab_size, even for CachedTokenizer we have get_vocab_size()?

See discussion on #14630 that led to that change

I see. It's actually better to comment on the code to reduce future confusion.

richardsliu pushed a commit to richardsliu/vllm that referenced this pull request Mar 14, 2025
@russellb
Copy link
Member Author

qq: Why we need to use max token ID to infer the vocab_size instead of just tokenizer.vocab_size, even for CachedTokenizer we have get_vocab_size()?

See discussion on #14630 that led to that change

I see. It's actually better to comment on the code to reduce future confusion.

agreed ... and this still wasn't right. I'm trying to get tests to pass over in #14832

lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants