Skip to content

Commit 3c18115

Browse files
committed
Optimize vLLM weight reloading using collective_rpc
Use vLLM's collective_rpc API to reload weights without recreating the entire engine. This provides significant performance improvements: - Weight reload: ~0.7-0.9s (vs ~7-10s for full engine recreation) - Preserves KV cache, kernels, and memory allocations - Reduces memory fragmentation Changes: - Update VLLMRolloutEngine.update_weights() to use collective_rpc("reload_weights") instead of recreating engine The reload mechanism saves updated weights to disk, then calls reload_weights() on all workers via RPC, maintaining bitwise determinism while avoiding expensive engine recreation. Note: Requires VLLM_ALLOW_INSECURE_SERIALIZATION=1 environment variable for collective_rpc with custom functions.
1 parent 62b20b6 commit 3c18115

File tree

1 file changed

+4
-14
lines changed

1 file changed

+4
-14
lines changed

torchtitan/experiments/deterministic_vllm_rl/simple_rl.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -145,21 +145,11 @@ def update_weights(self, vllm_compat_state: dict) -> None:
145145
seed=42, # Fixed seed for determinism
146146
enforce_eager=True,
147147
)
148+
print("✓ Created new vLLM engine")
148149
else:
149-
# vLLM V1's reload_weights() is broken - it doesn't actually reload from disk
150-
# The only reliable way is to recreate the engine
151-
del self.llm
152-
torch.cuda.empty_cache()
153-
154-
self.llm = LLM(
155-
model=self.temp_model_dir,
156-
trust_remote_code=True,
157-
max_model_len=2048,
158-
dtype="bfloat16",
159-
gpu_memory_utilization=0.3,
160-
seed=42, # Fixed seed for determinism
161-
enforce_eager=True,
162-
)
150+
# Use collective_rpc to call reload_weights on all workers
151+
# This reloads weights from temp_model_dir without recreating the engine
152+
self.llm.collective_rpc("reload_weights")
163153

164154
@torch.no_grad()
165155
def generate(

0 commit comments

Comments
 (0)