Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/_build_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ jobs:
python -m pip install -r requirements.txt
python -m pip install wheel
# 编译RDMA
export ENABLE_FD_RDMA=1
export FD_ENABLE_RDMA_COMPILE=1
bash build.sh 1 python false [${COMPILE_ARCH}]
ls ./dist/*.whl
'
Expand Down
16 changes: 15 additions & 1 deletion fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,12 @@ def _set_cudagraph_sizes(self, max_capture_size: int = 0):
draft_capture_sizes.append(max_capture_size)
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)

def filter_capture_size(self, tp_size: int = 1):
"""When TSP is used, capture size must be divisible by tp size."""
self.cudagraph_capture_sizes = [
draft_size for draft_size in self.cudagraph_capture_sizes if (draft_size % tp_size == 0)
]

def to_json_string(self):
"""
Convert speculative_config to json string.
Expand Down Expand Up @@ -1628,7 +1634,15 @@ def postprocess(self):
if self.device_config is not None and self.device_config.device_type != "cuda":
self.graph_opt_config.use_cudagraph = False
logger.info(f"CUDAGraph only support on GPU, current device type is {self.device_config.device_type}!")

if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph:
if self.scheduler_config.max_num_seqs < self.parallel_config.tensor_parallel_size:
self.parallel_config.use_sequence_parallel_moe = False
logger.info(
"Warning: sequence parallel moe do not support max_num_seqs < tensor_parallel_size when cudagraph enabled. We set use_sequence_parallel_moe to False."
)
else:
# It will hang when real batch_size < tp_size
self.graph_opt_config.filter_capture_size(tp_size=self.parallel_config.tensor_parallel_size)
Comment on lines +1637 to +1645
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

关于TSP+CUDAGraph的关键点在这儿

if self.model_config.enable_mm and self.graph_opt_config.use_cudagraph:
self.cache_config.enable_prefix_caching = False
logger.info("Multi-modal models do not support prefix caching when using CUDAGraph!")
Expand Down
6 changes: 4 additions & 2 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,10 @@ def __post_init__(self):
raise ValueError(
"Please set --rdma_comm_ports argument when using " "rdma cache transfer protocol."
)
if len(self.rdma_comm_ports) != self.tensor_parallel_size:
raise ValueError("The number of rdma comm ports must be equal to tensor parallel size.")
if len(self.rdma_comm_ports) != self.tensor_parallel_size * self.data_parallel_size:
raise ValueError(
f"The number of rdma comm ports must be equal to number of ranks ({self.data_parallel_size=} * {self.tensor_parallel_size=} = {self.data_parallel_size * self.tensor_parallel_size}), but got {len(self.rdma_comm_ports)}."
)

if envs.ENABLE_V1_KVCACHE_SCHEDULER == 1:
if "ipc" in self.cache_transfer_protocol:
Expand Down
16 changes: 10 additions & 6 deletions fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,10 +570,11 @@ def __init__(self, fd_config: FDConfig):
self.ernie = Ernie4_5_VLModel(fd_config=fd_config)

# Persistent buffers for CUDA graphs.
self._input_embeddings = paddle.zeros(
[fd_config.model_config.max_model_len, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)
if fd_config.graph_opt_config.use_cudagraph:
self._decoder_input_embeddings = paddle.zeros(
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)

self.ori_vocab_size = fd_config.model_config.ori_vocab_size

Expand Down Expand Up @@ -783,10 +784,13 @@ def forward(
image_features=image_features,
image_token_num=vl_moe_meta.num_image_patch_id.item(),
)
self._input_embeddings.copy_(input_embeddings, False)

if forward_meta.step_use_cudagraph:
self._decoder_input_embeddings.copy_(input_embeddings, False)
input_embeddings = self._decoder_input_embeddings

hidden_states = self.ernie(
input_embeddings=self._input_embeddings,
input_embeddings=input_embeddings,
ids_remove_padding=ids_remove_padding,
forward_meta=forward_meta,
vl_moe_meta=vl_moe_meta,
Expand Down
16 changes: 10 additions & 6 deletions fastdeploy/model_executor/models/ernie_vl_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def __init__(self, fd_config: FDConfig):
self.head_dtype = paddle.bfloat16

# Persistent buffers for CUDA graphs.
self._input_embeddings = paddle.zeros(
[fd_config.parallel_config.max_model_len, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)
if fd_config.graph_opt_config.use_cudagraph:
self._decoder_input_embeddings = paddle.zeros(
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)

self.rm_head = nn.Sequential(
(
Expand Down Expand Up @@ -112,10 +113,13 @@ def forward(
image_features=image_features,
image_token_num=vl_moe_meta.image_token_num.item(),
)
self._input_embeddings.copy_(input_embeddings, False)

if forward_meta.step_use_cudagraph:
self._decoder_input_embeddings.copy_(input_embeddings, False)
input_embeddings = self._decoder_input_embeddings

hidden_states = self.ernie(
input_embeddings=self._input_embeddings,
input_embeddings=input_embeddings,
ids_remove_padding=ids_remove_padding,
forward_meta=forward_meta,
vl_moe_meta=vl_moe_meta,
Expand Down
23 changes: 10 additions & 13 deletions fastdeploy/model_executor/models/paddleocr_vl/paddleocr_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,11 @@ def __init__(self, fd_config):
)

# Persistent buffers for CUDA graphs.
self._decoder_input_embeddings = paddle.zeros(
[fd_config.scheduler_config.max_num_seqs, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)
if fd_config.graph_opt_config.use_cudagraph:
self._decoder_input_embeddings = paddle.zeros(
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)

@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
Expand Down Expand Up @@ -242,15 +243,11 @@ def forward(

if forward_meta.step_use_cudagraph:
self._decoder_input_embeddings.copy_(input_embeddings, False)
input_embeddings = self._decoder_input_embeddings

hidden_states = self.model(
input_embeddings=self._decoder_input_embeddings,
forward_meta=forward_meta,
)
else:
hidden_states = self.model(
input_embeddings=input_embeddings,
forward_meta=forward_meta,
)
hidden_states = self.model(
input_embeddings=input_embeddings,
forward_meta=forward_meta,
)

return hidden_states
16 changes: 10 additions & 6 deletions fastdeploy/model_executor/models/qwen2_5_vl/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,11 @@ def __init__(self, fd_config: FDConfig):
self.model = Qwen2_5_VLModel(fd_config=fd_config)

# Persistent buffers for CUDA graphs.
self._input_embeddings = paddle.zeros(
[fd_config.model_config.max_model_len, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)
if fd_config.graph_opt_config.use_cudagraph:
self._decoder_input_embeddings = paddle.zeros(
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
dtype=fd_config.model_config.dtype,
)

self.ori_vocab_size = fd_config.model_config.ori_vocab_size

Expand Down Expand Up @@ -290,10 +291,13 @@ def forward(
input_embeddings = self.get_input_embeddings(
ids_remove_padding=ids_remove_padding, image_features=image_features
)
self._input_embeddings.copy_(input_embeddings, False)

if forward_meta.step_use_cudagraph:
self._decoder_input_embeddings.copy_(input_embeddings, False)
input_embeddings = self._decoder_input_embeddings

hidden_states = self.model(
input_embeddings=self._input_embeddings,
input_embeddings=input_embeddings,
ids_remove_padding=ids_remove_padding,
image_features=image_features,
forward_meta=forward_meta,
Expand Down
68 changes: 66 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# limitations under the License.
"""

import glob
import os
import re
import subprocess
import sys
from functools import lru_cache
from pathlib import Path

import paddle
Expand Down Expand Up @@ -180,6 +182,68 @@ def get_device_type():
return "cpu"


def check_header(header_path):
return os.path.exists(header_path)


def check_library(lib_name):
# search /usr/lib /usr/lib64 /lib /lib64 .etc
paths = [
"/usr/lib",
"/usr/lib32",
"/usr/lib64",
"/usr/lib/x86_64-linux-gnu",
"/lib",
"/lib32",
"/lib64",
"/usr/local/lib",
"/usr/local/lib64",
]
for p in paths:
if glob.glob(os.path.join(p, lib_name)):
return True
return False


def check_rdma_packages():
results = {}

# libibverbs-dev
results["libibverbs header"] = check_header("/usr/include/infiniband/verbs.h")
results["libibverbs library"] = check_library("libibverbs.so*") or check_library("libibverbs.so")

# librdmacm-dev
results["librdmacm header"] = check_header("/usr/include/rdma/rdma_cma.h")
results["librdmacm library"] = check_library("librdmacm.so*") or check_library("librdmacm.so")

print("===== RDMA Library Check Results =====")
for k, v in results.items():
status = "FOUND" if v else "NOT FOUND"
print(f"{k:25}: {status}")

print("\n== Summary ==")
if all(results.values()):
print("All required RDMA libraries are installed.")
return True
else:
print("Some RDMA libraries are missing. Suggested commands:")
print("\nUbuntu/Debian:")
print(" sudo apt-get install -y libibverbs-dev librdmacm-dev")
print("\nCentOS/RHEL:")
print(" sudo yum install -y libibverbs-devel librdmacm-devel")
return False


@lru_cache(maxsize=1)
def rdma_comm_supported():
supported = (
get_device_type() in ["gpu", "xpu"]
and check_rdma_packages()
and os.getenv("FD_ENABLE_RDMA_COMPILE", "1") == "1"
)
return supported


def get_name():
"""get package name"""
return "fastdeploy-" + get_device_type()
Expand Down Expand Up @@ -237,10 +301,10 @@ def write_version_to_file():
version=None,
)
]
if os.getenv("ENABLE_FD_RDMA", "0") == "1"
if rdma_comm_supported()
else []
),
cmdclass=cmdclass_dict if os.getenv("ENABLE_FD_RDMA", "0") == "1" else {},
cmdclass=cmdclass_dict if rdma_comm_supported() else {},
zip_safe=False,
classifiers=[
"Programming Language :: Python :: 3",
Expand Down
Loading