diff --git a/csrc/ops.h b/csrc/ops.h index 1dfd2e067e85..f471dfd80cc6 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -308,3 +308,5 @@ std::tuple allocate_shared_buffer_and_handle( int64_t size); int64_t open_mem_handle(torch::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); +void store_tensor(torch::Tensor device_tensor, torch::Tensor host_tensor); +void load_tensor(torch::Tensor host_tensor, torch::Tensor device_tensor); diff --git a/csrc/tensor_store_load_mem.cu b/csrc/tensor_store_load_mem.cu new file mode 100644 index 000000000000..7cdfd54cb3d1 --- /dev/null +++ b/csrc/tensor_store_load_mem.cu @@ -0,0 +1,100 @@ +#include +#include + +// Template-based CUDA kernel: Copy from device memory to pinned host memory +template +__global__ void store_kernel(const scalar_t* device_ptr, scalar_t* host_ptr, size_t num_elements) { + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < num_elements) { + host_ptr[idx] = device_ptr[idx]; + } +} + +// Templated CUDA kernel: Copy from pinned host memory to device memory +template +__global__ void load_kernel(const scalar_t* host_ptr, scalar_t* device_ptr, size_t num_elements) { + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < num_elements) { + device_ptr[idx] = host_ptr[idx]; + } +} + +// Templated wrapper function: Store Tensor to pinned memory +template +void store_tensor_impl(torch::Tensor device_tensor, torch::Tensor host_tensor) { + const auto num_elements = device_tensor.numel(); + const int threads = 256; + const int blocks = (num_elements + threads - 1) / threads; + + auto device_ptr = device_tensor.data_ptr(); + auto host_ptr = host_tensor.data_ptr(); + + store_kernel<<>>( + device_ptr, host_ptr, num_elements); +} + +// Templated wrapper function: Load Tensor from pinned memory +template +void load_tensor_impl(torch::Tensor host_tensor, torch::Tensor device_tensor) { + const auto num_elements = host_tensor.numel(); + const int threads = 256; + const int blocks = (num_elements + threads - 1) / threads; + + auto host_ptr = host_tensor.data_ptr(); + auto device_ptr = device_tensor.data_ptr(); + + load_kernel<<>>( + host_ptr, device_ptr, num_elements); +} + +// Type-dispatched wrapper function +void store_tensor(torch::Tensor device_tensor, torch::Tensor host_tensor) { + // Validate arguments + AT_ASSERT(device_tensor.is_cuda(), "Input tensor must be a CUDA tensor"); + AT_ASSERT(host_tensor.is_pinned(), "Output tensor must be pinned memory"); + AT_ASSERT(device_tensor.numel() == host_tensor.numel(), "Tensors must have same number of elements"); + AT_ASSERT(device_tensor.dtype() == host_tensor.dtype(), "Tensors must have same dtype"); + + // Type-based dispatch to different implementations + switch (device_tensor.scalar_type()) { + case torch::kFloat: + store_tensor_impl(device_tensor, host_tensor); + break; + case torch::kHalf: + store_tensor_impl(device_tensor, host_tensor); + break; + case torch::kBFloat16: + store_tensor_impl(device_tensor, host_tensor); + break; + default: + AT_ERROR("Unsupported data type: ", device_tensor.scalar_type()); + } +} + +void load_tensor(torch::Tensor host_tensor, torch::Tensor device_tensor) { + // Validate arguments + AT_ASSERT(device_tensor.is_cuda(), "Output tensor must be a CUDA tensor"); + AT_ASSERT(host_tensor.is_pinned(), "Input tensor must be pinned memory"); + AT_ASSERT(device_tensor.numel() == host_tensor.numel(), "Tensors must have same number of elements"); + AT_ASSERT(device_tensor.dtype() == host_tensor.dtype(), "Tensors must have same dtype"); + + // Type-based dispatch to different implementations + switch (host_tensor.scalar_type()) { + case torch::kFloat: + load_tensor_impl(host_tensor, device_tensor); + break; + case torch::kHalf: + load_tensor_impl(host_tensor, device_tensor); + break; + case torch::kBFloat16: + load_tensor_impl(host_tensor, device_tensor); + break; + default: + AT_ERROR("Unsupported data type: ", host_tensor.scalar_type()); + } +} + +// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +// m.def("store_tensor", &store_tensor, "Store CUDA tensor to pinned memory (supports float32, float16, bfloat16)"); +// m.def("load_tensor", &load_tensor, "Load CUDA tensor from pinned memory (supports float32, float16, bfloat16)"); +// } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7ca40a5e7827..47d8b9752062 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -665,4 +665,9 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { custom_ar.def("free_shared_buffer", &free_shared_buffer); } +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _mem_pool), mem_pool) { + mem_pool.def("store_tensor", &store_tensor); + mem_pool.def("load_tensor", &load_tensor); +} + REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/examples/online_serving/disagg_xpyd/disagg_prefill_proxy_xpyd.py b/examples/online_serving/disagg_xpyd/disagg_prefill_proxy_xpyd.py new file mode 100644 index 000000000000..448ba6064625 --- /dev/null +++ b/examples/online_serving/disagg_xpyd/disagg_prefill_proxy_xpyd.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import random +import socket +import threading +import uuid + +import aiohttp +import msgpack +import zmq +from quart import Quart, make_response, request + +count = 0 +prefill_instances: dict[str, str] = {} # http_address: zmq_address +decode_instances: dict[str, str] = {} # http_address: zmq_address + +prefill_cv = threading.Condition() +decode_cv = threading.Condition() + + +def _listen_for_register(poller, router_socket): + while True: + socks = dict(poller.poll()) + if router_socket in socks: + remote_address, message = router_socket.recv_multipart() + # data: {"type": "P", "http_address": "ip:port", + # "zmq_address": "ip:port"} + data = msgpack.loads(message) + # print("Received message from %s, data: %s", + # remote_address.decode(), data) + if data["type"] == "P": + global prefill_instances + global prefill_cv + with prefill_cv: + prefill_instances[ + data["http_address"]] = data["zmq_address"] + elif data["type"] == "D": + global decode_instances + global decode_cv + with decode_cv: + decode_instances[ + data["http_address"]] = data["zmq_address"] + else: + print("Unexpected, Received message from %s, data: %s", + remote_address, data) + + +def start_service_discovery(hostname, port): + if not hostname: + hostname = socket.gethostname() + if port == 0: + raise ValueError("Port cannot be 0") + + context = zmq.Context() + router_socket = context.socket(zmq.ROUTER) + router_socket.bind(f"tcp://{hostname}:{port}") + + poller = zmq.Poller() + poller.register(router_socket, zmq.POLLIN) + + _listener_thread = threading.Thread(target=_listen_for_register, + args=[poller, router_socket], + daemon=True) + _listener_thread.start() + return _listener_thread + + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +app = Quart(__name__) + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + +async def forward_request(url, data, request_id): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + async with session.post(url=url, json=data, + headers=headers) as response: + if response.status == 200: + # if response.headers.get('Transfer-Encoding') == 'chunked': + if True: + async for chunk_bytes in response.content.iter_chunked( + 1024): + yield chunk_bytes + else: + content = await response.read() + yield content + + +@app.route('/v1/completions', methods=['POST']) +async def handle_request(): + try: + original_request_data = await request.get_json() + + prefill_request = original_request_data.copy() + # change max_tokens = 1 to let it only do prefill + prefill_request['max_tokens'] = 1 + + global count + global prefill_instances + global prefill_cv + with prefill_cv: + # prefill_addr, prefill_zmq_addr = random.choice( + # list(prefill_instances.items())) + prefill_list = list(prefill_instances.items()) + prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)] + + global decode_instances + global decode_cv + with decode_cv: + # decode_addr, decode_zmq_addr = random.choice( + # list(decode_instances.items())) + decode_list = list(decode_instances.items()) + decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)] + + print(f"handle_request count: {count}, [HTTP:{prefill_addr}, " + f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, " + f"ZMQ:{decode_zmq_addr}]") + count += 1 + + request_id = ( + f"___prefill_addr_{prefill_zmq_addr}___decode_addr_{decode_zmq_addr}_{random_uuid()}" + ) + + # finish prefill + async for _ in forward_request(f'http://{prefill_addr}/v1/completions', + prefill_request, request_id): + continue + + # return decode + generator = forward_request(f'http://{decode_addr}/v1/completions', + original_request_data, request_id) + response = await make_response(generator) + response.timeout = None + + return response + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server") + print(e) + print("".join(traceback.format_exception(*exc_info))) + + +if __name__ == '__main__': + t = start_service_discovery("0.0.0.0", 30001) + app.run(host='0.0.0.0', port=10001) + t.join() diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0206d4552c8b..60aadcc1bd48 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1493,6 +1493,14 @@ def free_shared_buffer(ptr: int) -> None: torch.ops._C_custom_ar.free_shared_buffer(ptr) +def store_tensor(device_tensor: torch.Tensor, host_tensor: torch.Tensor): + torch.ops._C_mem_pool.store_tensor(device_tensor, host_tensor) + + +def load_tensor(host_tensor: torch.Tensor, device_tensor: torch.Tensor): + torch.ops._C_mem_pool.load_tensor(host_tensor, device_tensor) + + def get_flash_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 4f04899e92e6..83619c27f22f 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -271,6 +271,27 @@ def ncclGetUniqueId(self) -> ncclUniqueId: ctypes.byref(unique_id))) return unique_id + def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId: + """ + Reconstructs an `ncclUniqueId` object from bytes data. + + Args: + data: Must be a 128-byte data block (matching NCCL's unique_id). + + Returns: + ncclUniqueId: The reconstructed NCCL Unique ID object. + + Raises: + ValueError: If the input data length is not 128 bytes. + """ + if len(data) != 128: + raise ValueError( + f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes") + + unique_id = ncclUniqueId() + ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128) + return unique_id + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, rank: int) -> ncclComm_t: comm = ncclComm_t() diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 6532c101a4f6..33e80af7a7ce 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -49,11 +49,11 @@ def create_connector_v0(cls, rank: int, local_rank: int, return connector_cls(rank, local_rank, config) @classmethod - def create_connector_v1( - cls, - config: "VllmConfig", - role: KVConnectorRole, - ) -> KVConnectorBase_V1: + def create_connector_v1(cls, + config: "VllmConfig", + role: KVConnectorRole, + rank: int = 0, + local_rank: int = 0) -> KVConnectorBase_V1: if not envs.VLLM_USE_V1: raise ValueError("Attempting to initialize a V1 Connector, " f"but found {envs.VLLM_USE_V1=}") @@ -70,12 +70,13 @@ def create_connector_v1( # - Co-locate with worker process # - Should only be used inside the forward context & attention layer # We build separately to enforce strict separation - return connector_cls(config, role) + return connector_cls(config, role, rank, local_rank) # Register various connectors here. # The registration should not be done in each individual file, as we want to # only load the files corresponding to the current connector. + KVConnectorFactory.register_connector( "PyNcclConnector", "vllm.distributed.kv_transfer.kv_connector.simple_connector", @@ -96,11 +97,20 @@ def create_connector_v1( "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", "MooncakeStoreConnector") +KVConnectorFactory.register_connector( + "P2pConnector", "vllm.distributed.kv_transfer.kv_connector.p2p_connector", + "P2pConnector") + KVConnectorFactory.register_connector( "SharedStorageConnector", "vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", "SharedStorageConnector") +KVConnectorFactory.register_connector( + "P2pNcclConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.p2p_nccl_connector", + "P2pNcclConnector") + KVConnectorFactory.register_connector( "LMCacheConnectorV1", "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", diff --git a/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py b/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py new file mode 100644 index 000000000000..e59d3227f8ea --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import TYPE_CHECKING, List, Tuple, Union + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.utils import ( + model_aware_kv_ops_helper as kv_helper) +from vllm.distributed.kv_transfer.kv_pipe.p2p_nccl_pipe import P2pNcclPipe +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class P2pConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + self.rank = rank + self.config = config.kv_transfer_config + self.kv_helper = kv_helper(config) + + assert self.config.kv_connector == "P2pConnector" + + self.lookup_buffer_size = self.config.kv_buffer_size + + self.p2p_nccl_pipe = P2pNcclPipe( + local_rank=local_rank, + config=self.config, + hostname="", + port_offset=rank, + ) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + # input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + request_ids = list(model_input.request_ids_to_seq_ids.keys()) + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + num_heads, head_size = self.kv_helper.get_model_args(model_executable) + + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + # current_tokens = input_tokens_tensor[start_pos:end_pos] + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + key_cache, value_cache = self.kv_helper.get_kv_from_cache( + kv_cache, num_heads, head_size) + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + kvcache = torch.stack((keys, values), dim=0) + + request_id = request_ids[idx] + ip, port = self.parse_request_id(request_id, True) + remote_address = ip + ":" + str(port + self.rank) + + self.p2p_nccl_pipe.send_tensor(request_id + "kv", kvcache, + remote_address) + self.p2p_nccl_pipe.send_tensor( + request_id + "hidden", + hidden_or_intermediate_states[start_pos:end_pos], + remote_address) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + bypass_model_exec = True + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + request_ids = list(model_input.request_ids_to_seq_ids.keys()) + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + hidden_or_intermediate_states_for_one_req = [] + + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # This can happen during inflight batching. See: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You should set --enable_chunked_prefill=False " + "and --max_num_batched_tokens " + "should be equal to max_seq_len_to_capture") + bypass_model_exec = False + assert start_pos == num_prefill_tokens + break + + current_tokens = input_tokens_tensor[start_pos:end_pos] + + request_id = request_ids[idx] + ip, port = self.parse_request_id(request_id, False) + remote_address = ip + ":" + str(port + self.rank) + + kvcache = self.p2p_nccl_pipe.recv_tensor(request_id + "kv", + remote_address) + hidden = self.p2p_nccl_pipe.recv_tensor(request_id + "hidden", + remote_address) + + if kvcache is None or hidden is None: + # didn't find any match. + bypass_model_exec = False + continue + + num_computed_tokens = current_tokens.shape[0] + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # call self.kv_store to get kv layer by layer + for layer_id in range(start_layer, end_layer): + layer = model_executable.model.layers[layer_id] + # get kvcache object + kv_cache = kv_caches[layer_id - start_layer] + + # get remote kvcache + remote_k, remote_v = kvcache[0][layer_id], kvcache[1][layer_id] + + self.kv_helper.put_kv_to_cache(model_executable, remote_k, + remote_v, layer, kv_cache, + slot_mapping, start_pos, + end_pos) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + logger.warning( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + @staticmethod + def parse_request_id(request_id: str, is_prefill=True) -> Tuple[str, int]: + logger.debug("parse_request_id, request_id: %s, is_prefill: %s", + request_id, is_prefill) + # Regular expression to match the string hostname and integer port + if is_prefill: + pattern = r"___decode_addr_(.*):(\d+)" + else: + pattern = r"___prefill_addr_(.*):(\d+)___" + + # Use re.search to find the pattern in the request_id + match = re.search(pattern, request_id) + if match: + # Extract the ranks + ip = match.group(1) + port = int(match.group(2)) + + logger.debug("parse_request_id, request_id: %s, ip: %s, port: %s", + request_id, ip, str(port)) + return ip, port + raise ValueError( + f"Request id {request_id} does not contain hostname and port") + + def close(self) -> None: + self.p2p_nccl_pipe.close() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 95967d2ca919..1835a1bf1078 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -54,13 +54,19 @@ class KVConnectorMetadata: class KVConnectorBase_V1(ABC): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + def __init__(self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + rank: int = 0, + local_rank: int = 0): logger.warning( "Initializing KVConnectorBase_V1. This API is experimental and " "subject to change in the future as we iterate the design.") self._connector_metadata = KVConnectorMetadata() self._vllm_config = vllm_config self._role = role + self._rank = rank + self._local_rank = local_rank @property def role(self) -> KVConnectorRole: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p_nccl_connector.py new file mode 100644 index 000000000000..7fc7fd372122 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p_nccl_connector.py @@ -0,0 +1,374 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from dataclasses import dataclass +from typing import TYPE_CHECKING, Tuple + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.kv_transfer.kv_pipe.p2p_nccl_pipe import P2pNcclPipe +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import MLACommonMetadata +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class ReqMeta: + # Request Id + request_id: str + # Request tokens + token_ids: torch.Tensor + # Slot mappings, should have the same length as token_ids + slot_mapping: torch.Tensor + + @staticmethod + def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], + block_size: int) -> "ReqMeta": + valid_num_tokens = len(token_ids) + token_ids_tensor = torch.tensor(token_ids) + block_ids_tensor = torch.tensor(block_ids) + num_blocks = block_ids_tensor.shape[0] + block_offsets = torch.arange(0, block_size) + slot_mapping = block_offsets.reshape((1, block_size)) + \ + block_ids_tensor.reshape((num_blocks, 1)) * block_size + slot_mapping = slot_mapping.flatten()[:valid_num_tokens] + return ReqMeta( + request_id=request_id, + token_ids=token_ids_tensor, + slot_mapping=slot_mapping, + ) + + +@dataclass +class P2pNcclConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] + + def __init__(self): + self.requests = [] + + def add_request( + self, + request_id: str, + token_ids: list[int], + block_ids: list[int], + block_size: int, + ) -> None: + self.requests.append( + ReqMeta.make_meta(request_id, token_ids, block_ids, block_size)) + + +class P2pNcclConnector(KVConnectorBase_V1): + + def __init__(self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + rank: int = 0, + local_rank: int = 0): + super().__init__(vllm_config=vllm_config, + role=role, + rank=rank, + local_rank=local_rank) + self._block_size = vllm_config.cache_config.block_size + self._requests_need_load: dict[str, Request] = {} + self.config = vllm_config.kv_transfer_config + self.rank = rank + self.is_producer = self.config.is_kv_producer + + self.p2p_nccl_pipe = P2pNcclPipe( + local_rank=local_rank, + config=self.config, + hostname="", + port_offset=rank, + ) if role == KVConnectorRole.WORKER else None + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """Start loading the KV cache from the connector buffer to vLLM's + paged KV buffer. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + """ + assert self.p2p_nccl_pipe is not None + + attn_metadata = forward_context.attn_metadata + + def inject_kv_into_layer( + dst_kv_cache_layer: torch.Tensor, + src_kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + """Inject the KV cache into the layer. + + Args: + dst_kv_cache_layer (torch.Tensor): the destination KV cache + layer. In shape [2, num_pages, page_size, xxx] if not + using MLA, [num_pages, page_size, xxx] otherwise. + src_kv_cache (torch.Tensor): the source KV cache. In shape + [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] + otherwise. + slot_mapping (torch.Tensor): the slot mapping. In shape + [num_tokens]. + """ + dst_kv_cache_layer_shape = dst_kv_cache_layer.shape + if isinstance(attn_metadata, MLACommonMetadata): + num_pages = dst_kv_cache_layer_shape[0] + page_size = dst_kv_cache_layer_shape[1] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + num_pages * page_size, -1) + num_token = src_kv_cache.shape[0] + if len(slot_mapping) == num_token: + dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache + else: + dst_kv_cache_layer[slot_mapping[:num_token], ...] = src_kv_cache + logger.warning("🚧src_kv_cache does not match, num_slot:%d, num_token:%d", len(slot_mapping), num_token) + + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + else: + num_pages = dst_kv_cache_layer_shape[1] + page_size = dst_kv_cache_layer_shape[2] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + 2, num_pages * page_size, -1) + num_token = src_kv_cache.shape[1] + if len(slot_mapping) == num_token: + dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache + else: + dst_kv_cache_layer[:, slot_mapping[:num_token],...] = src_kv_cache + logger.warning("🚧src_kv_cache does not match, num_slot:%d, num_token:%d", len(slot_mapping), num_token) + + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + + # Get the metadata + metadata: KVConnectorMetadata = \ + self._get_connector_metadata() + assert isinstance(metadata, P2pNcclConnectorMetadata) + + if metadata is None: + logger.warning( + "In connector.start_load_kv, but the connector metadata is None" + ) + return + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + logger.warning( + "In connector.start_load_kv, but the attn_metadata is None") + return + + # Load the KV for each request each layer + for request in metadata.requests: + if self.is_producer: + continue + for layer_name in forward_context.no_compile_layers: + attn_layer = forward_context.no_compile_layers[layer_name] + kv_cache_layer = attn_layer.kv_cache[ \ + forward_context.virtual_engine] + + kv_cache = self.p2p_nccl_pipe.recv_tensor(request.request_id + + "-" + layer_name) + + if kv_cache is None: + logger.warning("🚧src_kv_cache is None, %s", + request.request_id) + continue + + inject_kv_into_layer(kv_cache_layer, kv_cache, + request.slot_mapping) + + logger.info("Inject KV cache of %d tokens to the paged memory, %s", + len(request.slot_mapping), request.request_id) + + def wait_for_layer_load(self, layer_name: str) -> None: + """Blocking until the KV for a specific layer is loaded into vLLM's + paged buffer. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + return + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """Start saving the KV cache of the layer from vLLM's paged buffer + to the connector. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + assert self.p2p_nccl_pipe is not None + + def extract_kv_from_layer( + layer: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> torch.Tensor: + """Extract the KV cache from the layer. + + Assume the shape of the layer is (2, num_pages, page_size, xxx) + if MLA is not used, and (num_pages, page_size, xxx) otherwise. + """ + if isinstance(attn_metadata, MLACommonMetadata): + num_pages, page_size = layer.shape[0], layer.shape[1] + return layer.reshape(num_pages * page_size, -1)[slot_mapping, + ...] + num_pages, page_size = layer.shape[1], layer.shape[2] + return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, + ...] + + connector_metadata = self._get_connector_metadata() + assert isinstance(connector_metadata, P2pNcclConnectorMetadata) + for request in connector_metadata.requests: + if self.is_producer: + request_id = request.request_id + ip, port = self.parse_request_id(request_id, True) + remote_address = ip + ":" + str(port + self._rank) + kv_cache = extract_kv_from_layer(kv_layer, + request.slot_mapping) + self.p2p_nccl_pipe.send_tensor(request_id + "-" + layer_name, + kv_cache, remote_address) + + def wait_for_save(self): + if self.is_producer: + assert self.p2p_nccl_pipe is not None + self.p2p_nccl_pipe.wait_for_sent() + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + if self.is_producer: + return 0 + + num_external_tokens = (len(request.prompt_token_ids) - 1 - + num_computed_tokens) + logger.info( + "🍒num_external_tokens:%d, num_prompt_tokens:%d, " + "num_computed_tokens:%d", num_external_tokens, + len(request.prompt_token_ids), num_computed_tokens) + + return num_external_tokens + + def update_state_after_alloc(self, request: "Request", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + + If blocks were allocated, add to _requests_need_load, + such that we load the KVs in the next forward pass. + """ + if not self.is_producer and num_external_tokens > 0: + self._requests_need_load[request.request_id] = request + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + """Build the connector metadata for this step. + + This function should NOT modify any fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + meta = P2pNcclConnectorMetadata() + + total_need_load = 0 + for new_req in scheduler_output.scheduled_new_reqs: + if new_req.req_id in self._requests_need_load: + meta.add_request(request_id=new_req.req_id, + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids, + block_size=self._block_size) + total_need_load += 1 + else: + if self.is_producer: + meta.add_request(request_id=new_req.req_id, + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids, + block_size=self._block_size) + + for cached_req in scheduler_output.scheduled_cached_reqs: + # NOTE(rob): here we rely on the resumed requests being + # the first N requests in the list scheduled_cache_reqs. + if not cached_req.resumed_from_preemption: + break + if cached_req.req_id in self._requests_need_load: + # NOTE(rob): cached_req_data does not have the full + # list of token ids (only new tokens). So we look it + # up in the actual request object. + request = self._requests_need_load[cached_req.req_id] + total_tokens = (len(cached_req.new_token_ids) + + cached_req.num_computed_tokens) + token_ids = request.all_token_ids[:total_tokens] + + # NOTE(rob): For resumed req, new_block_ids is all + # of the block_ids for the request. + block_ids = cached_req.new_block_ids + + meta.add_request(request_id=cached_req.req_id, + token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size) + total_need_load += 1 + + assert total_need_load == len(self._requests_need_load) + self._requests_need_load.clear() + return meta + + @staticmethod + def parse_request_id(request_id: str, is_prefill=True) -> Tuple[str, int]: + logger.debug("parse_request_id, request_id: %s, is_prefill: %s", + request_id, is_prefill) + # Regular expression to match the string hostname and integer port + if is_prefill: + pattern = r"___decode_addr_(.*):(\d+)" + else: + pattern = r"___prefill_addr_(.*):(\d+)___" + + # Use re.search to find the pattern in the request_id + match = re.search(pattern, request_id) + if match: + # Extract the ranks + ip = match.group(1) + port = int(match.group(2)) + + logger.debug("parse_request_id, request_id: %s, ip: %s, port: %s", + request_id, ip, str(port)) + return ip, port + raise ValueError( + f"Request id {request_id} does not contain hostname and port") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index f91ffbc720e7..6d3f284fee72 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -73,7 +73,11 @@ class SharedStorageConnector(KVConnectorBase_V1): # It does extra work which will overwrite the existing prefix-cache in GPU # - to remove the overhead, need to add some "mask" in the ReqMeta class - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + def __init__(self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + rank: int = 0, + local_rank: int = 0): super().__init__(vllm_config=vllm_config, role=role) self._block_size = vllm_config.cache_config.block_size self._requests_need_load: dict[str, Request] = {} diff --git a/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py new file mode 100644 index 000000000000..86d94dbf43e0 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py @@ -0,0 +1,470 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +import threading +import time +import typing +from collections import deque +from typing import Any, Deque, Dict, List, Optional + +import msgpack +import torch +import zmq + +from vllm.config import KVTransferConfig +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) +from vllm.distributed.kv_transfer.tensor_memory_pool import ( + TensorMemoryPool) +from vllm.utils import current_stream, get_ip + +logger = logging.getLogger(__name__) + + +class P2pNcclPipe: + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + hostname: str = "", + port_offset: int = 0, + library_path: Optional[str] = None) -> None: + self.config = config + self.rank = port_offset + self.local_rank = local_rank + self.device = torch.device(f"cuda:{self.local_rank}") + self.nccl = NCCLLibrary(library_path) + + if not hostname: + hostname = get_ip() + port = self.config.kv_port + port_offset + if port == 0: + raise ValueError("Port cannot be 0") + self._hostname = hostname + self._port = port + + # Each card corresponds to a ZMQ address. + self.zmq_address = f"{self._hostname}:{self._port}" + + # The `http_port` must be consistent with the port of OpenAI. + self.http_address = ( + f"{self._hostname}:" + f"{self.config.kv_connector_extra_config['http_port']}") + + # If `proxy_ip` or `proxy_port` is `""`, + # then the ping thread will not be enabled. + proxy_ip = self.config.get_from_extra_config("proxy_ip", "") + proxy_port = self.config.get_from_extra_config("proxy_port", "") + if proxy_ip == "" or proxy_port == "": + self.proxy_address = "" + else: + self.proxy_address = proxy_ip + ":" + proxy_port + + self.context = zmq.Context() + self.router_socket = self.context.socket(zmq.ROUTER) + self.router_socket.bind(f"tcp://{self.zmq_address}") + + self.poller = zmq.Poller() + self.poller.register(self.router_socket, zmq.POLLIN) + + self.send_store_cv = threading.Condition() + self.send_queue_cv = threading.Condition() + self.recv_store_cv = threading.Condition() + + self.send_stream = torch.cuda.Stream() + self.recv_stream = torch.cuda.Stream() + + mem_pool_size = self.config.get_from_extra_config("mem_pool_size", 128) + self.pool = TensorMemoryPool(max_block_size=mem_pool_size * 1024**3) # GB + + # The sending type includes tree mutually exclusive options: + # PUT, GET, PUT_ASYNC. + self.send_type = self.config.get_from_extra_config("send_type", "PUT") + if self.send_type == "GET": + self.send_store: Dict[str, + torch.Tensor] = {} # tensor_id: torch.Tensor + else: + # PUT or PUT_ASYNC + self.send_queue: Deque[ + List[Any]] = deque() # tensor_id: torch.Tensor + if self.send_type == "PUT_ASYNC": + self._send_thread = threading.Thread(target=self._send_async, + daemon=True) + self._send_thread.start() + + # tensor_id: torch.Tensor/(addr, dtype, shape) + self.recv_store: Dict[str, Any] = {} + self.socks: Dict[str, Any] = {} # remote_address: client socket + self.comms: Dict[str, Any] = {} # remote_address: (ncclComm_t, rank) + + self.buffer_size = 0 + self.buffer_size_threshold = self.config.kv_buffer_size + + self._listener_thread = threading.Thread( + target=self._listen_for_requests, daemon=True) + self._listener_thread.start() + + self._ping_thread = None + if port_offset == 0 and self.proxy_address != "": + self._ping_thread = threading.Thread(target=self._ping, + daemon=True) + self._ping_thread.start() + + def _create_connect(self, remote_address: typing.Optional[str] = None): + assert remote_address is not None + if remote_address not in self.socks: + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + sock.connect(f"tcp://{remote_address}") + self.socks[remote_address] = sock + if remote_address in self.comms: + logger.info("👋comm exists, remote_address:%s, comms:%s", + remote_address, self.comms) + return sock, self.comms[remote_address] + + unique_id = self.nccl.ncclGetUniqueId() + data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)} + sock.send(msgpack.dumps(data)) + + with torch.cuda.device(self.device): + rank = 0 + comm: ncclComm_t = self.nccl.ncclCommInitRank( + 2, unique_id, rank) + self.comms[remote_address] = (comm, rank) + logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s", + self.zmq_address, remote_address, rank) + + return self.socks[remote_address], self.comms[remote_address] + + def send_tensor( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + return True + else: + if self.send_type == "PUT": + return self._send_sync(tensor_id, tensor, remote_address) + elif self.send_type == "PUT_ASYNC": + with self.send_queue_cv: + self.send_queue.append([tensor_id, remote_address, tensor]) + self.send_queue_cv.notify() + else: # GET + with self.send_store_cv: + tensor_size = tensor.element_size() * tensor.numel() + while (self.buffer_size + tensor_size + > self.buffer_size_threshold): + oldest_tenser_id = next(iter(self.send_store)) + oldest_tenser = self.send_store.pop(oldest_tenser_id) + oldest_tenser_size = oldest_tenser.element_size( + ) * oldest_tenser.numel() + self.buffer_size -= oldest_tenser_size + logger.info( + "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," + " buffer_size:%d, oldest_tenser_size:%d, rank:%d", + remote_address, tensor_id, tensor_size, + self.buffer_size, oldest_tenser_size, self.rank) + + self.send_store[tensor_id] = tensor + self.buffer_size += tensor_size + logger.info( + "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " + "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", + remote_address, tensor_id, tensor_size, tensor.shape, + self.rank, self.buffer_size, + self.buffer_size / self.buffer_size_threshold * 100) + + return True + + def recv_tensor( + self, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ) -> torch.Tensor: + if self.send_type == "PUT" or self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.recv_store_cv: + while tensor_id not in self.recv_store: + self.recv_store_cv.wait() + tensor = self.recv_store[tensor_id] + self.recv_store[tensor_id] = None + while len(self.recv_store) > 10000: + self.recv_store.pop(next(iter(self.recv_store))) + + if tensor is not None: + if isinstance(tensor, tuple): + addr, dtype, shape = tensor + tensor = self.pool.load_tensor(addr, dtype, shape, + self.device) + else: + addr = 0 + self.buffer_size -= (tensor.element_size() * + tensor.numel()) + duration = time.time() - start_time + logger.info( + "🔵[PUT]Recv From %s, tensor_id:%s, shape:%s, " + "duration:%.3fms, size:%.3fGB, addr:%d, rank:%d", + remote_address, tensor_id, tensor.shape, duration * 1000, + tensor.element_size() * tensor.numel() / 1024**3, + addr, self.rank) + else: + duration = time.time() - start_time + logger.warning( + "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, " + "rank:%d", remote_address, tensor_id, duration * 1000, + self.rank) + return tensor + + # GET + if remote_address is None: + return None + + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + + data = {"cmd": "GET", "tensor_id": tensor_id} + sock.send(msgpack.dumps(data)) + + message = sock.recv() + data = msgpack.loads(message) + if data["ret"] != 0: + logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d", + remote_address, tensor_id, data["ret"]) + return None + + tensor = torch.empty(data["shape"], + dtype=getattr(torch, data["dtype"]), + device=self.device) + + start_time = time.time() + self._recv(comm, tensor, rank ^ 1, self.recv_stream) + duration = time.time() - start_time + logger.info( + "🔵[GET]Recv From %s, tensor_id:%s, shape:%s, duration:%.3fms, " + "size:%.3fGB, rank:%d", remote_address, tensor_id, tensor.shape, + duration * 1000, + tensor.element_size() * tensor.numel() / 1024**3, self.rank) + + return tensor + + def _listen_for_requests(self): + while True: + socks = dict(self.poller.poll()) + if self.router_socket in socks: + remote_address, message = self.router_socket.recv_multipart() + data = msgpack.loads(message) + logger.debug("Received message from %s, data:%s", + remote_address.decode(), data) + if data["cmd"] == "NEW": + unique_id = self.nccl.unique_id_from_bytes( + bytes(data["unique_id"])) + with torch.cuda.device(self.device): + rank = 1 + comm: ncclComm_t = self.nccl.ncclCommInitRank( + 2, unique_id, rank) + self.comms[remote_address.decode()] = (comm, rank) + logger.info( + "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", + self.zmq_address, remote_address.decode(), rank) + elif data["cmd"] == "PUT": + tensor_id = data["tensor_id"] + try: + tensor = torch.empty(data["shape"], + dtype=getattr( + torch, data["dtype"]), + device=self.device) + self.router_socket.send_multipart( + [remote_address, b"0"]) + comm, rank = self.comms[remote_address.decode()] + self._recv(comm, tensor, rank ^ 1, self.recv_stream) + tensor_size = tensor.element_size() * tensor.numel() + if (self.buffer_size + tensor_size + > self.buffer_size_threshold): + # Store Tensor in memory pool + addr = self.pool.store_tensor(tensor) + tensor = (addr, tensor.dtype, tensor.shape) + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Threshold, " + "%s👈%s, data:%s, addr:%d", self.zmq_address, + remote_address.decode(), data, addr) + else: + self.buffer_size += tensor_size + logger.info( + "🔵[PUT]Recv Tensor, %s👈%s, MyRank:%s, " + "data:%s, shape:%s", self.zmq_address, + remote_address.decode(), rank, data, + tensor.shape) + + except torch.cuda.OutOfMemoryError: + self.router_socket.send_multipart( + [remote_address, b"1"]) + tensor = None + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " + "data:%s", self.zmq_address, + remote_address.decode(), data) + + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + + elif data["cmd"] == "GET": + tensor_id = data["tensor_id"] + with self.send_store_cv: + tensor = self.send_store.pop(tensor_id, None) + if tensor is not None: + data = { + "ret": 0, + "shape": tensor.shape, + "dtype": + str(tensor.dtype).replace("torch.", "") + } + # LRU + self.send_store[tensor_id] = tensor + else: + data = {"ret": 1} + + self.router_socket.send_multipart( + [remote_address, msgpack.dumps(data)]) + + if data["ret"] == 0: + comm, rank = self.comms[remote_address.decode()] + self._send(comm, tensor.to(self.device), rank ^ 1, + self.send_stream) + + logger.info( + "🔵[GET]Send Tensor, %s👉%s, " + "MyRank:%s, data:%s", self.zmq_address, + remote_address.decode(), rank, data) + else: + logger.warning( + "🚧Unexpected, Received message from %s, data:%s", + remote_address, data) + + def _send_async(self): + while True: + with self.send_queue_cv: + while not self.send_queue: + self.send_queue_cv.wait() + tensor_id, remote_address, tensor = self.send_queue.popleft() + if not self.send_queue: + self.send_queue_cv.notify() + self._send_sync(tensor_id, tensor, remote_address) + + def wait_for_sent(self): + if self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.send_queue_cv: + while self.send_queue: + self.send_queue_cv.wait() + duration = time.time() - start_time + logger.info( + "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" + " to be empty, rank:%d", duration * 1000, self.rank) + + def _send_sync( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + return False + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + data = { + "cmd": "PUT", + "tensor_id": tensor_id, + "shape": tensor.shape, + "dtype": str(tensor.dtype).replace("torch.", "") + } + sock.send(msgpack.dumps(data)) + + response = sock.recv() + if response != b"0": + # with self.send_queue_cv: + # self.send_queue.append([tensor_id, remote_address, tensor]) + # self.send_queue_cv.notify() + logger.warning( + "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, " + "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s", + self.zmq_address, remote_address, rank, data, tensor.shape, + tensor.element_size() * tensor.numel() / 1024**3, + response.decode()) + return False + + self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) + logger.info("🔵Send Tensor, %s👉%s, MyRank:%s, data:%s, tensor:%s", + self.zmq_address, remote_address, rank, data, tensor.shape) + return True + + def _ping(self): + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + logger.debug("ping start, zmq_address:%s", self.zmq_address) + sock.connect(f"tcp://{self.proxy_address}") + data = { + "type": "P" if self.config.is_kv_producer else "D", + "http_address": self.http_address, + "zmq_address": self.zmq_address + } + while True: + sock.send(msgpack.dumps(data)) + time.sleep(3) + + def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + start_time = time.time() + with torch.cuda.stream(stream): + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + comm, cudaStream_t(stream.cuda_stream)) + stream.synchronize() + + duration = time.time() - start_time + logger.info( + "🕐Nccl Send Tensor, shape:%s, duration:%.3fms, size:%.3fGB, " + "rank:%d", tensor.shape, duration * 1000, + tensor.element_size() * tensor.numel() / 1024 ** 3, dst) + + def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + start_time = time.time() + with torch.cuda.stream(stream): + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + comm, cudaStream_t(stream.cuda_stream)) + stream.synchronize() + duration = time.time() - start_time + logger.info( + "🕐Nccl Recv Tensor, shape:%s, duration:%.3fms, size:%.3fGB, " + "rank:%d", tensor.shape, duration * 1000, + tensor.element_size() * tensor.numel() / 1024 ** 3, src) + + def close(self) -> None: + self._listener_thread.join() + if self.send_type == "PUT_ASYNC": + self._send_thread.join() + if self._ping_thread is not None: + self._ping_thread.join() diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index 25d2f2cf5c6e..96fb4bcf814b 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -61,7 +61,10 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: and _KV_CONNECTOR_AGENT is None): if envs.VLLM_USE_V1: _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1( - config=vllm_config, role=KVConnectorRole.WORKER) + config=vllm_config, + role=KVConnectorRole.WORKER, + rank=get_world_group().rank, + local_rank=get_world_group().local_rank) else: _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0( rank=get_world_group().rank, diff --git a/vllm/distributed/kv_transfer/tensor_memory_pool.py b/vllm/distributed/kv_transfer/tensor_memory_pool.py new file mode 100644 index 000000000000..685d36f19e99 --- /dev/null +++ b/vllm/distributed/kv_transfer/tensor_memory_pool.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 + +import atexit +import ctypes +import math +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import torch +from vllm import _custom_ops as ops + + +@dataclass +class MemoryBlock: + size: int + addr: int + + +class TensorMemoryPool: + def __init__(self, max_block_size: int, min_block_size: int = 512): + if max_block_size <= 0 or min_block_size <= 0: + raise ValueError("Block sizes must be positive") + if max_block_size < min_block_size: + raise ValueError("Max block size must be greater than min block size") + + self.max_block_size = self._round_to_power_of_two(max_block_size) + self.min_block_size = self._round_to_power_of_two(min_block_size) + + self.free_lists: Dict[int, Dict[int, MemoryBlock]] = {} + self.allocated_blocks: Dict[int, MemoryBlock] = {} + + self._initialize_free_lists() + self._allocate_pinned_memory() + + atexit.register(self.cleanup) + + self.store_stream = torch.cuda.Stream() + self.load_stream = torch.cuda.Stream() + + def _round_to_power_of_two(self, size: int) -> int: + return 1 << (size - 1).bit_length() + + def _initialize_free_lists(self): + size = self.max_block_size + while size >= self.min_block_size: + self.free_lists[size] = {} + size //= 2 + + def _allocate_pinned_memory(self): + self.base_tensor = torch.empty(self.max_block_size // 4, dtype=torch.float32, pin_memory=True) + self.base_address = self.base_tensor.data_ptr() + initial_block = MemoryBlock(size=self.max_block_size, addr=self.base_address) + self.free_lists[self.max_block_size][initial_block.addr] = initial_block + print("TensorMemoryPool, base_address:", self.base_address, self.base_address % self.max_block_size) + + def allocate(self, size: int) -> int: + if size <= 0: + raise ValueError("Allocation size must be positive") + + required_size = self._round_to_power_of_two(max(size, self.min_block_size)) + if required_size > self.max_block_size: + raise MemoryError("Requested size exceeds maximum block size") + + current_size = required_size + while current_size <= self.max_block_size: + if self.free_lists[current_size]: + _, block = self.free_lists[current_size].popitem() + self._split_block(block, required_size) + self.allocated_blocks[block.addr] = block + return block.addr + current_size *= 2 + + raise MemoryError("Insufficient memory") + + def _split_block(self, block: MemoryBlock, required_size: int): + while block.size > required_size and block.size // 2 >= self.min_block_size: + buddy_size = block.size // 2 + buddy_addr = block.addr + buddy_size + + buddy = MemoryBlock(size=buddy_size, addr=buddy_addr) + block.size = buddy_size + + self.free_lists[buddy_size][buddy.addr] = buddy + + def free(self, addr: int): + if addr not in self.allocated_blocks: + raise ValueError("Invalid address to free") + + block = self.allocated_blocks.pop(addr) + self._merge_buddies(block) + + def _merge_buddies(self, block: MemoryBlock): + MAX_MERGE_DEPTH = 30 + depth = 0 + + while depth < MAX_MERGE_DEPTH: + buddy_offset = block.size if (block.addr - self.base_address) % (2 * block.size) == 0 else -block.size + buddy_addr = block.addr + buddy_offset + buddy = self.free_lists[block.size].get(buddy_addr) + if buddy: + del self.free_lists[buddy.size][buddy.addr] + merged_addr = min(block.addr, buddy.addr) + merged_size = block.size * 2 + block = MemoryBlock(size=merged_size, addr=merged_addr) + depth += 1 + else: + break + self.free_lists[block.size][block.addr] = block + + def store_tensor(self, tensor: torch.Tensor) -> int: + if not tensor.is_cuda: + raise ValueError("Only CUDA tensors can be stored") + + size = tensor.element_size() * tensor.numel() + addr = self.allocate(size) + block = self.allocated_blocks[addr] + + if block.size < size: + self.free(addr) + raise MemoryError(f"Allocated block size {block.size} is smaller than required size {size}") + + try: + buffer = (ctypes.c_byte * block.size).from_address(block.addr) + cpu_tensor = torch.frombuffer(buffer, dtype=tensor.dtype, count=tensor.numel()) + except ValueError as e: + self.free(addr) + raise MemoryError(f"Failed to create tensor view: {e}") + + with torch.cuda.stream(self.store_stream): + ops.store_tensor(tensor, cpu_tensor) + self.store_stream.synchronize() + + return addr + + def load_tensor(self, addr: int, dtype: torch.dtype, shape: Tuple[int, ...], device) -> torch.Tensor: + if addr not in self.allocated_blocks: + raise ValueError("Invalid address to load") + + block = self.allocated_blocks[addr] + num_elements = math.prod(shape) + dtype_size = torch.tensor([], dtype=dtype).element_size() + required_size = num_elements * dtype_size + + if required_size > block.size: + raise ValueError("Requested tensor size exceeds block size") + + buffer = (ctypes.c_byte * block.size).from_address(block.addr) + cpu_tensor = torch.frombuffer(buffer, dtype=dtype, count=num_elements) + + cuda_tensor = torch.empty(shape, dtype=dtype, device=device) + + with torch.cuda.stream(self.load_stream): + ops.load_tensor(cpu_tensor, cuda_tensor) + self.load_stream.synchronize() + + self.free(addr) + + return cuda_tensor + + def cleanup(self): + self.free_lists.clear() + self.allocated_blocks.clear() + if hasattr(self, 'base_tensor'): + del self.base_tensor + + def __del__(self): + self.cleanup()