-
-
Notifications
You must be signed in to change notification settings - Fork 10.3k
Description
This RFC describes a proposal for interfaces and abstractions for distributed inference environments. I plan to solicit discussions for a week (until March 31st) before I begin to actually refactor the code.
Motivation
The current distributed inference environment in vllm
is quite tangled, and we often see deadlocks and hangs (see #3455 , #2770 , #3559 , to name a few). The problem becomes prominent when we try to upgrade to pytorch 2.2.0 (see #3442 , #3442 ), because pytorch 2.2.0
upgrades from nccl==2.18.1
to 2.19.3
(see https://pypi.org/pypi/torch/2.1.2/json and https://pypi.org/pypi/torch/2.2.0/json to compare the dependency), and nccl==2.19.3
breaks vllm
due to increased memory cost during cudagraph capture (from 10MB per graph to 100MB per graph, adds up to several GBs because we have dozens of cudagraph).
TL,DR; distributed inference in current codebase is a headache. If it works, hooray; if not, don't be surprised.
Proposal
Abstraction
I think we should have three levels of abstraction:
- Launcher, responsible for launching processes (potentially across multi-node). Currently it is
ray
, but we can also have another choices like Python's nativemultiprocessing
in single-node cases. See [Core] Multiprocessing executor for single-node multi-GPU deployment #3466 for example. - Coordinator, responsible for coordinating shared resources (e.g. filesystem usage) and broadcasting some messages. Currently we don't have this, and there are lots of hacks for ad-hoc implementation, e.g. use
filelock
to lock on filesystems ( [Bugfix] use SoftLockFile instead of LockFile #3578 ), useTCP
to initialize communication incupy
( Use CuPy for CUDA graphs #2811 ), useMPI
to initialize communication in AMD'scupy
version ( [ROCm] enable cupy in order to enable cudagraph mode for AMD GPUs #3123 ). - Communicator, responsible for cross-device communication of large tensor data (e.g. perform allreduce). Currently we support
nccl
, and AMD also has its own communication library. Note that this is vendor-specific, and vendors usually have their own way of cross-device communication.
The most messy one, and the missing one, is the Coordinator abstraction level. More on this later.
Interface
Between each consecutive abstractions, lies the interface.
Interface between Launcher and Coordinator
After Launcher launches processes, it needs to at least tell the processes the following information:
launch_id
, used to distinguish current launch with possibly concurrent launch (e.g. when 4 users want to set up 4 inference engines in the same node, each with 2 GPUs). Note: thelaunch_id
can be used as a "random seed" to draw values formaster_port
, instead of keeping only one defaultmaster_port
value and having to kill all processes after the last run crashes. A reference implementation would be hashing thelaunch_id
to a port number, and increasing the port number to find the first free port. This is a strategy taken by Jupyter Notebook/Lab Server .world_size
, number of processes participating in the current launch (may span over multiple nodes)local_world_size
, number of processes participating in the current launch in the current node (not necessarily the same across nodes)rank
, range from 0 (inclusive) toworld_size
(exclusive) , unique in the launch for each processlocal_rank
, range from 0 (inclusive) tolocal_world_size
(exclusive), unique in each node, can use this to assign devices in a node!master_addr
, the IP address of the master node, should be reachable from all nodesmaster_port
, a free port in the master node, reserved for possible coordination- other custom information can be added, but the above are required.
How does Launcher pass these information to each process? Basically we have two choices:
- through environment variables, the simplest way, but will disable the usage of thread-level distributed inference because environment variables are shared within threads in one process. (However, thread-level distributed inference seems rare. Do we need to consider this?)
- through serialization and deserialization (e.g. passing bytes in a shared object store), the most general way, at the cost of complexity and runtime efficiency to design and execute the serialization/deserialization
Interface between Coordinator and Communicator
Device communicators (e.g. nccl
) often need to initialize the communication by sharing some unique token (see nccl
documentation). In addition, processes sometimes need to coordinate the resource in a node or across the cluster.
In sight of the above consideration, Coordinator
should at least have the following interfaces:
is_master()
: tell if the current process is a master process, i.e. convenient wrapper for boilerplate coderank == 0
is_local_master()
: tell if the current process is a local master process, i.e. convenient wrapper for boilerplate codelocal_rank == 0
broadcast(bytes, src)
: broadcast some message (in the form ofbytes
) from ranksrc
to all the processes. The semantic is standard, no need for more explanation.barrier()
: block until all processes reaches here. Also standard communication primitive.
Note: very often than not, we want to execute something in just one process per node (e.g. creating directories, downloading files to the node). Inspired by this thread, we can write code like this:
if is_local_master():
do_something() # download file, create directory, etc.
barrier()
Furthermore, there are more complicated requirements like "only one process in each node does something, but this something is different across nodes", essentially the requirement of local_barrier()
, a function that block until all processes in the current node reaches here. It is debatable if we want this (currently I don't see any requirements like this in vllm
.)
Communicator interface
The following functionality of communicator is suggested (mostly taken from the nccl
design):
- the master process get unique token to identify the communication group
- the master process broadcast unique token to all ranks
- each process initializes communication by the unique token and their rank, world_size
- an in-place allreduce function:
allreduce(char* input, size_t count, size_t dtype, size_t op)
. More functionality would be better (e.g. out-of-place allreduce, broadcast/reduce/scatter etc.), but inplace allreduce is all we need currently.
The intended usage would be something like this:
# inside each process
coor = Coordinator(); # initialize Coordinator, done by vllm
comm = Communicator(coor) # hardware vendor can use `coor` to initialize their communicator
data = torch.tensor((1024, 1024)).to(device=f"xpu:{coor.local_rank}")
comm.allreduce(data) # hardware vendor can access the raw data via pytorch's [`Tensor.data_ptr`](https://pytorch.org/docs/stable/generated/torch.Tensor.data_ptr.html) mechanism.
# please implement Communicator.__del__ to destroy communicator, so that programs can exit gracefully
A reference implementation of Coordinator
A reference implementation of Coordinator can be torch.distributed
, with the gloo
backend designed to communicate CPU tensors.
Other considerations include MPI and custom-implemented TCP store. However, since we live in torch
framework, torch.distributed
is a natural choice without any new dependency.
Note: torch.distributed
can also be used as a fully functional communicator for GPU devices. However, torch.distributed.all_reduce
is way more complicated than just an allreduce operation. It might initialize autograd engine, might keep track of gradients, might dispatch to different device kernels. Even if we are in torch.inference_mode
, its c10
engine might perform some additional operations that fails functionalities like cudagraph. Therefore, I prefer to call vendor-provided communication libraries directly to bypass the problem. After all, we just want an allreduce operation on dense tensors, without any hustle and bustle.
Benefits
After we have the above abstraction and interface, we can have the following benefits:
- We are always in a distributed environment, just with different sizes of wold_size. Distributed concerns will always be considered, so that we can easily scale to multi-node environments (if any LLM needs this).
- Hardware vendors can plug in their communication libraries very easily. All they need to provide are: integration into pytorch
torch.Tensor
(only forward computation ops are enough), a c library (an .so file would be enough) for calling communication ops with raw data (i.e.char*
in c). And if they want to move quickly, just oneallreduce
op would be enough for inference. No need to wait for the whole functionality completed within pytorch.
Things not to be considered
We don't aim for a fully-fledged distributed execution environment. And since inference tasks are almost stateless, we don't need to consider elasticness and fault-tolerance. As opposed to training, we don't need to save checkpoints, we don't need to resume from previous failure ...