|
2 | 2 |
|
3 | 3 | import json
|
4 | 4 | import os
|
5 |
| -import pickle |
6 | 5 | from concurrent.futures import ThreadPoolExecutor
|
7 | 6 | from dataclasses import dataclass
|
8 | 7 | from typing import Optional, Union
|
9 | 8 |
|
10 | 9 | import torch
|
11 | 10 | import zmq
|
| 11 | +from safetensors.torch import load as safetensors_load |
| 12 | +from safetensors.torch import save as safetensors_save |
12 | 13 |
|
13 | 14 | from vllm.config import KVTransferConfig
|
14 | 15 | from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
@@ -237,14 +238,13 @@ def tensor_hash(self, tensor: torch.Tensor) -> int:
|
237 | 238 | return hash(tensor.data_ptr())
|
238 | 239 |
|
239 | 240 | def _send_impl(self, tensor: torch.Tensor) -> None:
|
240 |
| - """Implement the tensor sending logic.""" |
241 |
| - value_bytes = pickle.dumps(tensor) |
242 |
| - self.transfer_engine.send_bytes(value_bytes) |
| 241 | + """Implement the tensor sending logic using safetensors.""" |
| 242 | + self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor})) |
243 | 243 |
|
244 | 244 | def _recv_impl(self) -> torch.Tensor:
|
245 |
| - """Implement the tensor receiving logic.""" |
| 245 | + """Implement the tensor receiving logic using safetensors.""" |
246 | 246 | data = self.transfer_engine.recv_bytes()
|
247 |
| - return pickle.loads(data) |
| 247 | + return safetensors_load(data)["tensor"].to(self.device) |
248 | 248 |
|
249 | 249 | def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
250 | 250 | """Send tensor to the target process."""
|
|
0 commit comments