Skip to content

Commit b3fd081

Browse files
authored
Merge branch 'main' into disable_prefix_caching_per_request
2 parents cffd20c + 11fd69d commit b3fd081

File tree

58 files changed

+1543
-307
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1543
-307
lines changed

.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml

Lines changed: 0 additions & 12 deletions
This file was deleted.

.buildkite/test-pipeline.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,8 +546,11 @@ steps:
546546

547547
- label: Model Executor Test # 23min
548548
timeout_in_minutes: 35
549+
torch_nightly: true
549550
mirror_hardwares: [amdexperimental]
550551
source_file_dependencies:
552+
- vllm/engine/arg_utils.py
553+
- vllm/config/model.py
551554
- vllm/model_executor
552555
- tests/model_executor
553556
- tests/entrypoints/openai/test_tensorizer_entrypoint.py

.github/CODEOWNERS

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,8 @@ mkdocs.yaml @hmellor
127127
/vllm/config/pooler.py @noooop
128128
/vllm/pooling_params.py @noooop
129129
/vllm/model_executor/layers/pooler.py @noooop
130+
131+
# Security guide and policies
132+
/docs/usage/security.md @russellb
133+
/SECURITY.md @russellb
134+
/docs/contributing/vulnerability_management.md @russellb

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from vllm.utils.argparse_utils import FlexibleArgumentParser
1717

1818
DEFAULT_MODELS = [
19-
"nm-testing/Mixtral-8x7B-Instruct-v0.1",
20-
"nm-testing/deepseekv2-lite",
19+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
20+
"deepseek-ai/DeepSeek-V2-Lite",
2121
"ibm-granite/granite-3.0-1b-a400m",
2222
"ibm-granite/granite-3.0-3b-a800m",
2323
]

benchmarks/kernels/benchmark_shapes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@
7878
}
7979

8080
WEIGHT_SHAPES_MOE = {
81-
"nm-testing/Mixtral-8x7B-Instruct-v0.1": [
81+
"mistralai/Mixtral-8x7B-Instruct-v0.1": [
8282
[8, 2, 4096, 28672],
8383
[8, 2, 14336, 4096],
8484
],
85-
"nm-testing/deepseekv2-lite": [
85+
"deepseek-ai/DeepSeek-V2-Lite": [
8686
[64, 6, 2048, 1408],
8787
],
8888
"ibm-granite/granite-3.0-1b-a400m": [

csrc/attention/merge_attn_states.cu

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,32 @@ __global__ void merge_attn_states_kernel(
4646
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
4747

4848
const float max_lse = fmaxf(p_lse, s_lse);
49+
50+
/* In certain edge cases, MLA can produce p_lse = s_lse = -inf;
51+
continuing the pipeline then yields NaN. Root cause: with chunked prefill
52+
a batch may be split into two chunks; if a request in that batch has no
53+
prefix hit, every LSE entry for that request’s position is -inf, and at
54+
this moment we merge cross-attention at first. For now we simply emit
55+
prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix
56+
this problem.
57+
*/
58+
if (std::isinf(max_lse)) {
59+
if (pack_offset < head_size) {
60+
// Pack 128b load
61+
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
62+
prefix_head_ptr)[pack_offset / pack_size];
63+
64+
// Pack 128b storage
65+
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
66+
p_out_pack;
67+
}
68+
// We only need to write to output_lse once per head.
69+
if (output_lse != nullptr && pack_idx == 0) {
70+
output_lse[head_idx * num_tokens + token_idx] = max_lse;
71+
}
72+
return;
73+
}
74+
4975
p_lse = p_lse - max_lse;
5076
s_lse = s_lse - max_lse;
5177
const float p_se = expf(p_lse);

docs/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ vLLM is flexible and easy to use with:
5656
- Tensor, pipeline, data and expert parallelism support for distributed inference
5757
- Streaming outputs
5858
- OpenAI-compatible API server
59-
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
59+
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
6060
- Prefix caching support
6161
- Multi-LoRA support
6262

docs/deployment/frameworks/helm.md

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Before you begin, ensure that you have the following:
1313
- A running Kubernetes cluster
1414
- NVIDIA Kubernetes Device Plugin (`k8s-device-plugin`): This can be found at [https://github.com/NVIDIA/k8s-device-plugin](https://github.com/NVIDIA/k8s-device-plugin)
1515
- Available GPU resources in your cluster
16-
- An S3 with the model which will be deployed
16+
- (Optional) An S3 bucket or other storage with the model weights, if using automatic model download
1717

1818
## Installing the chart
1919

@@ -61,10 +61,16 @@ The following table describes configurable parameters of the chart in `values.ya
6161
| deploymentStrategy | object | {} | Deployment strategy configuration |
6262
| externalConfigs | list | [] | External configuration |
6363
| extraContainers | list | [] | Additional containers configuration |
64-
| extraInit | object | {"pvcStorage":"1Gi","s3modelpath":"relative_s3_model_path/opt-125m", "awsEc2MetadataDisabled": true} | Additional configuration for the init container |
65-
| extraInit.pvcStorage | string | "1Gi" | Storage size of the s3 |
66-
| extraInit.s3modelpath | string | "relative_s3_model_path/opt-125m" | Path of the model on the s3 which hosts model weights and config files |
67-
| extraInit.awsEc2MetadataDisabled | boolean | true | Disables the use of the Amazon EC2 instance metadata service |
64+
| extraInit | object | {"modelDownload":{"enabled":true},"initContainers":[],"pvcStorage":"1Gi"} | Additional configuration for init containers |
65+
| extraInit.modelDownload | object | {"enabled":true} | Model download functionality configuration |
66+
| extraInit.modelDownload.enabled | bool | true | Enable automatic model download job and wait container |
67+
| extraInit.modelDownload.image | object | {"repository":"amazon/aws-cli","tag":"2.6.4","pullPolicy":"IfNotPresent"} | Image for model download operations |
68+
| extraInit.modelDownload.waitContainer | object | {} | Wait container configuration (command, args, env) |
69+
| extraInit.modelDownload.downloadJob | object | {} | Download job configuration (command, args, env) |
70+
| extraInit.initContainers | list | [] | Custom init containers (appended after model download if enabled) |
71+
| extraInit.pvcStorage | string | "1Gi" | Storage size for the PVC |
72+
| extraInit.s3modelpath | string | "relative_s3_model_path/opt-125m" | (Optional) Path of the model on S3 |
73+
| extraInit.awsEc2MetadataDisabled | bool | true | (Optional) Disable AWS EC2 metadata service |
6874
| extraPorts | list | [] | Additional ports configuration |
6975
| gpuModels | list | ["TYPE_GPU_USED"] | Type of gpu used |
7076
| image | object | {"command":["vllm","serve","/data/","--served-model-name","opt-125m","--host","0.0.0.0","--port","8000"],"repository":"vllm/vllm-openai","tag":"latest"} | Image configuration |
@@ -98,3 +104,36 @@ The following table describes configurable parameters of the chart in `values.ya
98104
| serviceName | string | "" | Service name |
99105
| servicePort | int | 80 | Service port |
100106
| labels.environment | string | test | Environment name |
107+
108+
## Configuration Examples
109+
110+
### Using S3 Model Download (Default)
111+
112+
```yaml
113+
extraInit:
114+
modelDownload:
115+
enabled: true
116+
pvcStorage: "10Gi"
117+
s3modelpath: "models/llama-7b"
118+
```
119+
120+
### Using Custom Init Containers Only
121+
122+
For use cases like llm-d where you need custom sidecars without model download:
123+
124+
```yaml
125+
extraInit:
126+
modelDownload:
127+
enabled: false
128+
initContainers:
129+
- name: llm-d-routing-proxy
130+
image: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0
131+
imagePullPolicy: IfNotPresent
132+
ports:
133+
- containerPort: 8080
134+
name: proxy
135+
securityContext:
136+
runAsUser: 1000
137+
restartPolicy: Always
138+
pvcStorage: "10Gi"
139+
```

docs/getting_started/installation/cpu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ Currently, there are no pre-built CPU wheels.
9494
## Related runtime environment variables
9595

9696
- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`.
97-
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively.
97+
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists, `auto` (by default), or `nobind` (to disable binding to individual CPU cores and to inherit user-defined OpenMP variables). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. If set to `nobind`, the number of OpenMP threads is determined by the standard `OMP_NUM_THREADS` environment variable.
9898
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`.
9999
- `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence.
100100
- `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).

examples/online_serving/chart-helm/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,15 @@ This directory contains a Helm chart for deploying the vllm application. The cha
1919
- templates/pvc.yaml: Template for Persistent Volume Claims.
2020
- templates/secrets.yaml: Template for Kubernetes Secrets.
2121
- templates/service.yaml: Template for creating Services.
22+
23+
## Running Tests
24+
25+
This chart includes unit tests using [helm-unittest](https://github.com/helm-unittest/helm-unittest). Install the plugin and run tests:
26+
27+
```bash
28+
# Install plugin
29+
helm plugin install https://github.com/helm-unittest/helm-unittest
30+
31+
# Run tests
32+
helm unittest .
33+
```

0 commit comments

Comments
 (0)