Skip to content

Conversation

ywang96
Copy link
Member

@ywang96 ywang96 commented Feb 24, 2025

Reopened from accidentally merged #13721

This PR is a followup to #13594 (comment) that describes the memory issue during online serving even after sampler profiling is added to profile_run. After some investigation, the root cause is memory fragmentation issue of logits and other related sampling tensors since we don't preallocate buffers for these beforehand.

This memory issue can be reproduced by modifying the temperature in the file below to non zero to trigger the logits sampling code path.

"temperature": 0.0,

Server command:
VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B --disable-log-requests --no-enable-prefix-caching

Client command:

python3 benchmarks/benchmark_serving.py \        
        --model meta-llama/Llama-3.1-8B \
        --dataset-name sharegpt \
        --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
        --ignore-eos \
        --num-prompts 7200 \
        --request-rate 60

The graphs below tracks the memory usage of the server process on 1xH100. Timestamp starts when the server is ready to receive the traffic, and around T=25 is when it starts receiving traffic.

On main, since logits.shape[0] goes up incrementally because of the nature of online traffic pattern, the memory usage starts growing as a result of the memory fragmentation issue of the intermediate tensors from logits in Sampler. This issue will not crash the server since PyTorch itself will garbage collect these cached buffers to prevent OOM (as observed from the dips in the graph), but this should be fixed and handled by vLLM for a few obvious reasons (e.g, memory release requires synchronization).
gpu0_memory_usage

The root cause of this issue is that the sampler was not included in compile_or_warm_up_model but capture_model implicitly calls torch.cuda.empty_cache(), therefore even if the memory usage of sampler was captured in profile_run, the memory buffers were cleared from this method.

This PR addresses this issue by adding dummy_sampler_run and calls it after the model forward itself is warmed up and captured. We do not want to put them both in _dummy_run since this method is needed elsewhere for other purposes.
gpu0_memory_usage (1)
The memory usage is rather stable from this PR, and one can observe the initial server memory usage increases from ~68K MiB to ~73K MiB (this accounts for all sampling related buffers that were taken into account during profiling but cleared from warmup), but stayed stable during the actual inference. We indeed observe a very small bump when it started receiving traffic, but IMO it is small enough for us to leave it for later investigation.

In addition, with the default GMU=0.9 and thus one may expect initial server launch takes 81559 * 0.9 = 73403 + w/e cuda graphs require (although this is technically speaking not how it works), the memory usage (~73714) with the fix from this PR should be acceptable. Without needing PyTorch to do gc, we also observe a tiny perf improvement.

Main
============ Serving Benchmark Result ============
Successful requests:                     7200      
Benchmark duration (s):                  138.31    
Total input tokens:                      1582725   
Total generated tokens:                  1442778   
Request throughput (req/s):              52.06     
Output token throughput (tok/s):         10431.65  
Total Token throughput (tok/s):          21875.15  
---------------Time to First Token----------------
Mean TTFT (ms):                          132.21    
Median TTFT (ms):                        111.04    
P99 TTFT (ms):                           392.95    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          35.49     
Median TPOT (ms):                        33.55     
P99 TPOT (ms):                           58.06     
---------------Inter-token Latency----------------
Mean ITL (ms):                           34.90     
Median ITL (ms):                         33.16     
P99 ITL (ms):                            94.75     
==================================================

This PR
============ Serving Benchmark Result ============
Successful requests:                     7200      
Benchmark duration (s):                  137.63    
Total input tokens:                      1582725   
Total generated tokens:                  1442778   
Request throughput (req/s):              52.32     
Output token throughput (tok/s):         10483.35  
Total Token throughput (tok/s):          21983.57  
---------------Time to First Token----------------
Mean TTFT (ms):                          111.15    
Median TTFT (ms):                        101.92    
P99 TTFT (ms):                           283.02    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          35.12     
Median TPOT (ms):                        33.42     
P99 TPOT (ms):                           57.61     
---------------Inter-token Latency----------------
Mean ITL (ms):                           34.56     
Median ITL (ms):                         32.93     
P99 ITL (ms):                            87.28     
==================================================

An alternative fix is to have persistent buffer of logits but this may encounter some practical issues, and we will leave it for future investigation too.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Feb 24, 2025
@youkaichao
Copy link
Member

is it compatible with sleep mode now?

@WoosukKwon
Copy link
Collaborator

Any updates?

@JaheimLee
Copy link

Is it related to this issue?

@ywang96
Copy link
Member Author

ywang96 commented Feb 26, 2025

@WoosukKwon @youkaichao sorry but I haven't got chance to work on this (got flu over the weekend) - will try to investigate more by end of Friday

Copy link

mergify bot commented Mar 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ywang96.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 1, 2025
@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 4, 2025
@mergify mergify bot removed the needs-rebase label Mar 4, 2025
@ywang96
Copy link
Member Author

ywang96 commented Mar 8, 2025

Note - we also have this issue on V0 but it wasn't this pronounced because the default max-num-seqs is 256 (instead of 1024 on V1)

@ywang96 ywang96 marked this pull request as ready for review March 8, 2025 08:49
@ywang96
Copy link
Member Author

ywang96 commented Mar 8, 2025

Discussed with @youkaichao offline - for now we will "bypass" cumem tests for V1 and properly fix it for V1 sleep mode later.

Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
@ywang96 ywang96 added ready ONLY add when PR is ready to merge/full CI is needed and removed ready ONLY add when PR is ready to merge/full CI is needed labels Mar 8, 2025
Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, will try to fix sleep mode compatibility later.

# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
self.model_runner._dummy_sampler_run(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this memory is out of gpu memory utilization, users might get OOM here, and they will try to increase gpu memory utilization.

can we add a try-catch here, when OOM occurs, tell users to set a smaller gpu memory utilization?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point - will add!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in 1e017e4

Signed-off-by: Roger Wang <[email protected]>
@vllm-bot vllm-bot merged commit 8d5aa46 into vllm-project:main Mar 8, 2025
30 of 31 checks passed
robertgshaw2-redhat added a commit that referenced this pull request Mar 8, 2025
robertgshaw2-redhat pushed a commit that referenced this pull request Mar 8, 2025
Copy link

mergify bot commented Mar 9, 2025

⚠️ The sha of the head commit of this PR conflicts with #14508. Mergify cannot evaluate rules on this PR. ⚠️

Alexei-V-Ivanov-AMD added a commit to ROCm/vllm that referenced this pull request Mar 11, 2025
* Fix `head_dim` not existing in all model configs (Transformers backend) (vllm-project#14141)

Signed-off-by: Harry Mellor <[email protected]>

* [V0][Metrics] Remove unimplemented `vllm:tokens_total` (vllm-project#14134)

Signed-off-by: Mark McLoughlin <[email protected]>

* [V0][Metrics] Deprecate some KV/prefix cache metrics (vllm-project#14136)

Signed-off-by: Mark McLoughlin <[email protected]>

* [V1] Simplify stats logging (vllm-project#14082)

Signed-off-by: Nick Hill <[email protected]>

* [WIP][[V1][Metrics] Implement max_num_generation_tokens,  request_params_n, and request_params_max_tokens metrics (vllm-project#14055)

Signed-off-by: Mark McLoughlin <[email protected]>

* [Bugfix] Allow shared_experts skip quantization for DeepSeekV2/V3 (vllm-project#14100)

Signed-off-by: mgoin <[email protected]>

* [Kernel] Optimize moe intermediate_cache usage (vllm-project#13625)

Signed-off-by: mgoin <[email protected]>

* [Docs] Add GPTQModel (vllm-project#14056)

Signed-off-by: mgoin <[email protected]>
Co-authored-by: mgoin <[email protected]>

* [v1] Add comments to the new ragged paged attention Pallas kernel (vllm-project#14155)

Signed-off-by: Xiongfei Wei <[email protected]>
Co-authored-by: Michael Goin <[email protected]>

* [Model] Add support for GraniteMoeShared models (vllm-project#13313)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>

* [core] moe fp8 block quant tuning support (vllm-project#14068)

Signed-off-by: Divakar Verma <[email protected]>

* [Misc] Remove lru_cache in NvmlCudaPlatform (vllm-project#14156)

Signed-off-by: Cody Yu <[email protected]>

* [core] Pass all driver env vars to ray workers unless excluded (vllm-project#14099)

Signed-off-by: Rui Qiao <[email protected]>

* Use math.prod instead of np.prod for trivial ops (vllm-project#14142)

* Fix benchmark_moe.py tuning for CUDA devices (vllm-project#14164)

* [platform] add debug logging during inferring the device type (vllm-project#14195)

Signed-off-by: youkaichao <[email protected]>

* [sleep mode] error out with expandable_segments (vllm-project#14189)

Signed-off-by: youkaichao <[email protected]>

* [doc] add "Failed to infer device type" to faq (vllm-project#14200)

Signed-off-by: youkaichao <[email protected]>

* [Bugfix] Restrict MacOS CPU detection (vllm-project#14210)

Signed-off-by: mgoin <[email protected]>

* [V1][BugFix] Fix remaining sync engine client shutdown errors/hangs (vllm-project#13869)

Signed-off-by: Nick Hill <[email protected]>

* [V0][Metrics] Deprecate some questionable request time metrics (vllm-project#14135)

Signed-off-by: Mark McLoughlin <[email protected]>

* [V1][Molmo] Fix get_multimodal_embeddings() in molmo.py (vllm-project#14161)

* add cutlass support for blackwell fp8 gemm (vllm-project#13798)

* [TPU][Profiler] Support start_profile/stop_profile in TPU worker (vllm-project#13988)

Signed-off-by: Siyuan Liu <[email protected]>
Co-authored-by: mgoin <[email protected]>

* Fix performance when `--generation-config` is not `None` (vllm-project#14223)

Signed-off-by: Harry Mellor <[email protected]>

* [Frontend] Do `prompt_logprobs` clamping for chat as well as completions (vllm-project#14225)

Signed-off-by: Harry Mellor <[email protected]>

* [Docs] Update Dockerfile dependency image (vllm-project#14215)

Signed-off-by: mgoin <[email protected]>

* [v1][Metrics] Add design doc (vllm-project#12745)

Signed-off-by: Mark McLoughlin <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Cody Yu <[email protected]>

* [Security] Serialize using safetensors instead of pickle in Mooncake Pipe (vllm-project#14228)

Signed-off-by: KuntaiDu <[email protected]>

* Clean up unused padding_idx variables across many model definitions (vllm-project#13240)

Signed-off-by: Tyler Michael Smith <[email protected]>

* [ROCm] Disable a few more kernel tests that are broken on ROCm (vllm-project#14145)

Signed-off-by: Sage Moore <[email protected]>

* [V1][TPU] TPU multimodal model support for ragged attention (vllm-project#14158)

Signed-off-by: Michael Goin <[email protected]>

* [misc] announce china meetup (vllm-project#14248)

Signed-off-by: youkaichao <[email protected]>

* Moved numba from common requirements to cuda/rocm specific requirements (vllm-project#14199)

Signed-off-by: Nishidha Panpaliya <[email protected]>

* Disable GPTQ AllSpark kernels for CUDA Compiler < 12.0 (vllm-project#14157)

Signed-off-by: mgoin <[email protected]>

* [Bugfix] Fix gptq_marlin for deepseek-v3 (vllm-project#13750)

Signed-off-by: dangshunya <[email protected]>
Co-authored-by: dangshunya <[email protected]>

* [V1][Bugfix] Do not reset prefix caching metrics (vllm-project#14235)

* [Model] New model support for Phi-4-multimodal-instruct (vllm-project#14119)

* [V1] EP/TP MoE + DP Attention (vllm-project#13931)

* [platforms] improve rocm debugging info (vllm-project#14257)

* Temporarily disable test_awq_gemm_opcheck (vllm-project#14251)

Signed-off-by: mgoin <[email protected]>

* [Frontend] Allow return_tokens_as_token_ids to be passed as a request param (vllm-project#14066)

Signed-off-by: Benjamin Chislett <[email protected]>

* [Misc][V1] Avoid using `envs.VLLM_USE_V1` in mm processing (vllm-project#14256)

Signed-off-by: Roger Wang <[email protected]>

* [Bugfix][V1] Fix allowed_token_ids for v1 Sampler (vllm-project#14169)

Signed-off-by: Lu Fang <[email protected]>

* [Doc] Update nginx guide: remove privileged from vllm container run and add target GPU ID (vllm-project#14217)

Signed-off-by: Iacopo Poli <[email protected]>

* [Doc] [3/N] Refer code examples for common cases in dev multimodal processor (vllm-project#14278)

Signed-off-by: DarkLight1337 <[email protected]>

* Small update for external_launcher backend docs (vllm-project#14288)

* [V1][Frontend] Add Testing For V1 Runtime Parameters (vllm-project#14159)

Signed-off-by: [email protected] <[email protected]>

* [LoRA] Remove linear hack outside transformers backend (vllm-project#14177)

Signed-off-by: Isotr0py <[email protected]>

* [Misc] Add Qwen2MoeForCausalLM moe tuning support  (vllm-project#14276)

Signed-off-by: Jee Jee Li <[email protected]>

* prefix_caching.md: Fixed typo (vllm-project#14293)

Signed-off-by: Daivid Savernin-Frenk <[email protected]>

* [Bugfix] Fix broken vision language example (vllm-project#14292)

Signed-off-by: Isotr0py <[email protected]>

* [Docs] Add Meta Slides (vllm-project#14297)

Signed-off-by: simon-mo <[email protected]>

* [V1][Minor] Remove obsolete FIXME comment (vllm-project#14304)

Signed-off-by: Nick Hill <[email protected]>

* Deprecate `best_of` Sampling Parameter in anticipation for vLLM V1 (vllm-project#13997)

Signed-off-by: vincent-4 <[email protected]>
Signed-off-by: Brayden Zhong <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Co-authored-by: Brayden Zhong <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>

* [V1][BugFix] Fix for mixed top_k batch (vllm-project#14301)

Signed-off-by: Nick Hill <[email protected]>


Co-authored-by: Ye Cao <[email protected]>

* [misc] Add FlashMLA as a new option of VLLM_ATTENTION_BACKEND env (vllm-project#14267)

* [V1][Easy] Add empty allowed_token_ids in the v1 sampler test (vllm-project#14308)

Signed-off-by: Lu Fang <[email protected]>

* init

Signed-off-by: Sage Moore <[email protected]>

* [Bugfix] Fix DeepSeek MTP crash when using TP1ModelRunner with CUDA graph due to shape mismatch (vllm-project#14237)

Signed-off-by: pyc96 <[email protected]>

* [Bugfix] Remove num_tokens_across_dp (vllm-project#14302)

Signed-off-by: Tyler Michael Smith <[email protected]>

* [BugFix] Fix prefix caching V0 MLA (vllm-project#14255)

Signed-off-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Ying Zhong <[email protected]>

* [CI/Build] Use spawn multiprocessing mode for V1 test pipeline (vllm-project#14243)

Signed-off-by: Russell Bryant <[email protected]>

* Add benchmark for DeepGEMM and vLLM Block FP8 Dense GEMM (vllm-project#13917)

Signed-off-by: mgoin <[email protected]>

* [Build] Add UV_HTTP_TIMEOUT to avoid timeout during installation (vllm-project#13850)

Signed-off-by: Yuan Tang <[email protected]>

* [BugFix] MLA + V1, illegal memory access and accuracy issues (vllm-project#14253)

Signed-off-by: Lucas Wilkinson <[email protected]>

* [misc] Mention `ray list nodes` command to troubleshoot ray issues (vllm-project#14318)

Signed-off-by: Rui Qiao <[email protected]>

* [Bugfix][Structured Output] Support outlines engine with reasoning outputs for DeepSeek R1 (vllm-project#14114)

* [V1] LoRA - Enable more V1 tests (vllm-project#14315)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>

* [Bugfix][CI] ALiBi test case in xformers multi_query_kv_attention (vllm-project#11301)

* [Hardware] Update the flash attn tag to support Blackwell (vllm-project#14244)

* [Model] Update Paligemma multimodal processing with PromptUpdate  (vllm-project#14015)

Signed-off-by: Kyle Huang <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>

* [V1][VLM][Pixtral-HF] Support Pixtral-HF on V1 (vllm-project#14275)

Signed-off-by: Linkun Chen <[email protected]>

* [Core] Optimizing cross-attention `QKVParallelLinear` computation (vllm-project#12325)

Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Co-authored-by: NickLucche <[email protected]>

* [Frontend][Docs] Transcription API streaming (vllm-project#13301)

Signed-off-by: NickLucche <[email protected]>

* [Doc] Update reasoning with stream example to use OpenAI library (vllm-project#14077)

Signed-off-by: liuyanyi <[email protected]>

* [Doc] Correct beam_search using in generative_models.md (vllm-project#14363)

* [Kernel] [V1] Improved performance for V1 Triton (ROCm) backend  (vllm-project#14152)

* [Bugfix][Core] fix abort_seq_group and memory leak when n>1 (vllm-project#14326)

Signed-off-by: courage17340 <[email protected]>

* [Core] Don't use cache during multi-modal profiling (vllm-project#14336)

* [Doc] Fix date typo in README.md (vllm-project#14366)

Signed-off-by: Jitse Klomp <[email protected]>

* [RLHF] use worker_extension_cls for compatibility with V0 and V1 (vllm-project#14185)

Signed-off-by: youkaichao <[email protected]>

* Reinstate `best_of` for V0 (vllm-project#14356)

Signed-off-by: Harry Mellor <[email protected]>

* Adding cpu inference with VXE ISA for s390x architecture (vllm-project#12613)

Signed-off-by: Dilip Gowda Bhagavan <[email protected]>
Signed-off-by: Rishika Kedia <[email protected]>
Co-authored-by: Rishika Kedia <[email protected]>

* Add authors to license header. (vllm-project#14371)

Signed-off-by: Thomas Parnell <[email protected]>
Co-authored-by: Burkhard Ringlein <[email protected]>
Co-authored-by: Jan van Lunteren <[email protected]>

* Fix mla prefill context performance (vllm-project#13897)

Signed-off-by: ZhongYingMatrix <[email protected]>

* [V1] Do not detokenize if sampling param detokenize is False (vllm-project#14224)

Signed-off-by: Himanshu Jaju <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Co-authored-by: Nick Hill <[email protected]>

* [Distributed] Add enable_expert_parallel arg (vllm-project#14305)

Signed-off-by: Tyler Michael Smith <[email protected]>

* [CI/Build] Use uv python for docker rather than ppa:deadsnakes/ppa (vllm-project#13569)

Signed-off-by: mgoin <[email protected]>

* [CI] Disable spawn when running V1 Test (vllm-project#14345)

Signed-off-by: Thomas Parnell <[email protected]>

* [Kernel] Add needs_fixed_stride_order tag to most GEMMs (vllm-project#14306)

Signed-off-by: Tyler Michael Smith <[email protected]>

* [Bugfix] Fix use_direct_call condition in FusedMoE layer for  (vllm-project#14382)

Signed-off-by: Tyler Michael Smith <[email protected]>

* [Bug] Fix Attention when ignored in by quant_method (vllm-project#14313)

Signed-off-by: mgoin <[email protected]>

* [V1][Bugfix] Standardize quantized kv cache rejection for attention backends (vllm-project#14221)

Signed-off-by: mgoin <[email protected]>

* [Docs] Add nsight guide to profiling docs (vllm-project#14298)

Signed-off-by: mgoin <[email protected]>

* cleanup boolean logic

Signed-off-by: Sage Moore <[email protected]>

* [Hardware][TPU]Enable ragged paged attention kernel and resolve recompilation issue (vllm-project#14310)

Signed-off-by: Chengji Yao <[email protected]>

* [Doc] Fix a typo (vllm-project#14385)

* [Bugfix] Correctly call `cudaProfilerStop` in benchmarks script (vllm-project#14183)

Signed-off-by: Brayden Zhong <[email protected]>

* [Perf] Reduce MLA CPU overheads in V1 (vllm-project#14384)

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>

* [FP8] Refactor apply_fp8_linear and apply_fp8_linear_generic into an object (vllm-project#14390)

Signed-off-by: luka <[email protected]>

* [BugFix] Illegal Memory Access in the blockwise cutlass fp8 GEMMs (vllm-project#14396)

* [Bugfix] Fix JambaForCausalLM LoRA  (vllm-project#14370)

Signed-off-by: Jee Jee Li <[email protected]>

* [Build] Add nightly wheel fallback when latest commit wheel unavailable (vllm-project#14358)

Signed-off-by: Isotr0py <[email protected]>

* OpenVINO: added CPU-like conditions (vllm-project#14338)

Signed-off-by: Ilya Lavrenov <[email protected]>

* [GH] Auto-apply multi-modality label to relevant PRs (vllm-project#14402)

Signed-off-by: DarkLight1337 <[email protected]>

* correct wrong markdown syntax (vllm-project#14414)

Signed-off-by: vincent-pli <[email protected]>

* [Bugfix] Further clean up LoRA test (vllm-project#14422)

Signed-off-by: Jee Jee Li <[email protected]>

* [Bugfix] Clean up multi-modal processors (vllm-project#14417)

Signed-off-by: DarkLight1337 <[email protected]>

* [Misc] Set default value of seed to None (vllm-project#14274)

Signed-off-by: மனோஜ்குமார் பழனிச்சாமி <[email protected]>

* [BUGFIX] Skip tokenization support for throughput benchmark (vllm-project#12712)

Signed-off-by: root <[email protected]>
Signed-off-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>

* Fix missing `kv_caches` and `attn_metadata` in `OpenVINOCausalLM` (vllm-project#14271)

Signed-off-by: Harry Mellor <[email protected]>

* Use the optimized block sizes after tuning the kernel. (vllm-project#14329)

* [V1][Core] Support for Structured Outputs (vllm-project#12388)

Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: Russell Bryant <[email protected]>
Co-authored-by: Russell Bryant <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Nick Hill <[email protected]>

* [Doc] Update prefix_caching.md to match the example image (vllm-project#14420)

* [Benchmarks] Make detokenization optional in benchmark scripts (vllm-project#11697)

Signed-off-by: Jeremy Arnold <[email protected]>

* comments

Signed-off-by: Sage Moore <[email protected]>

* [Kernel] optimize performance of gptq marlin kernel when n is small (vllm-project#14138)

Signed-off-by: Jinzhen Lin <[email protected]>

* [Misc] Add Phi4-MM example (vllm-project#14343)

Signed-off-by: Jee Jee Li <[email protected]>

* [v1] torch.compile integration explanation (vllm-project#14437)

Signed-off-by: youkaichao <[email protected]>

* [V1] Eagerly remove finished requests from the batch (vllm-project#14388)

Signed-off-by: Nick Hill <[email protected]>

* [V1][Metrics] Fix traceback with preemptions+LoRA (vllm-project#14220)

Signed-off-by: Mark McLoughlin <[email protected]>

* [Bugfix] Fix torch_xla which can't handle None seed introduced in vllm-project#14274 (vllm-project#14459)

Signed-off-by: Yarong Mu <[email protected]>

* [V1] Prompt logprobs + APC compatibility; prompt logprobs reqs cannot fill APC (vllm-project#13949)

* [Bugfix][V1] Handle MLA in kv_cache_interface (vllm-project#14462)

Signed-off-by: Tyler Michael Smith <[email protected]>

* Revert "[Perf] Reduce MLA CPU overheads in V1 (vllm-project#14384)" (vllm-project#14471)

* [Bugfix][Disaggregated] Add a check in send_kv_caches_and_hidden_states and fix the reshape of the KVCache (vllm-project#14369)

Signed-off-by: Mathis Felardos <[email protected]>

* [MISC][V1] Register process killing handler only in the main thread (vllm-project#14380)

Signed-off-by: Cody Yu <[email protected]>

* [core] add `extra_args` to `SamplingParams` (vllm-project#13300)

Signed-off-by: Aviv Keshet <[email protected]>

* [CI/Build] refactor: set timezone of container to UTC (vllm-project#12888)

Signed-off-by: Roger Meier <[email protected]>

* Default to `generation_config` from model (vllm-project#12622)

Signed-off-by: Harry Mellor <[email protected]>

* [Doc]add doc for Qwen models tool calling (vllm-project#14478)

Signed-off-by: WangErXiao <[email protected]>

* [Doc] Added QwQ-32B to the supported models list in the reasoning out… (vllm-project#14479)

Signed-off-by: WangErXiao <[email protected]>

* [Bugfix] Make the deviceprofiler include LoRA memory. (vllm-project#14469)

Signed-off-by: Jee Jee Li <[email protected]>

* Add training doc signposting to TRL (vllm-project#14439)

Signed-off-by: Harry Mellor <[email protected]>

* [Build/BugFix] Fix hopper 12.8 build (vllm-project#14354)

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>

* Add RLHF document (vllm-project#14482)

Signed-off-by: Harry Mellor <[email protected]>

* [CI/Build] Use a fixed seed to avoid flaky tests (vllm-project#14480)

Signed-off-by: DarkLight1337 <[email protected]>

* [V1] TPU - Add tensor parallel support via Ray (vllm-project#13618)

Signed-off-by: Alexander Matveev <[email protected]>

* [VLM] Add TP support for Phi-4-MM (vllm-project#14453)

Signed-off-by: Isotr0py <[email protected]>

* [Misc] add `use_tqdm_on_load` to reduce logs (vllm-project#14407)

Signed-off-by: Aaron Pham <[email protected]>

* [V1][Core] Fix memory issue with logits & sampling (vllm-project#13776)

Signed-off-by: Roger Wang <[email protected]>

* [benchmarks] Add option to use unique jsonschema for each request (vllm-project#14457)

Signed-off-by: Russell Bryant <[email protected]>

* [Misc] Don't run ruff at all on 3rd party libs (vllm-project#14493)

Signed-off-by: DarkLight1337 <[email protected]>

* Move requirements into their own directory (vllm-project#12547)

Signed-off-by: Harry Mellor <[email protected]>

* [Bugfix] DeepSeek Accuracy (vllm-project#14476)

Signed-off-by: Lucas Wilkinson <[email protected]>

* [Bugfix] Fix profiling OOM and decouple encoder multimodal profiling (vllm-project#14361)

Signed-off-by: Isotr0py <[email protected]>

* Update CODEOWNERS for structured output (vllm-project#14496)

Signed-off-by: Russell Bryant <[email protected]>

* [Misc] Upgrade to Python 3.9 typing for additional directories (vllm-project#14492)

Signed-off-by: DarkLight1337 <[email protected]>

* [V1] Support bad_words in sampler (vllm-project#13376)

Signed-off-by: 22quinn <[email protected]>
Co-authored-by: Nick Hill <[email protected]>

* Revert "[V1][Core] Fix memory issue with logits & sampling" (vllm-project#14504)

Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>

* [Attention] Default to FlashMLA backend for MLA (vllm-project#14451)

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>

* [V1][TPU] Remove unnecessary padding for running on TPU. (vllm-project#14467)

* [Feat] Support chunked prefill for LMCache connector (vllm-project#14505)

Signed-off-by: YaoJiayi <[email protected]>

* [Bugfix] Fix tqdm progress bar when SamplingParams.n > 1 (vllm-project#12428)

Signed-off-by: Yuchen Yan <[email protected]>

* [Bugfix] Revert QKVCrossParallelLinear usage in Mllama to keep BNB quantization work (vllm-project#14498)

Signed-off-by: Isotr0py <[email protected]>

* [Hardware][TPU] Fix the recompiling issue in logits processor after warmup (vllm-project#14510)

Signed-off-by: Chengji Yao <[email protected]>

* [Misc] Ensure out-of-tree quantization method recognize by cli args (vllm-project#14328)

Signed-off-by: liuyanyi <[email protected]>

* [Bugfix] Wrong requirements path - rocm (vllm-project#14527)

Signed-off-by: Martin Hoyer <[email protected]>

* [Feature] Consolidate performance benchmark datasets (vllm-project#14036)

Signed-off-by: Jennifer Zhao <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Jennifer Zhao <[email protected]>
Co-authored-by: Roger Wang <[email protected]>

* [Misc] Add log information for handle_process_request. (vllm-project#14130)

Signed-off-by: chaunceyjiang <[email protected]>

* [Docs] Mention `model_impl` arg when explaining Transformers fallback (vllm-project#14552)

Signed-off-by: Harry Mellor <[email protected]>

* [Frontend] support image embeds (vllm-project#13955)

Signed-off-by: chaunceyjiang <[email protected]>

* [Kernel] Add more dtype support for GGUF kernels (vllm-project#14043)

Signed-off-by: SzymonOzog <[email protected]>
Signed-off-by: SzymonOzog <[email protected]>

* [Doc] Update PaliGemma note to a warning (vllm-project#14565)

Signed-off-by: DarkLight1337 <[email protected]>

* V1 rocm support (#469)

* Initial commit for V1 successfull compilation

* Small improvement for linear

* Small improvement for linear

* making use of forward_cuda for all except ROPE in llama

---------

Co-authored-by: maleksan85 <[email protected]>

* nightly_fixed_aiter_integration_final_20250305 README update (#470)

* nightly_fixed_aiter_integration_final_20250305 README update (perf results only)

* Update Docker Manifest git hash

* Update Docker Manifest and added nightly_fixed_aiter_integration_final_20250305

* some more updates

* Update AITER section with example

* Updated AITER command with larger batch size and model name

* Fixing typo

* Removed --max-model-len in AITER command

* Updating AITER instructions

* typo

* Another typo

* Whitespace

* modifying whats new section

* Another typo

---------

Co-authored-by: arakowsk-amd <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>

---------

Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Mark McLoughlin <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: mgoin <[email protected]>
Signed-off-by: Xiongfei Wei <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Siyuan Liu <[email protected]>
Signed-off-by: KuntaiDu <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Sage Moore <[email protected]>
Signed-off-by: Michael Goin <[email protected]>
Signed-off-by: Nishidha Panpaliya <[email protected]>
Signed-off-by: dangshunya <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Iacopo Poli <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: [email protected] <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: Daivid Savernin-Frenk <[email protected]>
Signed-off-by: simon-mo <[email protected]>
Signed-off-by: vincent-4 <[email protected]>
Signed-off-by: Brayden Zhong <[email protected]>
Signed-off-by: pyc96 <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Russell Bryant <[email protected]>
Signed-off-by: Yuan Tang <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Kyle Huang <[email protected]>
Signed-off-by: Linkun Chen <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: liuyanyi <[email protected]>
Signed-off-by: courage17340 <[email protected]>
Signed-off-by: Jitse Klomp <[email protected]>
Signed-off-by: Dilip Gowda Bhagavan <[email protected]>
Signed-off-by: Rishika Kedia <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: ZhongYingMatrix <[email protected]>
Signed-off-by: Himanshu Jaju <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: luka <[email protected]>
Signed-off-by: Ilya Lavrenov <[email protected]>
Signed-off-by: vincent-pli <[email protected]>
Signed-off-by: மனோஜ்குமார் பழனிச்சாமி <[email protected]>
Signed-off-by: root <[email protected]>
Signed-off-by: Aleksandr Malyshev <[email protected]>
Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: Jeremy Arnold <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Yarong Mu <[email protected]>
Signed-off-by: Mathis Felardos <[email protected]>
Signed-off-by: Aviv Keshet <[email protected]>
Signed-off-by: Roger Meier <[email protected]>
Signed-off-by: WangErXiao <[email protected]>
Signed-off-by: Alexander Matveev <[email protected]>
Signed-off-by: 22quinn <[email protected]>
Signed-off-by: YaoJiayi <[email protected]>
Signed-off-by: Yuchen Yan <[email protected]>
Signed-off-by: Martin Hoyer <[email protected]>
Signed-off-by: Jennifer Zhao <[email protected]>
Signed-off-by: chaunceyjiang <[email protected]>
Signed-off-by: SzymonOzog <[email protected]>
Signed-off-by: SzymonOzog <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Mark McLoughlin <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Qubitium-ModelCloud <[email protected]>
Co-authored-by: mgoin <[email protected]>
Co-authored-by: iefgnoix <[email protected]>
Co-authored-by: Travis Johnson <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Divakar Verma <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
Co-authored-by: Rui Qiao <[email protected]>
Co-authored-by: Zhanwen Chen <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: lkchen <[email protected]>
Co-authored-by: kushanam <[email protected]>
Co-authored-by: Siyuan Liu <[email protected]>
Co-authored-by: Kuntai Du <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Sage Moore <[email protected]>
Co-authored-by: Nishidha <[email protected]>
Co-authored-by: rainkert <[email protected]>
Co-authored-by: dangshunya <[email protected]>
Co-authored-by: Congcong Chen <[email protected]>
Co-authored-by: Benjamin Chislett <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: Lu Fang <[email protected]>
Co-authored-by: Iacopo Poli <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Zhe Zhang <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
Co-authored-by: DaividFrank <[email protected]>
Co-authored-by: Simon Mo <[email protected]>
Co-authored-by: Vincent <[email protected]>
Co-authored-by: Brayden Zhong <[email protected]>
Co-authored-by: Ye Cao <[email protected]>
Co-authored-by: Serena <[email protected]>
Co-authored-by: pyc96 <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Ying Zhong <[email protected]>
Co-authored-by: Russell Bryant <[email protected]>
Co-authored-by: Yuan Tang <[email protected]>
Co-authored-by: Ce Gao <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Nicolò Lucchesi <[email protected]>
Co-authored-by: Pavani Majety <[email protected]>
Co-authored-by: kYLe <[email protected]>
Co-authored-by: NickLucche <[email protected]>
Co-authored-by: Yanyi Liu <[email protected]>
Co-authored-by: Irina Yuryeva <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
Co-authored-by: courage17340 <[email protected]>
Co-authored-by: Jitse Klomp <[email protected]>
Co-authored-by: Dilip Gowda Bhagavan <[email protected]>
Co-authored-by: Rishika Kedia <[email protected]>
Co-authored-by: Burkhard Ringlein <[email protected]>
Co-authored-by: Jan van Lunteren <[email protected]>
Co-authored-by: Himanshu Jaju <[email protected]>
Co-authored-by: Chengji Yao <[email protected]>
Co-authored-by: Daniel Li <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Co-authored-by: Ilya Lavrenov <[email protected]>
Co-authored-by: Peng Li <[email protected]>
Co-authored-by: மனோஜ்குமார் பழனிச்சாமி <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: Aaron Pham <[email protected]>
Co-authored-by: York-RDWang <[email protected]>
Co-authored-by: Jeremy Arnold <[email protected]>
Co-authored-by: Jinzhen Lin <[email protected]>
Co-authored-by: yarongmu-google <[email protected]>
Co-authored-by: afeldman-nm <[email protected]>
Co-authored-by: Mathis Felardos <[email protected]>
Co-authored-by: Aviv Keshet <[email protected]>
Co-authored-by: Roger Meier <[email protected]>
Co-authored-by: Robin <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Co-authored-by: 22quinn <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: Jiayi Yao <[email protected]>
Co-authored-by: Yuchen Yan <[email protected]>
Co-authored-by: Martin Hoyer <[email protected]>
Co-authored-by: Jennifer Zhao <[email protected]>
Co-authored-by: Jennifer Zhao <[email protected]>
Co-authored-by: Chauncey <[email protected]>
Co-authored-by: Szymon Ożóg <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: Mcirino1 <[email protected]>
Co-authored-by: arakowsk-amd <[email protected]>
captainzmc pushed a commit to captainzmc/vllm that referenced this pull request Mar 12, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants