6
6
7
7
import vllm .envs as envs
8
8
from vllm .attention import get_attn_backend
9
- from vllm .config import (CacheConfig , CompilationConfig , DeviceConfig ,
10
- ModelConfig , ParallelConfig , VllmConfig )
9
+ from vllm .config import (CacheConfig , DeviceConfig , ModelConfig ,
10
+ ParallelConfig , VllmConfig )
11
11
from vllm .distributed import (ensure_model_parallel_initialized ,
12
12
init_distributed_environment )
13
13
from vllm .logger import init_logger
@@ -33,8 +33,8 @@ class CPUCacheEngine:
33
33
"""
34
34
35
35
def __init__ (self , cache_config : CacheConfig , model_config : ModelConfig ,
36
- parallel_config : ParallelConfig , device_config : DeviceConfig ,
37
- compilation_config : CompilationConfig ) -> None :
36
+ parallel_config : ParallelConfig ,
37
+ device_config : DeviceConfig ) -> None :
38
38
assert device_config .device_type == "cpu"
39
39
self .cache_config = cache_config
40
40
self .model_config = model_config
@@ -66,8 +66,6 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
66
66
67
67
# Initialize the cache.
68
68
self .cpu_cache = self ._allocate_kv_cache (self .num_cpu_blocks )
69
- bind_kv_cache (compilation_config .static_forward_context ,
70
- self .cpu_cache )
71
69
72
70
def _allocate_kv_cache (
73
71
self ,
@@ -292,13 +290,15 @@ def _init_cache_engine(self) -> None:
292
290
self .model_config ,
293
291
self .parallel_config ,
294
292
self .device_config ,
295
- self .compilation_config ,
296
293
) for _ in range (self .parallel_config .pipeline_parallel_size )
297
294
]
298
295
self .cpu_cache = [
299
296
self .cache_engine [ve ].cpu_cache
300
297
for ve in range (self .parallel_config .pipeline_parallel_size )
301
298
]
299
+ for ve in range (self .parallel_config .pipeline_parallel_size ):
300
+ bind_kv_cache (self .compilation_config .static_forward_context ,
301
+ self .cpu_cache [ve ], ve )
302
302
self .model_runner .block_size = self .cache_engine [0 ].block_size
303
303
304
304
assert all (
0 commit comments