11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ from typing import Optional
4+
5+ import numpy as np
36import pytest
47import torch
58
6- from vllm .multimodal .cache import MultiModalCache , MultiModalCacheItemMetadata
9+ from vllm .config import ModelConfig , ParallelConfig , VllmConfig
10+ from vllm .multimodal .cache import (MultiModalCache ,
11+ MultiModalProcessorCacheItem ,
12+ MultiModalProcessorCacheItemMetadata ,
13+ processor_cache_from_config ,
14+ receiver_cache_from_config )
15+ from vllm .multimodal .hasher import MultiModalHasher
716from vllm .multimodal .inputs import (MultiModalFieldElem , MultiModalKwargsItem ,
817 MultiModalKwargsItems ,
918 MultiModalSharedField )
19+ from vllm .multimodal .processing import PromptInsertion
20+ from vllm .multimodal .registry import MultiModalRegistry
21+
1022
23+ def _dummy_elem (
24+ modality : str ,
25+ key : str ,
26+ size : int ,
27+ * ,
28+ rng : Optional [np .random .RandomState ] = None ,
29+ ):
30+ if rng is None :
31+ data = torch .empty ((size , ), dtype = torch .int8 )
32+ else :
33+ data = torch .from_numpy (rng .randint (4 , size = (size , ), dtype = np .int8 ))
1134
12- def _dummy_elem (modality : str , key : str , size : int ):
1335 return MultiModalFieldElem (
1436 modality = modality ,
1537 key = key ,
16- data = torch . empty (( size , ), dtype = torch . int8 ) ,
38+ data = data ,
1739 field = MultiModalSharedField (1 ),
1840 )
1941
2042
21- def _dummy_item (modality : str , size_by_key : dict [str , int ]):
43+ def _dummy_item (
44+ modality : str ,
45+ size_by_key : dict [str , int ],
46+ * ,
47+ rng : Optional [np .random .RandomState ] = None ,
48+ ):
2249 return MultiModalKwargsItem .from_elems ([
23- _dummy_elem (modality , key , size ) for key , size in size_by_key .items ()
50+ _dummy_elem (modality , key , size , rng = rng )
51+ for key , size in size_by_key .items ()
2452 ])
2553
2654
27- def _dummy_items (size_by_key_modality : dict [str , dict [str , int ]]):
55+ def _dummy_items (
56+ size_by_key_modality : dict [str , dict [str , int ]],
57+ * ,
58+ rng : Optional [np .random .RandomState ] = None ,
59+ ):
2860 return MultiModalKwargsItems .from_seq ([
29- _dummy_item (modality , size_by_key )
61+ _dummy_item (modality , size_by_key , rng = rng )
3062 for modality , size_by_key in size_by_key_modality .items ()
3163 ])
3264
@@ -48,5 +80,139 @@ def test_cache_item_size(item, expected_size):
4880 cache ["" ] = item
4981 assert cache .currsize == expected_size
5082
51- cache ["" ] = MultiModalCacheItemMetadata .wraps (item )
83+ prompt_update = PromptInsertion ("dummy" , "target" , "insertion" ) \
84+ .resolve (0 )
85+
86+ cache ["" ] = MultiModalProcessorCacheItem (item , [prompt_update ])
87+ assert cache .currsize == expected_size
88+
89+ cache ["" ] = MultiModalProcessorCacheItemMetadata (item , [prompt_update ])
5290 assert cache .currsize == expected_size
91+
92+
93+ def _create_vllm_config (
94+ * ,
95+ mm_processor_cache_gb : float ,
96+ enable_ipc : bool ,
97+ ):
98+ return VllmConfig (
99+ model_config = ModelConfig (mm_processor_cache_gb = mm_processor_cache_gb ),
100+ parallel_config = ParallelConfig (
101+ data_parallel_size = 1 if enable_ipc else 2 ),
102+ )
103+
104+
105+ def _compare_caches (
106+ config_0 : VllmConfig ,
107+ config_1 : VllmConfig ,
108+ * ,
109+ item_capacity : int = 8 ,
110+ hit_rate : float = 0.5 ,
111+ max_items_per_iter : int = 3 ,
112+ is_cached_calls_per_iter : int ,
113+ n_iter : int = 100 ,
114+ seed : int = 0 ,
115+ ):
116+ mm_registry = MultiModalRegistry ()
117+ cache_0_p0 = processor_cache_from_config (config_0 , mm_registry )
118+ cache_0_p1 = receiver_cache_from_config (config_0 , mm_registry )
119+ cache_1_p0 = processor_cache_from_config (config_1 , mm_registry )
120+ cache_1_p1 = receiver_cache_from_config (config_1 , mm_registry )
121+
122+ cache_size_gb = max (
123+ config_0 .model_config .mm_processor_cache_gb ,
124+ config_1 .model_config .mm_processor_cache_gb ,
125+ )
126+ item_size_gb = int (cache_size_gb / item_capacity )
127+
128+ rng = np .random .RandomState (seed )
129+ all_items = [
130+ _dummy_item ("item" , {"key" : item_size_gb }, rng = rng )
131+ for _ in range (int (item_capacity / hit_rate ))
132+ ]
133+ all_hashes = [
134+ MultiModalHasher .hash_kwargs (item = item .get_data ())
135+ for item in all_items
136+ ]
137+
138+ # Should not be used since there is nothing to convert to text
139+ prompt_update = PromptInsertion ("dummy" , "target" , "insertion" )
140+
141+ for it in range (n_iter ):
142+ num_items_to_select = rng .randint (0 , max_items_per_iter )
143+ item_idxs_to_select = rng .choice (len (all_items ), num_items_to_select )
144+
145+ selected_items = [all_items [idx ] for idx in item_idxs_to_select ]
146+ selected_hashes = [all_hashes [idx ] for idx in item_idxs_to_select ]
147+
148+ if cache_0_p0 is None :
149+ cache_0_p0_out = selected_items
150+ else :
151+ for _ in range (is_cached_calls_per_iter ):
152+ cache_0_p0 .is_cached (selected_hashes )
153+ cache_0_p0_out = [
154+ item for item , _ in cache_0_p0 .get_and_update (
155+ [(item , prompt_update .content ) for item in selected_items ],
156+ selected_hashes ,
157+ )
158+ ]
159+
160+ if cache_1_p0 is None :
161+ cache_1_p0_out = selected_items
162+ else :
163+ for _ in range (is_cached_calls_per_iter ):
164+ cache_1_p0 .is_cached (selected_hashes )
165+ cache_1_p0_out = [
166+ item for item , _ in cache_1_p0 .get_and_update (
167+ [(item , prompt_update .content ) for item in selected_items ],
168+ selected_hashes ,
169+ )
170+ ]
171+
172+ if cache_0_p1 is None :
173+ cache_0_p1_out = cache_0_p0_out
174+ else :
175+ cache_0_p1_out = cache_0_p1 .get_and_update (cache_0_p0_out ,
176+ selected_hashes )
177+
178+ if cache_1_p1 is None :
179+ cache_1_p1_out = cache_1_p0_out
180+ else :
181+ cache_1_p1_out = cache_1_p1 .get_and_update (cache_1_p0_out ,
182+ selected_hashes )
183+
184+ assert cache_0_p1_out == cache_1_p1_out , f"Failed at { it = } "
185+
186+
187+ @pytest .mark .parametrize ("is_cached_calls_per_iter" , [1 , 2 , 3 ])
188+ def test_ipc_enable_disable_consistency (is_cached_calls_per_iter ):
189+ cache_size_gb = 1 / (1 << 20 )
190+
191+ vllm_config_ipc_enabled = _create_vllm_config (
192+ mm_processor_cache_gb = cache_size_gb ,
193+ enable_ipc = True ,
194+ )
195+ vllm_config_ipc_disabled = _create_vllm_config (
196+ mm_processor_cache_gb = 0 ,
197+ enable_ipc = False ,
198+ )
199+ vllm_config_cache_disabled = _create_vllm_config (
200+ mm_processor_cache_gb = cache_size_gb ,
201+ enable_ipc = True ,
202+ )
203+
204+ _compare_caches (
205+ vllm_config_ipc_enabled ,
206+ vllm_config_ipc_disabled ,
207+ is_cached_calls_per_iter = is_cached_calls_per_iter ,
208+ )
209+ _compare_caches (
210+ vllm_config_ipc_disabled ,
211+ vllm_config_cache_disabled ,
212+ is_cached_calls_per_iter = is_cached_calls_per_iter ,
213+ )
214+ _compare_caches (
215+ vllm_config_cache_disabled ,
216+ vllm_config_ipc_enabled ,
217+ is_cached_calls_per_iter = is_cached_calls_per_iter ,
218+ )
0 commit comments