-
-
Notifications
You must be signed in to change notification settings - Fork 10.3k
[Core] Upgrade to pytorch 2.2, remove cupy dependency, avoid nccl 2.19 bug #3442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1a7d826
7de363f
9bc921d
76ab3e7
0109fd2
4c616a8
922aa0c
193d73a
bfcc926
584e6ef
daca4e1
015b7d4
fef9e03
cf400cb
d77e855
75f05de
e82cf3a
6d10bf5
4accd02
a92346f
e7f215b
f99fe2a
0f3181f
6ef3843
7db0e1b
62650ae
4ed16b9
2d215df
0f6f243
da1df5e
b4085a1
f77c9ae
2766418
0e18aed
7c531b0
bbe3622
1abf38e
5d661a6
99f96d7
b567f04
74fcf08
43da101
37e7425
7e983f5
e3f8d5f
4e277ae
a20d802
dfc9d82
8a5a011
a009e31
68e4792
0a6fab1
a82a976
1c6ec48
47ff82a
b0c15c2
0b4f7dd
7942050
20a3ec4
0ca27b7
a3c2340
71e2976
3d9332a
76f46f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ include CMakeLists.txt | |
|
||
recursive-include cmake * | ||
recursive-include csrc * | ||
recursive-include vllm/lib * |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,5 +3,5 @@ cmake>=3.21 | |
ninja | ||
packaging | ||
setuptools>=49.4.0 | ||
torch==2.1.2 | ||
torch==2.2.1 | ||
wheel |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,12 @@ | |
from shutil import which | ||
import torch | ||
from torch.utils.cpp_extension import CUDA_HOME | ||
import zipfile | ||
import shutil | ||
import logging | ||
import tempfile | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
ROOT_DIR = os.path.dirname(__file__) | ||
|
||
|
@@ -188,6 +194,48 @@ def _install_punica() -> bool: | |
return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))) | ||
|
||
|
||
if _is_cuda(): | ||
|
||
# tricky part, nccl 2.19 has a bug that increased memory overhead | ||
# of cudagraph. However, pytorch has binary dependencies on nccl 2.19, | ||
# simply `pip install nvidia-nccl-cu12==2.18.3` will break pytorch, | ||
# so we have to manually download nccl 2.18 and keep the library to | ||
# a secrect place | ||
|
||
# Define the URL of the file and the directory to unzip to | ||
file_url = ('https://files.pythonhosted.org/packages/44/6e/' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider using a constant? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I wonder if we can support env var so that we can also decide to load arbitrary .so instead of always downloading our package when we build There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. detecting w.r.t. downloading our package when we build, we have to do this because nccl brought by torch==2.2.0 does not work. |
||
'3c9cd7007072f8a63dae7b5eddd1cc1525fd357377467ce3a4749b02d5ff' | ||
'/nvidia_nccl_cu12-2.18.3-py3-none-manylinux1_x86_64.whl') | ||
|
||
logger.info('Installing NVIDIA NCCL library...') | ||
|
||
target_dir = os.path.dirname(os.path.abspath(__file__)) + "/vllm/lib/" | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
local_zip_path = ( | ||
f"{temp_dir}/" | ||
"nvidia_nccl_cu12-2.18.3-py3-none-manylinux1_x86_64.whl") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does vllm currently support amd arch (the wheel is only for x86)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean |
||
# make sure the target directory exists | ||
os.makedirs(target_dir, exist_ok=True) | ||
# Check if the file is already downloaded | ||
if os.path.exists(target_dir + "nvidia"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit, but if os.path.exists(target_dir + "nvidia"):
break
# Download the file
logger.info('Downloading file...')
....
.... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have no choice here, we cannot |
||
logger.info('library already exists.') | ||
else: | ||
# Download the file | ||
logger.info('Downloading file...') | ||
os.system(f"wget {file_url} -q -P {temp_dir}/") | ||
# Unzip the file | ||
logger.info('Unzipping file...') | ||
with zipfile.ZipFile(local_zip_path, 'r') as zip_ref: | ||
zip_ref.extractall(temp_dir) | ||
shutil.rmtree(f"{temp_dir}/nvidia_nccl_cu12-2.18.3.dist-info") | ||
os.remove(local_zip_path) | ||
# Move the unzipped files to the target directory | ||
logger.info('Moving files...') | ||
os.system(f"mv {temp_dir}/nvidia {target_dir}") | ||
so_path = f"{target_dir}/nvidia/nccl/lib/libnccl.so.2" | ||
os.rename(so_path, so_path.replace(".so.2", ".so.2.18.3")) | ||
|
||
|
||
def get_hipcc_rocm_version(): | ||
# Run the hipcc --version command | ||
result = subprocess.run(['hipcc', '--version'], | ||
|
@@ -330,7 +378,10 @@ def get_requirements() -> List[str]: | |
ext_modules.append(CMakeExtension(name="vllm._C")) | ||
|
||
package_data = { | ||
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] | ||
"vllm": [ | ||
"py.typed", "model_executor/layers/fused_moe/configs/*.json", | ||
"lib/nvidia/nccl/lib/libnccl.so.2.18.3" | ||
] | ||
} | ||
if os.environ.get("VLLM_USE_PRECOMPILED"): | ||
package_data["vllm"].append("*.so") | ||
|
@@ -362,6 +413,8 @@ def get_requirements() -> List[str]: | |
python_requires=">=3.8", | ||
install_requires=get_requirements(), | ||
ext_modules=ext_modules, | ||
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, | ||
cmdclass={ | ||
"build_ext": cmake_build_ext if not _is_neuron() else build_ext, | ||
}, | ||
package_data=package_data, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# this script is not run with `pytest`. | ||
# It is run with `torchrun`. | ||
import os | ||
import multiprocessing | ||
import pytest | ||
import torch | ||
from vllm.model_executor.parallel_utils.pynccl import ( | ||
NCCLCommunicator, | ||
ncclGetUniqueId, | ||
) | ||
|
||
|
||
def distributed_run(fn, world_size): | ||
number_of_processes = world_size | ||
processes = [] | ||
for i in range(number_of_processes): | ||
env = os.environ.copy() | ||
env['RANK'] = str(i) | ||
env['WORLD_SIZE'] = str(number_of_processes) | ||
env['MASTER_ADDR'] = 'localhost' | ||
env['MASTER_PORT'] = '12345' | ||
p = multiprocessing.Process(target=fn, args=(env, )) | ||
processes.append(p) | ||
p.start() | ||
|
||
for p in processes: | ||
p.join() | ||
|
||
|
||
def update_env(fn): | ||
|
||
def wrapper(env): | ||
import os | ||
os.environ.update(env) | ||
fn() | ||
|
||
return wrapper | ||
|
||
|
||
@update_env | ||
def worker_fn(): | ||
comm = NCCLCommunicator() | ||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) | ||
comm.all_reduce(tensor) | ||
result = tensor.mean().cpu().item() | ||
assert result == comm.world_size | ||
|
||
|
||
@pytest.mark.skipif(torch.cuda.device_count() < 2, | ||
reason="Need at least 2 GPUs to run the test.") | ||
def test_pynccl(): | ||
distributed_run(worker_fn, 2) | ||
|
||
|
||
@update_env | ||
def worker_fn_with_cudagraph(): | ||
with torch.no_grad(): | ||
graph = torch.cuda.CUDAGraph() | ||
comm = NCCLCommunicator() | ||
# run something in the default stream to initialize torch engine | ||
a = torch.ones((4, 4), device=f'cuda:{comm.rank}') | ||
torch.cuda.synchronize() | ||
with torch.cuda.graph(graph, stream=comm.stream): | ||
comm.all_reduce(a) | ||
comm.stream.synchronize() | ||
assert a.mean().cpu().item() == comm.world_size**0 | ||
graph.replay() | ||
comm.stream.synchronize() | ||
assert a.mean().cpu().item() == comm.world_size**2 | ||
|
||
|
||
@pytest.mark.skipif(torch.cuda.device_count() < 2, | ||
reason="Need at least 2 GPUs to run the test.") | ||
def test_pynccl_with_cudagraph(): | ||
distributed_run(worker_fn_with_cudagraph, 2) | ||
|
||
|
||
def test_ncclGetUniqueId(): | ||
unique_id = ncclGetUniqueId() | ||
# `list(unique_id.internal)` is something like this: | ||
# [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0, | ||
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | ||
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | ||
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | ||
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | ||
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] | ||
# as long as the function doesn't raise an exception, we're good | ||
assert unique_id is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that pytorch 2.2.0 has 9.0a support by default:
https://github.com/pytorch/pytorch/blob/19d27a13ea052230d9fb565a5b82e683e28d1697/Dockerfile#L60
while our docker image does not support 9.0a
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this make our build system always compile the CUDA kernels for all architectures?
If I remember correctly, we only compiled the kernels for a single architecture by detecting the equipped GPUs on the user machine (I'm not sure this is still true after we changed our build system to CMake though), to reduce the compile time. Exceptionally, we targeted all architectures when building docker images or pypi wheels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is used in docker image. It seems this CMake inherit the build architecture from pytorch by default, and so I have to change it (to avoid 9.0a architecture that is not supported in docker nvcc).