Skip to content

Conversation

ywang96
Copy link
Member

@ywang96 ywang96 commented Mar 9, 2025

Reopened from reverted #13776

Co-authored by @varun-sundar-rabindranath for LoRA dummy run fix.

Copy link

github-actions bot commented Mar 9, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link

mergify bot commented Mar 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ywang96.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot removed the needs-rebase label Mar 9, 2025
Signed-off-by: Roger Wang <[email protected]>
@ywang96
Copy link
Member Author

ywang96 commented Mar 9, 2025

@varun-sundar-rabindranath @jeejeelee Please help take a look why this is breaking LoRA tests on V1 - thank you very much! 🙏

@varun-sundar-rabindranath
Copy link
Contributor

Have a PR #14514 that should fix this issue 🤞 @ywang96 please Cherry pick the latest commit to this branch. Thanks 🙏🏻

@varun-sundar-rabindranath
Copy link
Contributor

Have a PR #14514 that should fix this issue 🤞 @ywang96 please Cherry pick the latest commit to this branch. Thanks 🙏🏻

ywang96 and others added 2 commits March 9, 2025 17:57
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 10, 2025
@ywang96
Copy link
Member Author

ywang96 commented Mar 10, 2025

Have a PR #14514 that should fix this issue 🤞 @ywang96 please Cherry pick the latest commit to this branch. Thanks 🙏🏻

Tested with some of your changes locally and now the lora tested indeed passes. Thanks for the help!

ywang96 added 3 commits March 10, 2025 08:30
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Comment on lines +146 to +147
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
# is captured but cannot be releasesd from PyTorch due to a known bug,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate on this?

Copy link
Member Author

@ywang96 ywang96 Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the discussion here https://vllm-dev.slack.com/archives/C087WBWC5AQ/p1741398800083509?thread_ts=1741386694.452939&cid=C087WBWC5AQ - TLDR is that empty_cache cannot be called when we turn on sleep mode.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... Why do we need empty_cache?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference here is that we never (in both V0 and V1) warmed up sampler, therefore the memory fragmentation issue was always there but not as pronounced in V0 (since the default batch size is 256).

Now we're adding the sampler warmup in V1, but when we call sleep(), the memory buffer for logits can't be cleared from the pytorch caching allocator (the bug mentioned in this comment), therefore the memory usage will be a lot higher.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ywang96 Thanks for the explanation. Just want to double check: We don't want to call empty_cache anyways, because we intentionally reserve the (max_num_reqs x vocab_size)-sized tensor in the pytorch allocator, right?

Copy link
Member Author

@ywang96 ywang96 Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct though I do think there should be a better & clean fix for this to work with sleep mode in the long term. We should probably free the memory when sleep is called, then warm up sampler again within wakeup, but this is currently blocked since we can't free the memory anyways.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm,,, How is the logits tensor different from other intermediate activation tensors?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this specific tensor becomes a problem.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because dummy_run doesn't include/activate sampler tensors, this is why we made dummy_sampler_run in the first place.

@ywang96 ywang96 enabled auto-merge (squash) March 11, 2025 02:16
@ywang96 ywang96 merged commit 1fc973c into vllm-project:main Mar 11, 2025
35 checks passed
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants