-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Closed
Labels
bugSomething isn't workingSomething isn't workingrayanything related with rayanything related with ray
Description
Your current environment
vllm version: vllm-0.6.6.post1
from pathlib import Path
import json
import asyncio
from vllm import LLM, AsyncLLMEngine, AsyncEngineArgs, SamplingParams
import torch
from transformers import AutoTokenizer
from time import time
from uuid import uuid4
import argparse
models = {
"llama2-7b": "/path/to/llama"
}
def example_to_prompt(tokenizer, ex) -> str:
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": ex["query"]}],
add_generation_prompt=True,
tokenize=False,
)
return prompt
async def run_query(engine, params, query: str):
request_id = uuid4()
outputs = engine.generate(query, params, request_id)
async for output in outputs:
final_output = output
responses = []
for output in final_output.outputs:
responses.append(output.text)
return responses
async def process(engine, params, queries):
tasks = [asyncio.create_task(run_query(engine, params, q)) for q in queries]
results = []
for task in asyncio.as_completed(tasks):
result = await task
results.append(result)
return results
def main(args):
tp_size = args.tensor_parallel_size
pp_size = args.pipeline_parallel_size
print(f"tp_size: {tp_size} pp_size: {pp_size}")
enforce_eager = args.enforce_eager
model_path = models["llama2-7b"]
tokenizer = AutoTokenizer.from_pretrained(model_path)
# assert enforce_eager == False
engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(
model=model_path,
trust_remote_code=True,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
dtype=torch.bfloat16,
# *** for debug ****
enforce_eager=enforce_eager,
#max_num_seqs=32,
disable_log_requests=True,
disable_custom_all_reduce=True,
distributed_executor_backend="ray",
gpu_memory_utilization=0.85,
)
)
num_prompts = args.num_prompts
sample = args.num_sampling
params = SamplingParams(n=sample, temperature=1, skip_special_tokens=False)
input_file = "/path/to/data"
with open(input_file) as f:
raw_ex_batch = [json.loads(line) for line in f]
raw_ex_batch = raw_ex_batch[:num_prompts]
output_file = open(f"/path/to/logfile", "w")
print(f"load {len(raw_ex_batch)} examples", file=output_file, flush=True)
total_s = 0
total_token = 0
max_len = 0
for idx in range(1):
texts = [example_to_prompt(tokenizer, ex) for ex in raw_ex_batch]
print(texts, file=output_file, flush=True)
tic = time()
ret_batch = asyncio.run(process(engine, params, texts))
toc = time()
duration_s = f"{toc - tic:.5f} s".ljust(10)
num_done = f"{idx}".ljust(8)
print(f"{duration_s} Processed={num_done}", file=output_file, flush=True)
total_s += toc - tic
print(f"total time : {total_s}", file=output_file, flush=True)
print(f"average time per batch : {total_s / (idx+1)}", file=output_file, flush=True)
print(ret_batch)
print(len(ret_batch))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="vLLM Black-box Tester")
parser.add_argument('--tensor-parallel-size', type=int, default=1)
parser.add_argument('--pipeline-parallel-size', type=int, default=1)
parser.add_argument('--num-prompts', type=int, default=1)
parser.add_argument('--num-sampling', type=int, default=1)
parser.add_argument('--enforce-eager', action='store_true')
args = parser.parse_args()
main(args)
Model Input Dumps
No response
🐛 Describe the bug
Sorry to bother you @ruisearch42 @rkooo567
We try to apply pipeline parallelism to llama2-7b in spmd mode, but found device placement error as follows:
Which might be supported since #6837 already, while #7099 deletes dag tests. Any progresses sharing is expected, thanks!
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingrayanything related with rayanything related with ray
Type
Projects
Status
Done