Skip to content

Conversation

@lsy323
Copy link
Collaborator

@lsy323 lsy323 commented Mar 31, 2025

Make @support_torch_compile work for XLA backend. With the custom dispatcher, overhead of dynamo guard evaluation is eliminated.

For TPU backend, each models have 2 FX graphs/dynamo bytecodes:

  1. During profiling run - No KV cache
  2. After profiling run - with KV cache

This breaks the assumption in the current @support_torch_compile implementation - Each model has one FX graph/cached bytecode. Since the profiling graph won't be invoked after the profiling run, we clear the cached bytecode after profiling run so that the assumption remains true.

Other changes:

  • Remove ModelWrapperV1, which is used to wrap the model code and torch.compile the wrapped model. It's not needed anymore since we are reusing the compile decorator.
  • Since ModelWrapperV1 is removed, sampler logic is moved to a separate function.

Credit to @WoosukKwon on the idea of clearing bytecode cache and @youkaichao for the many help on torch dynamo related questions!

cc @youkaichao @alexm-redhat @miladm @NickLucche @WoosukKwon @yaochengji @robertgshaw2-redhat

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added v1 tpu Related to Google TPUs labels Mar 31, 2025
@lsy323
Copy link
Collaborator Author

lsy323 commented Mar 31, 2025

Slightly improved the throughput 6.09 -> 6.14 req/s, benchmarking cmd is:

VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct \
 --disable-log-requests \
 --port 8004 \
 --gpu-memory-utilization 0.95 \
 --max-num-seqs 512 \
 --max-num-batched-tokens 512 \
 --tensor-parallel-size 1 \
 --max-model-len 2048

python benchmarks/benchmark_serving.py \
    --backend vllm \
    --model meta-llama/Llama-3.1-8B-Instruct  \
    --dataset-name random \
    --random-input-len 1800 \
    --random-output-len 128 \
    --random-prefix-len 0 \
    --port 8004

self._hidden_states_dtype = out.dtype
self.model(input_ids=input_ids,
positions=position_ids,
inputs_embeds=inputs_embeds)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep the _hidden_states_dtype assignment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep that's still needed

Copy link
Collaborator Author

@lsy323 lsy323 Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO updating _hidden_states_dtype is not needed, _hidden_states_dtype is initialized with the model dtype already in _hidden_states_dtype. Also seems we don't need _hidden_states_dtype at all, since it should be the same as the model dtype

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is out of scope, but yes in principle you'd only need self.dtype. The dtype issue you linked has proven that nothing crashes at runtime should the output of an op not match self.dtype though.
Hence if the same bug were to re-appear again, we would only notice the server recompiling and we'd have to debug again the way I did, painfully.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Updated

Comment on lines 791 to 787
xm.wait_device_ops()
xm.wait_device_ops()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to pull the device sync inside of the loop? IIRC we pulled it out since it made parallel compilation slightly quicker

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Updated.

Comment on lines -944 to -984
def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)

def get_input_embeddings(self, *args, **kwargs):
return self.model.get_input_embeddings(*args, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested that multimodal inference still works and these are called correctly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only ran TPU CI to test the change, multimodal is not tested. Can you provide a script to test the multimodal? I can test it on this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLLM_USE_V1=1 vllm serve llava-hf/llava-1.5-7b-hf --max-model-len 4096 --max-num-seqs 8 --max-num-batched-tokens 512 --chat-template examples/template_llava.jinja

then python examples/online_serving/openai_chat_completion_client_for_multimodal.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am adding tests in another pr btw

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg I'll rebase to your PR after :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @NickLucche Thank you for providing the testing cmd for multimodal, I made this PR working for llava-hf/llava-1.5-7b-hf. The server cmd runs fine.

However, the client script python examples/online_serving/openai_chat_completion_client_for_multimodal.py would fail at HEAD.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main is working for me, let me try your PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still really slow but it works the same on this PR on my side, thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickLucche Thank you so much for trying! I set up a new conda env on my end and tried as well. I also found it's really slow, was about to ask you lol.

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!
Left a comment about MM, I think that's about the only thing to clarify on my side.


return hidden_states
def reset_dynamo_cache(self):
# TODO(lsy323): Support multimodal models, the backbone language model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean

Suggested change
# TODO(lsy323): Support multimodal models, the backbone language model
compiled_model = self.model.language_model if self.is_multimodal_model else self.model.model

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't tested multimodal yet, plan to do it in another PR. Right now I only tested the tests in TPU CI

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't tested multimodal yet, plan to do it in another PR

I think it's best if we double check here o/w we may inadvertently break MM

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll rebase to the MM PR after it's merged to ensure this PR doesn't break MM


class ModelWrapperV1(nn.Module):

def __init__(self, model: nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't mind the wrapper too much I think it was grouping a few related functions nicely. Still, if it benefits performance I am ok with that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wrapper was added in V0 for torch.compile. Now we have the @support_torch_compile decorator that wraps the model with torch.compile already, therefore we don't need the wrapper anymore.

sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample]
logits = self.compute_logits(sample_hidden_states)
logits = self.model.compute_logits(sample_hidden_states, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we re-add the pruning comment here? Just in case it slips though in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, may I ask what is 'pruning comment'?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# SamplingMetadata here for pruning output in LogitsProcessor, disabled
or smt along this line to indicate why the 2nd argument is None. That's because it could enable logits pruning which is bad for xla

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Updated

self._hidden_states_dtype = out.dtype
self.model(input_ids=input_ids,
positions=position_ids,
inputs_embeds=inputs_embeds)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep that's still needed

Signed-off-by: Siyuan Liu <[email protected]>
def reset_dynamo_cache(self):
# TODO(lsy323): Support multimodal models, the backbone language model
# is stored in a different member.
compiled_model = self.model.model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wonder why we need to do model.model. Could you add a comment?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible self.model doesn't have a attr model? E.g. it's not annotated by support_torch_compile?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, model.model is torch compiled

Signed-off-by: Siyuan Liu <[email protected]>
Comment on lines 874 to 877
if self.is_multimodal_model:
compiled_model = self.model.language_model.model
else:
compiled_model = self.model.model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK "language_model" is not a stable attribute to reference, it is based on the HF model definition. Maybe @ywang96 @DarkLight1337 would know a stable interface to access the language model backbone?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a potential solution is to add a query function get_backbone_lm for multimodal models, similar to the existing get_input_embeddings

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we can add get_language_model to SupportsMultiModal interface

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like #16007?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave an assert here for now as we will address later @lsy323

Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I think we should aim to land this just working with language_model since there are other blockers/developments there. Afterwards we can migrate to use the get_language_model interface proposed by Nicolo.

Comment on lines 874 to 877
if self.is_multimodal_model:
compiled_model = self.model.language_model.model
else:
compiled_model = self.model.model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave an assert here for now as we will address later @lsy323

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 8, 2025
@DarkLight1337 DarkLight1337 merged commit 87918e4 into vllm-project:main Apr 8, 2025
56 checks passed
@lsy323 lsy323 deleted the lsiyuan/try-disable-dynamo-guard-3 branch April 8, 2025 18:05
@lsy323 lsy323 restored the lsiyuan/try-disable-dynamo-guard-3 branch April 8, 2025 18:06
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants