Skip to content

Commit 4e658a5

Browse files
DarkLight1337epwalsh
authored andcommitted
[Core] Use key-only cache for BaseMultiModalProcessor (vllm-project#23018)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 126249b commit 4e658a5

File tree

29 files changed

+954
-394
lines changed

29 files changed

+954
-394
lines changed

docs/configuration/conserving_memory.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
8686

8787
If you run out of CPU RAM, try the following options:
8888

89-
- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process)
89+
- (Multi-modal models only) you can set the size of multi-modal cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB).
9090
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
9191

9292
## Multi-modal input limits

docs/configuration/optimization.md

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -204,20 +204,33 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2
204204
to avoid CPU resource exhaustion.
205205

206206
!!! note
207-
[Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled
208-
because it requires a one-to-one correspondence between API and engine core processes.
207+
API server scale-out disables [multi-modal IPC caching](#ipc-caching)
208+
because it requires a one-to-one correspondance between API and engine core processes.
209209

210-
## Multi-Modal Caching
210+
This does not impact [multi-modal processor caching](#processor-caching).
211211

212-
### Processor Cache
212+
## Multi-Modal Caching
213213

214-
By default, the multi-modal processor cache is enabled to avoid repeatedly processing
215-
the same multi-modal inputs via Hugging Face `AutoProcessor`,
214+
Multi-modal caching avoids repeated transfer or processing of the same multi-modal data,
216215
which commonly occurs in multi-turn conversations.
217216

218-
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb`
219-
(default 4 GiB per API process + 4 GiB per engine core process).
220-
If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`.
217+
### Processor Caching
218+
219+
Multi-modal processor caching is automatically enabled
220+
to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalProcessor`.
221+
222+
### IPC Caching
223+
224+
Multi-modal IPC caching is automatically enabled when
225+
there is a one-to-one correspondance between API (`P0`) and engine core (`P1`) processes,
226+
to avoid repeatedly transferring the same multi-modal inputs between them.
227+
228+
### Configuration
229+
230+
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB).
231+
232+
If you do not benefit much from the cache, you can disable both IPC
233+
and processor caching completely via `mm_processor_cache_gb=0`.
221234

222235
Examples:
223236

@@ -230,3 +243,16 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
230243
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
231244
mm_processor_cache_gb=0)
232245
```
246+
247+
### Cache Placement
248+
249+
Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows:
250+
251+
| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory |
252+
|-------------------|-------------|------------|------------|-------------|
253+
||| K | K + V | `mm_processor_cache_gb * data_parallel_size` |
254+
||| K + V | N/A | `mm_processor_cache_gb * api_server_count` |
255+
||| N/A | N/A | `0` |
256+
257+
K: Stores the hashes of multi-modal items
258+
V: Stores the processed tensor data of multi-modal items

tests/models/multimodal/processing/test_common.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from vllm.config import ModelConfig
1515
from vllm.inputs import InputProcessingContext
1616
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
17+
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
1718
from vllm.multimodal.inputs import MultiModalInputs
18-
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
19+
from vllm.multimodal.processing import BaseMultiModalProcessor
1920
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
2021
cached_tokenizer_from_config,
2122
encode_tokens)
@@ -63,6 +64,8 @@ def _test_processing_correctness(
6364
revision=model_info.revision,
6465
trust_remote_code=model_info.trust_remote_code,
6566
hf_overrides=model_info.hf_overrides,
67+
# Ensure that the cache can fit all of the data
68+
mm_processor_cache_gb=2048,
6669
)
6770

6871
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
@@ -71,8 +74,7 @@ def _test_processing_correctness(
7174
model_config,
7275
tokenizer=cached_tokenizer_from_config(model_config),
7376
)
74-
# Ensure that it can fit all of the data
75-
cache = ProcessingCache(capacity_gb=2048)
77+
cache = MultiModalProcessorOnlyCache(model_config)
7678

7779
processing_info = factories.info(ctx)
7880
supported_mm_limits = processing_info.get_supported_mm_limits()

tests/multimodal/test_cache.py

Lines changed: 174 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,64 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Optional
4+
5+
import numpy as np
36
import pytest
47
import torch
58

6-
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
9+
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
10+
from vllm.multimodal.cache import (MultiModalCache,
11+
MultiModalProcessorCacheItem,
12+
MultiModalProcessorCacheItemMetadata,
13+
processor_cache_from_config,
14+
receiver_cache_from_config)
15+
from vllm.multimodal.hasher import MultiModalHasher
716
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
817
MultiModalKwargsItems,
918
MultiModalSharedField)
19+
from vllm.multimodal.processing import PromptInsertion
20+
from vllm.multimodal.registry import MultiModalRegistry
21+
1022

23+
def _dummy_elem(
24+
modality: str,
25+
key: str,
26+
size: int,
27+
*,
28+
rng: Optional[np.random.RandomState] = None,
29+
):
30+
if rng is None:
31+
data = torch.empty((size, ), dtype=torch.int8)
32+
else:
33+
data = torch.from_numpy(rng.randint(4, size=(size, ), dtype=np.int8))
1134

12-
def _dummy_elem(modality: str, key: str, size: int):
1335
return MultiModalFieldElem(
1436
modality=modality,
1537
key=key,
16-
data=torch.empty((size, ), dtype=torch.int8),
38+
data=data,
1739
field=MultiModalSharedField(1),
1840
)
1941

2042

21-
def _dummy_item(modality: str, size_by_key: dict[str, int]):
43+
def _dummy_item(
44+
modality: str,
45+
size_by_key: dict[str, int],
46+
*,
47+
rng: Optional[np.random.RandomState] = None,
48+
):
2249
return MultiModalKwargsItem.from_elems([
23-
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
50+
_dummy_elem(modality, key, size, rng=rng)
51+
for key, size in size_by_key.items()
2452
])
2553

2654

27-
def _dummy_items(size_by_key_modality: dict[str, dict[str, int]]):
55+
def _dummy_items(
56+
size_by_key_modality: dict[str, dict[str, int]],
57+
*,
58+
rng: Optional[np.random.RandomState] = None,
59+
):
2860
return MultiModalKwargsItems.from_seq([
29-
_dummy_item(modality, size_by_key)
61+
_dummy_item(modality, size_by_key, rng=rng)
3062
for modality, size_by_key in size_by_key_modality.items()
3163
])
3264

@@ -48,5 +80,139 @@ def test_cache_item_size(item, expected_size):
4880
cache[""] = item
4981
assert cache.currsize == expected_size
5082

51-
cache[""] = MultiModalCacheItemMetadata.wraps(item)
83+
prompt_update = PromptInsertion("dummy", "target", "insertion") \
84+
.resolve(0)
85+
86+
cache[""] = MultiModalProcessorCacheItem(item, [prompt_update])
87+
assert cache.currsize == expected_size
88+
89+
cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update])
5290
assert cache.currsize == expected_size
91+
92+
93+
def _create_vllm_config(
94+
*,
95+
mm_processor_cache_gb: float,
96+
enable_ipc: bool,
97+
):
98+
return VllmConfig(
99+
model_config=ModelConfig(mm_processor_cache_gb=mm_processor_cache_gb),
100+
parallel_config=ParallelConfig(
101+
data_parallel_size=1 if enable_ipc else 2),
102+
)
103+
104+
105+
def _compare_caches(
106+
config_0: VllmConfig,
107+
config_1: VllmConfig,
108+
*,
109+
item_capacity: int = 8,
110+
hit_rate: float = 0.5,
111+
max_items_per_iter: int = 3,
112+
is_cached_calls_per_iter: int,
113+
n_iter: int = 100,
114+
seed: int = 0,
115+
):
116+
mm_registry = MultiModalRegistry()
117+
cache_0_p0 = processor_cache_from_config(config_0, mm_registry)
118+
cache_0_p1 = receiver_cache_from_config(config_0, mm_registry)
119+
cache_1_p0 = processor_cache_from_config(config_1, mm_registry)
120+
cache_1_p1 = receiver_cache_from_config(config_1, mm_registry)
121+
122+
cache_size_gb = max(
123+
config_0.model_config.mm_processor_cache_gb,
124+
config_1.model_config.mm_processor_cache_gb,
125+
)
126+
item_size_gb = int(cache_size_gb / item_capacity)
127+
128+
rng = np.random.RandomState(seed)
129+
all_items = [
130+
_dummy_item("item", {"key": item_size_gb}, rng=rng)
131+
for _ in range(int(item_capacity / hit_rate))
132+
]
133+
all_hashes = [
134+
MultiModalHasher.hash_kwargs(item=item.get_data())
135+
for item in all_items
136+
]
137+
138+
# Should not be used since there is nothing to convert to text
139+
prompt_update = PromptInsertion("dummy", "target", "insertion")
140+
141+
for it in range(n_iter):
142+
num_items_to_select = rng.randint(0, max_items_per_iter)
143+
item_idxs_to_select = rng.choice(len(all_items), num_items_to_select)
144+
145+
selected_items = [all_items[idx] for idx in item_idxs_to_select]
146+
selected_hashes = [all_hashes[idx] for idx in item_idxs_to_select]
147+
148+
if cache_0_p0 is None:
149+
cache_0_p0_out = selected_items
150+
else:
151+
for _ in range(is_cached_calls_per_iter):
152+
cache_0_p0.is_cached(selected_hashes)
153+
cache_0_p0_out = [
154+
item for item, _ in cache_0_p0.get_and_update(
155+
[(item, prompt_update.content) for item in selected_items],
156+
selected_hashes,
157+
)
158+
]
159+
160+
if cache_1_p0 is None:
161+
cache_1_p0_out = selected_items
162+
else:
163+
for _ in range(is_cached_calls_per_iter):
164+
cache_1_p0.is_cached(selected_hashes)
165+
cache_1_p0_out = [
166+
item for item, _ in cache_1_p0.get_and_update(
167+
[(item, prompt_update.content) for item in selected_items],
168+
selected_hashes,
169+
)
170+
]
171+
172+
if cache_0_p1 is None:
173+
cache_0_p1_out = cache_0_p0_out
174+
else:
175+
cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out,
176+
selected_hashes)
177+
178+
if cache_1_p1 is None:
179+
cache_1_p1_out = cache_1_p0_out
180+
else:
181+
cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out,
182+
selected_hashes)
183+
184+
assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}"
185+
186+
187+
@pytest.mark.parametrize("is_cached_calls_per_iter", [1, 2, 3])
188+
def test_ipc_enable_disable_consistency(is_cached_calls_per_iter):
189+
cache_size_gb = 1 / (1 << 20)
190+
191+
vllm_config_ipc_enabled = _create_vllm_config(
192+
mm_processor_cache_gb=cache_size_gb,
193+
enable_ipc=True,
194+
)
195+
vllm_config_ipc_disabled = _create_vllm_config(
196+
mm_processor_cache_gb=0,
197+
enable_ipc=False,
198+
)
199+
vllm_config_cache_disabled = _create_vllm_config(
200+
mm_processor_cache_gb=cache_size_gb,
201+
enable_ipc=True,
202+
)
203+
204+
_compare_caches(
205+
vllm_config_ipc_enabled,
206+
vllm_config_ipc_disabled,
207+
is_cached_calls_per_iter=is_cached_calls_per_iter,
208+
)
209+
_compare_caches(
210+
vllm_config_ipc_disabled,
211+
vllm_config_cache_disabled,
212+
is_cached_calls_per_iter=is_cached_calls_per_iter,
213+
)
214+
_compare_caches(
215+
vllm_config_cache_disabled,
216+
vllm_config_ipc_enabled,
217+
is_cached_calls_per_iter=is_cached_calls_per_iter,
218+
)

vllm/config/__init__.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ class ModelConfig:
437437
from `AutoProcessor.from_pretrained`. The available overrides depend on the
438438
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
439439
"""
440-
mm_processor_cache_gb: int = 4
440+
mm_processor_cache_gb: float = 4
441441
"""The size (in GiB) of the multi-modal processor cache, which is used to
442442
avoid re-processing past multi-modal inputs.
443443
@@ -884,12 +884,6 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
884884

885885
return None
886886

887-
def set_mm_processor_cache_gb(self, value: int) -> None:
888-
mm_config = self.get_multimodal_config()
889-
890-
self.mm_processor_cache_gb = value
891-
mm_config.mm_processor_cache_gb = value
892-
893887
def _get_encoder_config(self):
894888
return get_sentence_transformer_tokenizer_config(
895889
self.model, self.revision)
@@ -1697,22 +1691,6 @@ def uses_mrope(self) -> bool:
16971691
def is_multimodal_model(self) -> bool:
16981692
return self.multimodal_config is not None
16991693

1700-
@property
1701-
def enable_mm_processor_cache(self) -> bool:
1702-
"""Whether the multi-modal processor cache should be enabled."""
1703-
mm_config = self.multimodal_config
1704-
if mm_config is None:
1705-
return False
1706-
1707-
return mm_config.mm_processor_cache_gb > 0
1708-
1709-
def get_mm_input_cache_gb(self) -> int:
1710-
mm_config = self.multimodal_config
1711-
if mm_config is None:
1712-
return 0
1713-
1714-
return envs.VLLM_MM_INPUT_CACHE_GIB
1715-
17161694
@property
17171695
def is_cross_encoder(self) -> bool:
17181696
return (self._model_info.supports_cross_encoding
@@ -2561,7 +2539,7 @@ class MultiModalConfig:
25612539
`{"num_crops": 4}`.
25622540
"""
25632541

2564-
mm_processor_cache_gb: int = 4
2542+
mm_processor_cache_gb: float = 4
25652543
"""
25662544
The size (in GiB) of the multi-modal processor cache, which is used to
25672545

vllm/engine/arg_utils.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ class EngineArgs:
351351
mm_processor_kwargs: Optional[Dict[str, Any]] = \
352352
MultiModalConfig.mm_processor_kwargs
353353
disable_mm_preprocessor_cache: bool = False # DEPRECATED
354-
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
354+
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
355355
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
356356
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
357357
# LoRA fields
@@ -1293,18 +1293,6 @@ def create_engine_config(
12931293
worker_extension_cls=self.worker_extension_cls,
12941294
)
12951295

1296-
if model_config.is_multimodal_model:
1297-
dp_supports_mm_processor_cache = (self.data_parallel_size == 1
1298-
or data_parallel_external_lb)
1299-
if (not dp_supports_mm_processor_cache
1300-
and model_config.mm_processor_cache_gb > 0):
1301-
logger.warning(
1302-
"Multi-modal processor cache is disabled because "
1303-
"it is not compatible with data parallelism when "
1304-
"there does not exist a one-to-one correspondance "
1305-
"between API and engine core processes.")
1306-
model_config.set_mm_processor_cache_gb(0)
1307-
13081296
speculative_config = self.create_speculative_config(
13091297
target_model_config=model_config,
13101298
target_parallel_config=parallel_config,

0 commit comments

Comments
 (0)