Skip to content

Conversation

JC1DA
Copy link
Contributor

@JC1DA JC1DA commented Nov 11, 2024

This pull request extends guided decoding capabilities

guidance backend supports regex, choice, json and grammar.

relevant: #5245

Usage

  • JSON Generation
from pydantic import BaseModel, ConfigDict

model = "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4"
llm = LLM(model=model)

class UserProfile(BaseModel):
    name: str
    age: int
    email: str

    model_config = ConfigDict(extra="forbid")

sampling_params = SamplingParams(
    temperature=0.0,
    top_p=0.95,
    max_tokens=512,
    guided_decoding=GuidedDecodingParams(
        json=UserProfile,
        backend="guidance",
    ),
)

outputs = llm.chat(
    messages=[
        [
            CustomChatCompletionMessageParam(
                role="system", content="You are a helpful assistant."
            ),
            CustomChatCompletionMessageParam(
                role="user",
                content="Tell me something about yourself (name, age, email) in JSON format.\n",
            ),
        ],
    ],
    sampling_params=[sampling_params],
)
  • Choices Generation
sampling_params = SamplingParams(
    temperature=0.0,
    top_p=0.95,
    max_tokens=512,
    guided_decoding=GuidedDecodingParams(
        choice=["3","4","5","6"],
        backend="guidance",
    ),
)

outputs = llm.chat(
    messages=[
        [
            CustomChatCompletionMessageParam(
                role="system", content="You are a 5 years-old helpful assistant."
            ),
            CustomChatCompletionMessageParam(
                role="user",
                content="How old are you?",
            ),
        ],
    ],
    sampling_params=[sampling_params],
)
  • Regex Generation via OpenAI Client
model = "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4"
client = OpenAI(
    base_url="http://localhost:8000/v1",
    api_key="NOKEY",
)

completion = client.chat.completions.create(
    model=model,
    messages=[
        {
            "role": "user",
            "content": "You are a 5 years-old helpful assistant.",
        },
        {
            "role": "user",
            "content": """How old are you?""",
        },
    ],
    extra_body={"guided_regex": "\\d+", "guided_decoding_backend": "guidance"}
)

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Nov 11, 2024
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @JC1DA for the great contribution!

A few other questions:

  • Presumably the parallelization speedup is due to the fact that the pytorch ops involved release the gil?
  • Were your outlines measurements also using the threadpool?
  • It would be good to also try with the latest outlines 0.1.x if possible which is apparently much faster than < 0.1. We would want to upgrade to that too in any case.

tokenizer) -> Optional[LogitsProcessor]:
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
if guided_params.backend == 'outlines':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LMFE doesn't support grammar, we should retain the existing behaviour to fall back to a different backend in this case (perhaps it could now be guidance rather than outlines).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a check at the beginning to use outlines by default?

    if guided_params.grammar and guided_params.backend not in [
            'outlines', 'guidance'
    ]:
        guided_params.backend = 'outlines'

Comment on lines 170 to 172
mask = torch.tensor(mask,
dtype=logits.dtype,
device=logits.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the allocated mask tensor be reused between calls?

Copy link
Contributor Author

@JC1DA JC1DA Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill I have updated the code to reuse the logits variable. as we don't add thread pool into this PR anymore, it should work great with inplace ops.

tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.6
outlines >= 0.0.43, < 0.1
guidance>=0.2rc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the requirements of guidance? Does it have compiled binaries for specific python versions or CPU architectures?
Maybe this could be an optional dependency to start with, like we do for many quantization backends

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @mgoin, guidance does have a fair number of dependencies, but we're mostly depending on the lower-level guidance layer here llguidance. llguidance is compiled for Python 3.9+ on manylinux/Mac OS/Windows. My understanding is that vLLM only supports Linux on Python 3.9+ too so I think we should be good there.

We can change this PR in the near future to just use llguidance (which has no other dependencies: https://github.com/microsoft/llguidance/blob/b5ca97b2562b720c1ff3f567bfa45956338a1864/pyproject.toml#L8). We just need to port one last function down from the Python guidance library into the Rust layer first :).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin we replaced guidance with llguidance which has no extra dependencies. Hope it is good enough to merge :)

@JC1DA
Copy link
Contributor Author

JC1DA commented Nov 14, 2024

Thanks @njhill for your quick review. Really appreciate it.

  • Presumably the parallelization speedup is due to the fact that the pytorch ops involved release the gil?

That's one reason, another one is the parser (llguidance) used in guidance was implemented in Rust, and it automatically releases GIL when called. So it would be more efficient to run guidance in parallel.

  • Were your outlines measurements also using the threadpool?

Yes, experiments were done using threadpool

  • It would be good to also try with the latest outlines 0.1.x if possible which is apparently much faster than < 0.1. We
    would want to upgrade to that too in any case.

I haven't tested outlines 0.1.x yet, just used the current version in VLLM. However, I am not focusing too much on the benchmark for this PR. The focus is to make guidance available as another guided decoding backend to VLLM's community so people can choose what's best for them. :)

@JC1DA
Copy link
Contributor Author

JC1DA commented Nov 14, 2024

I also figured out lm-format-enforcer is not thread-safe. It failed some tests when number of threads is larger than 1.
@njhill any suggestions for this?

@Harsha-Nori Harsha-Nori mentioned this pull request Nov 15, 2024
40 tasks
@JC1DA
Copy link
Contributor Author

JC1DA commented Nov 25, 2024

I also figured out lm-format-enforcer is not thread-safe. It failed some tests when number of threads is larger than 1. @njhill any suggestions for this?

Decided to rollback to single threaded version to not break lm-format-enforcer. The PR is coming with minimal changes to add llguidance as new logits processor.
Hope the current code is good for merging :) @njhill @mgoin

@JC1DA JC1DA closed this Nov 25, 2024
@JC1DA JC1DA reopened this Nov 25, 2024
@mergify mergify bot removed the needs-rebase label Dec 3, 2024
@JC1DA
Copy link
Contributor Author

JC1DA commented Dec 5, 2024

Resolved conflict with newly merged xgrammar

@njhill @mgoin

Copy link

mergify bot commented Dec 17, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @JC1DA.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 17, 2024
@mergify mergify bot removed the needs-rebase label Jan 3, 2025
@mmoskal
Copy link
Contributor

mmoskal commented Jan 22, 2025

We have just released a large JSON Schema benchmark and a paper. Of particular interest might be isolated mask-generation benchmarks - comparing LLGuidance, Outlines, XGrammar and llama.cpp grammars.

hero

Copy link

mergify bot commented Feb 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @JC1DA.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 12, 2025
russellb added a commit to russellb/vllm that referenced this pull request Mar 6, 2025
This commit is based on the PR vllm-project#10217. I started to rebase it, but found
it easier to just re-apply the changes on top of latest main.

Signed-off-by: Russell Bryant <[email protected]>
Co-authored-by: Loc Huynh <[email protected]>
Co-authored-by: Loc Huynh <[email protected]>
@russellb
Copy link
Member

russellb commented Mar 6, 2025

Hello! I wanted to try this out, so I re-applied the changes on top of main, adjusting as necessary to get it to fit the current state of things.

https://github.com/vllm-project/vllm/compare/main...russellb:vllm:llguidance-v0-integration?expand=1

It's failing on _initialize() in the logits processor right now. Perhaps someone could take a look with me? #forum-structured-output on the vllm slack would be a good place to find me outside of github if you'd like to chat.

$  pytest -sv tests/model_executor/test_guided_processors.py::test_guided_logits_processor_black_box[True-guidance]

...

tests/model_executor/test_guided_processors.py:93: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
vllm/model_executor/guided_decoding/guidance_logits_processors.py:116: in __call__
    self._initialize()
vllm/model_executor/guided_decoding/guidance_logits_processors.py:86: in _initialize
    TransformersTokenizer( \
vllm/model_executor/guided_decoding/guidance_utils.py:183: in __init__
    byte_tokens = self._byte_tokens(transformers_tokenizer)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <vllm.model_executor.guided_decoding.guidance_utils.TransformersTokenizer object at 0x7f52a7c9a840>
transformers_tokenizer = LlamaTokenizerFast(name_or_path='HuggingFaceH4/zephyr-7b-beta', vocab_size=32000, model_max_length=1000000000000000019...ecial=True),
        2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

    def _byte_tokens(
        self,
        transformers_tokenizer: Union[
            "transformers_package.PreTrainedTokenizer",
            "transformers_package.PreTrainedTokenizerFast",
        ],
    ) -> list[bytes]:
    
        if hasattr(transformers_tokenizer, "byte_decoder"):
            try:
                self._check_byte_decoder(transformers_tokenizer.byte_decoder,
                                         transformers_tokenizer)
            except ByteDecoderError as e:
                error_log = f"Tokenizer has a byte_decoder, \
                    but it can't be used to construct byte_tokens: {e}"
    
                logging.warning(error_log)
                pass
            else:
                return self._byte_tokens_from_byte_decoder(
                    transformers_tokenizer.byte_decoder,
                    transformers_tokenizer)
    
        if hasattr(transformers_tokenizer, "sp_model"):
            return self._byte_tokens_from_sp_model(transformers_tokenizer)
    
        try:
            return self._byte_tokens_by_encoding_token_strings(
                transformers_tokenizer)
        except ValueError as e:
            error_log = f"Could not build byte tokens from the \
                            tokenizer by encoding token strings: {e}"
    
            logging.warning(error_log)
            pass
    
        fallback_byte_decoder = self._fallback_byte_decoder()
        try:
            self._check_byte_decoder(fallback_byte_decoder,
                                     transformers_tokenizer)
        except ByteDecoderError as e:
            # Should be the only exception that is raised in _byte_tokens
>           raise ByteTokensError(
                "Could not build byte tokens from the tokenizer, \
                    and falling back to a standard gpt2 byte_decoder failed"
            ) from e
E           vllm.model_executor.guided_decoding.guidance_utils.ByteTokensError: Could not build byte tokens from the tokenizer,                     and falling back to a standard gpt2 byte_decoder failed

vllm/model_executor/guided_decoding/guidance_utils.py:288: ByteTokensError
----------------------------------------------------------------- Captured log call ------------------------------------------------------------------
WARNING  root:guidance_utils.py:279 Could not build byte tokens from the                             tokenizer by encoding token strings: Round-trip encoding of tokens                                     [<0x00>] failed! Got [1, 28705, 3]
============================================================== short test summary info ===============================================================
FAILED tests/model_executor/test_guided_processors.py::test_guided_logits_processor_black_box[True-guidance] - vllm.model_executor.guided_decoding.guidance_utils.ByteTokensError: Could not build byte tokens from the tokenizer,                     and falli...

@Harsha-Nori
Copy link

Hey @russellb! We've been tracking the discussion on #12388 . Our plan is to re-do this PR once that gets merged. llguidance exposes a similar API to xgrammar so it'll be quite a bit easier to just drop our code in at that point.

Happy to get started on it whenever you recommend. Thanks for the pointer on the slack, we'll join and chat there too :)

@lochuynh1412 @mmoskal

@russellb
Copy link
Member

russellb commented Mar 7, 2025

Hey @russellb! We've been tracking the discussion on #12388 . Our plan is to re-do this PR once that gets merged. llguidance exposes a similar API to xgrammar so it'll be quite a bit easier to just drop our code in at that point.

Happy to get started on it whenever you recommend. Thanks for the pointer on the slack, we'll join and chat there too :)

@lochuynh1412 @mmoskal

That sounds great. I want to get multiple backends going for the V1 engine after that PR merges.

I also might have a use case for this in the V0 engine for an existing user, as well, which brought me over to this PR. I figured I might be able to help get this updated and working so I can test and see if it works for them.

russellb added a commit to russellb/vllm that referenced this pull request Mar 11, 2025
This commit is based on the PR vllm-project#10217. It is updated to be compatible
with `main`.

Signed-off-by: Russell Bryant <[email protected]>
Co-authored-by: Loc Huynh <[email protected]>
Co-authored-by: Michal Moskal <[email protected]>
@aarnphm
Copy link
Collaborator

aarnphm commented Mar 16, 2025

superceded by #14589

@aarnphm aarnphm closed this Mar 16, 2025
russellb added a commit to russellb/vllm that referenced this pull request Mar 18, 2025
This commit is based on the PR vllm-project#10217. It is updated to be compatible
with `main`.

Signed-off-by: Russell Bryant <[email protected]>
Co-authored-by: Loc Huynh <[email protected]>
Co-authored-by: Michal Moskal <[email protected]>
russellb added a commit to russellb/vllm that referenced this pull request Mar 19, 2025
This commit is based on the PR vllm-project#10217. It is updated to be compatible
with `main`. It has also been significantly updated and simplified to
take advantage of changes to llguidance since the original PR was
written.

Many thanks to the llguidance team for the help on this.

Signed-off-by: Russell Bryant <[email protected]>
Co-authored-by: Loc Huynh <[email protected]>
Co-authored-by: Michal Moskal <[email protected]>
Co-authored-by: Aaron Pham <[email protected]>
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Mar 21, 2025
This commit is based on the PR vllm-project#10217. It is updated to be compatible
with `main`. It has also been significantly updated and simplified to
take advantage of changes to llguidance since the original PR was
written.

Many thanks to the llguidance team for the help on this.

Signed-off-by: Russell Bryant <[email protected]>
Co-authored-by: Loc Huynh <[email protected]>
Co-authored-by: Michal Moskal <[email protected]>
Co-authored-by: Aaron Pham <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants