Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions examples/offline_inference/disaggregated_prefill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
This file demonstrates the example usage of disaggregated prefilling
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
and then transfer the KV cache between them.

Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
"""
import os
import time
from multiprocessing import Event, Process

from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig


def run_prefill(prefill_done):
# We use GPU 0 for prefill node.
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# The prefill node receives two requests, while the decode node receives
# three requests. So the decode node will only receive the KV Cache for
# requests 1 and 3. The decode node will use the KV Cache of requests 1
# and 3 and do prefilling on request 2.
prompts = [
"Hello, my name is",
# "Hi, your name is", # To trigger partial prefill of batched requests
"Tell me a very long story",
]
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)

ktc = KVTransferConfig.from_cli(
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
)
# Example: Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct",
kv_transfer_config=ktc,
max_model_len=2000,
gpu_memory_utilization=0.8)

llm.generate(prompts, sampling_params)
print("Prefill node is finished.")
prefill_done.set()

# To keep the prefill node running in case the decode node is not done
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("Script stopped by user.")


def run_decode(prefill_done):
# We use GPU 1 for decode node.
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

prompts = [
"Hello, my name is",
"Hi, your name is",
"Tell me a very long story",
]
sampling_params = SamplingParams(temperature=0, top_p=0.95)

ktc = KVTransferConfig.from_cli(
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
)
# Example: Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory.
llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct",
kv_transfer_config=ktc,
max_model_len=2000,
gpu_memory_utilization=0.8)

# Wait for the producer to start the pipe
print("Waiting for prefill node to finish...")
prefill_done.wait()

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == "__main__":
prefill_done = Event()
prefill_process = Process(target=run_prefill, args=(prefill_done, ))
decode_process = Process(target=run_decode, args=(prefill_done, ))

# Start prefill node
prefill_process.start()

# Start decode node
decode_process.start()

# Terminate the prefill node when decode is finished
decode_process.join()
prefill_process.terminate()
Loading