Skip to content

Commit 288ca11

Browse files
authored
[Security] Serialize using safetensors instead of pickle in Mooncake Pipe (#14228)
Signed-off-by: KuntaiDu <[email protected]>
1 parent c2bd219 commit 288ca11

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
import json
44
import os
5-
import pickle
65
from concurrent.futures import ThreadPoolExecutor
76
from dataclasses import dataclass
87
from typing import Optional, Union
98

109
import torch
1110
import zmq
11+
from safetensors.torch import load as safetensors_load
12+
from safetensors.torch import save as safetensors_save
1213

1314
from vllm.config import KVTransferConfig
1415
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
@@ -237,14 +238,13 @@ def tensor_hash(self, tensor: torch.Tensor) -> int:
237238
return hash(tensor.data_ptr())
238239

239240
def _send_impl(self, tensor: torch.Tensor) -> None:
240-
"""Implement the tensor sending logic."""
241-
value_bytes = pickle.dumps(tensor)
242-
self.transfer_engine.send_bytes(value_bytes)
241+
"""Implement the tensor sending logic using safetensors."""
242+
self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor}))
243243

244244
def _recv_impl(self) -> torch.Tensor:
245-
"""Implement the tensor receiving logic."""
245+
"""Implement the tensor receiving logic using safetensors."""
246246
data = self.transfer_engine.recv_bytes()
247-
return pickle.loads(data)
247+
return safetensors_load(data)["tensor"].to(self.device)
248248

249249
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
250250
"""Send tensor to the target process."""

0 commit comments

Comments
 (0)