Skip to content

[Bug]: Pipeline parallel with Ray ADAG doesnot work #12026

@charles9304

Description

@charles9304

Your current environment

vllm version: vllm-0.6.6.post1

from pathlib import Path
import json
import asyncio
from vllm import LLM, AsyncLLMEngine, AsyncEngineArgs, SamplingParams
import torch
from transformers import AutoTokenizer
from time import time
from uuid import uuid4

import argparse

models = {
    "llama2-7b": "/path/to/llama"
}


def example_to_prompt(tokenizer, ex) -> str:

    prompt = tokenizer.apply_chat_template(
        [{"role": "user", "content": ex["query"]}],
        add_generation_prompt=True,
        tokenize=False,
    )

    return prompt

async def run_query(engine, params, query: str):
    request_id = uuid4()
    outputs = engine.generate(query, params, request_id)
    async for output in outputs:
        final_output = output
    responses = []
    for output in final_output.outputs:
        responses.append(output.text)
    return responses


async def process(engine, params, queries):
    tasks = [asyncio.create_task(run_query(engine, params, q)) for q in queries]
    results = []
    for task in asyncio.as_completed(tasks):
        result = await task
        results.append(result)
    return results


def main(args):
    tp_size = args.tensor_parallel_size
    pp_size = args.pipeline_parallel_size
    print(f"tp_size: {tp_size} pp_size: {pp_size}")
    enforce_eager = args.enforce_eager
    model_path = models["llama2-7b"]
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    # assert enforce_eager == False
    engine = AsyncLLMEngine.from_engine_args(
        AsyncEngineArgs(
            model=model_path,
            trust_remote_code=True,
            tensor_parallel_size=tp_size,
            pipeline_parallel_size=pp_size,
            dtype=torch.bfloat16,
            # *** for debug ****
            enforce_eager=enforce_eager,
            #max_num_seqs=32,
            disable_log_requests=True,
            disable_custom_all_reduce=True,
            distributed_executor_backend="ray",
            gpu_memory_utilization=0.85,
        )
    )
    num_prompts = args.num_prompts
    sample = args.num_sampling
    params = SamplingParams(n=sample, temperature=1, skip_special_tokens=False)

    input_file = "/path/to/data"
    with open(input_file) as f:
        raw_ex_batch = [json.loads(line) for line in f]
    raw_ex_batch = raw_ex_batch[:num_prompts]

    output_file = open(f"/path/to/logfile", "w")
    print(f"load {len(raw_ex_batch)} examples", file=output_file, flush=True)
    total_s = 0
    total_token = 0
    max_len = 0
    for idx in range(1):
        texts = [example_to_prompt(tokenizer, ex) for ex in raw_ex_batch]
        print(texts, file=output_file, flush=True)
        tic = time()
        ret_batch = asyncio.run(process(engine, params, texts))
        toc = time()

        duration_s = f"{toc - tic:.5f} s".ljust(10)
        num_done = f"{idx}".ljust(8)
        print(f"{duration_s} Processed={num_done}", file=output_file, flush=True)
        total_s += toc - tic
        print(f"total time : {total_s}", file=output_file, flush=True)
        print(f"average time per batch : {total_s / (idx+1)}", file=output_file, flush=True)
        print(ret_batch)
        print(len(ret_batch))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="vLLM Black-box Tester")

    parser.add_argument('--tensor-parallel-size', type=int, default=1)
    parser.add_argument('--pipeline-parallel-size', type=int, default=1)
    parser.add_argument('--num-prompts', type=int, default=1)
    parser.add_argument('--num-sampling', type=int, default=1)
    parser.add_argument('--enforce-eager', action='store_true')
    args = parser.parse_args()
    main(args)

Model Input Dumps

No response

🐛 Describe the bug

Sorry to bother you @ruisearch42 @rkooo567

We try to apply pipeline parallelism to llama2-7b in spmd mode, but found device placement error as follows:
image

Which might be supported since #6837 already, while #7099 deletes dag tests. Any progresses sharing is expected, thanks!

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingrayanything related with ray

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions