Skip to content

Conversation

DarkLight1337
Copy link
Member

@DarkLight1337 DarkLight1337 commented Mar 23, 2025

  • Specifically ask users to set --max-num-seqs to avoid OOM in V1
  • Bring up more options (VLLM_MM_INPUT_CACHE_GIB and VLLM_CPU_KVCACHE_SPACE) to reduce CPU memory consumption.
  • Reduce VLLM_MM_INPUT_CACHE_GIB default to 4 (previous 8) as users with 32GB RAM may otherwise run out of memory.
  • Misc: Update Engine Arguments page to point back to offline inference and online serving pages for easy reference.

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 23, 2025
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the documentation Improvements or additions to documentation label Mar 23, 2025
@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Mar 23, 2025

cc @robertgshaw2-redhat should we reduce the default max_num_seqs for V1, or adjust it automatically somehow? There have been numerous reports of OOM from people using lower-end GPUs like RTX3060.

@robertgshaw2-redhat
Copy link
Collaborator

cc @robertgshaw2-redhat should we reduce the default max_num_seqs for V1, or adjust it automatically somehow? There have been numerous reports of OOM from people using lower-end GPUs like RTX3060.

Why does --max-num-seqs result in OOM?

@ywang96
Copy link
Member

ywang96 commented Mar 24, 2025

cc @robertgshaw2-redhat should we reduce the default max_num_seqs for V1, or adjust it automatically somehow? There have been numerous reports of OOM from people using lower-end GPUs like RTX3060.

Why does --max-num-seqs result in OOM?

@robertgshaw2-redhat This is mostly related to two changes we made on V1

  1. max-num-seqs was raised from 256 to 1024 by default
  2. On V1 there were two PRs added late in the development ([Bugfix] V1 Memory Profiling: V0 Sampler Integration without Rejection Sampler #13594, [V1][Core] Fix memory issue with logits & sampling #14508) to add sampler dummy run into profile_run and compile_or_warm_up_model (the latter is to mitigate memory fragmentation issue which this was previously missing on V0). This means there will always be a dummy sampler run with max number of possible decoding sequences (which's very often --max-num-seqs), and can sometimes result in OOM becasue of the new default. We have also gave an explicit error about it here but users can sometimes miss it.
    raise RuntimeError(
    "CUDA out of memory occurred when warming up sampler with "
    f"{num_reqs} dummy requests. Please try lowering "
    "`max_num_seqs` or `gpu_memory_utilization` when "
    "initializing the engine.") from e

@robertgshaw2-redhat
Copy link
Collaborator

cc @robertgshaw2-redhat should we reduce the default max_num_seqs for V1, or adjust it automatically somehow? There have been numerous reports of OOM from people using lower-end GPUs like RTX3060.

Why does --max-num-seqs result in OOM?

@robertgshaw2-redhat This is mostly related to two changes we made on V1

  1. max-num-seqs was raised from 256 to 1024 by default
  2. On V1 there were two PRs added late in the development ([Bugfix] V1 Memory Profiling: V0 Sampler Integration without Rejection Sampler #13594, [V1][Core] Fix memory issue with logits & sampling #14508) to add sampler dummy run into profile_run and compile_or_warm_up_model (the latter is to mitigate memory fragmentation issue which this was previously missing on V0). This means there will always be a dummy sampler run with max number of possible decoding sequences (which's very often --max-num-seqs), and can sometimes result in OOM becasue of the new default. We have also gave an explicit error about it here but users can sometimes miss it.
    raise RuntimeError(
    "CUDA out of memory occurred when warming up sampler with "
    f"{num_reqs} dummy requests. Please try lowering "
    "`max_num_seqs` or `gpu_memory_utilization` when "
    "initializing the engine.") from e

Okay, we can set --max-num-seqs 1024 just for H100 and A100 then. WDYT @WoosukKwon

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel it's safer to leave the default unchanged if it's already used in production.

Comment on lines 101 to 102
The default `max_num_seqs` has been raised from `256` in V0 to `1024` in V1.
If you encounter OOM only when using V1 engine, try setting a lower value of `max_num_seqs`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The default `max_num_seqs` has been raised from `256` in V0 to `1024` in V1.
If you encounter OOM only when using V1 engine, try setting a lower value of `max_num_seqs`.
The default `max_num_seqs` has been raised from `256` in V0 to `1024` in V1.
If you encounter CUDA OOM only when using V1 engine, try setting a lower value of `max_num_seqs` or `gpu_memory_utilization`.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am worried this might lead to some confusion as lowering gpu_memory_utilization may lead to a related error "No available memory for the cache blocks"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s still clarify this is related to CUDA OOM

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in 4c27e09

@ywang96
Copy link
Member

ywang96 commented Mar 24, 2025

@simon-mo What's your thought on this one?

Comment on lines 100 to 104
:::{important}
The default `max_num_seqs` has been raised from `256` in V0 to `1024` in V1.
If you encounter CUDA OOM only when using V1 engine, try setting a lower value of `max_num_seqs` or `gpu_memory_utilization`.
On the other hand, if you get an error about insufficient memory for the cache blocks, you should increase `gpu_memory_utilization` as this indicates that your GPU has sufficient memory but you're not allocating enough of it to vLLM.
:::
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I think about it - maybe we move this to V1 User Guide page?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this under FAQ for V1

Signed-off-by: Roger Wang <[email protected]>
@simon-mo simon-mo enabled auto-merge (squash) March 24, 2025 16:58
Signed-off-by: Roger Wang <[email protected]>
@simon-mo simon-mo modified the milestones: v0.8.0, v0.8.2 Mar 24, 2025
Signed-off-by: Roger Wang <[email protected]>
@simon-mo simon-mo disabled auto-merge March 24, 2025 21:29
@simon-mo simon-mo merged commit 6dd55af into vllm-project:main Mar 24, 2025
26 of 33 checks passed
@DarkLight1337 DarkLight1337 deleted the docs-oom branch March 25, 2025 03:58
erictang000 pushed a commit to erictang000/vllm that referenced this pull request Mar 25, 2025
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
wrmedford pushed a commit to wrmedford/vllm that referenced this pull request Mar 26, 2025
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Signed-off-by: Wes Medford <[email protected]>
@GohioAC
Copy link

GohioAC commented Mar 27, 2025

The OOM issue is not limited to the dummy sampler run with max number of possible decoding sequences.

I'm running some 1B multimodal models with 32 max-num-seqs, and the CPU RAM usage increases after every batch. I have tried deleting previous batch inputs and garbage collecting to no avail. Even setting disable_mm_preprocessor_cache to True does not help.

Let me know, if I should create a dedicated issue with more details. This issue is specific to the V1 engine.

@DarkLight1337
Copy link
Member Author

Which version of vLLM are you using? Both v0.8.1 and v0.8.2 fixed some memory leaks.

@GohioAC
Copy link

GohioAC commented Mar 27, 2025

Just saw #15294. Looks like the same issue.
My bad for not searching thoroughly. I'll give v0.8.2 a whirl.

lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants