Skip to content
Closed
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
09133e9
integrate aiter kernels: Linear, Norm, MOE
vllmellm Feb 26, 2025
ead17c7
maintain a consistent import pattern
vllmellm Feb 26, 2025
2527956
add aiter fp8 block scaled moe kernel
vllmellm Feb 26, 2025
814702a
bugfix: fix import paths and wrong env variables
vllmellm Feb 26, 2025
024cfc5
rename importing module names from amd/rocm aiter package to avoid c…
vllmellm Feb 28, 2025
7cfe429
bugfixe on wrong env variable spelling an add missing statment cond…
vllmellm Mar 1, 2025
41e7e4f
enabled VLLM_ROCM_USE_AITER in unit-tests
vllmellm Mar 1, 2025
5f668ea
include the AMD AITER package in rocm_base docker file
vllmellm Mar 1, 2025
8c5eb52
integrate AITER paged attention
vllmellm Mar 3, 2025
77cb436
bugfixes and disable rocm aiter paged attention
vllmellm Mar 3, 2025
942aa5b
Merge remote-tracking branch origin/main into aiter-integration
vllmellm Mar 4, 2025
4c41781
revert back the custom pa condition
vllmellm Mar 4, 2025
c09a740
enable AITER tgemm.mm per tensor scaled mm unittest
tjtanaa Mar 4, 2025
e19b7f5
bugfix: shuffle the weights when using aiter fmoe block scaled kernel
vllmellm Mar 4, 2025
11ac580
fix environment wrong variable in unit tests
vllmellm Mar 4, 2025
0865124
add aiter block gemm kernel and refactor aiter envs conditions
tjtanaa Mar 5, 2025
623dadb
add dispatch tests
vllmellm Mar 5, 2025
459bb02
add dispatch tests
vllmellm Mar 5, 2025
acc27ff
add dispatch tests
vllmellm Mar 5, 2025
11b6aba
bugfixes in layernorm and fix spelling mistakes
vllmellm Mar 5, 2025
0a6b8a0
enable rocm aiter paged attention
vllmellm Mar 5, 2025
1474828
bugfix: add the missing argument in dispatch
vllmellm Mar 5, 2025
b78114a
update rocm AITER commit version
vllmellm Mar 6, 2025
d20d757
bug fix
vllmellm Mar 6, 2025
7754c2e
add more comments for code documentation
vllmellm Mar 6, 2025
5e31c3e
disable some model tests
vllmellm Mar 7, 2025
d21c912
move rocm-aiter env flag checks to vllm.platforms.current_platform
vllmellm Mar 7, 2025
59f0208
bugfixes after refactoring the aiter modules enablility in current pl…
vllmellm Mar 7, 2025
17b4d6a
update AMD CI to skip certain test cases
vllmellm Mar 7, 2025
c32c31f
refactor dispatching for w8a8 scaled-mm
vllmellm Mar 7, 2025
e412824
Merge remote-tracking branch 'origin/main' into aiter-integration
vllmellm Mar 8, 2025
a5d7339
fix cutlass flag bug
vllmellm Mar 8, 2025
ce30f63
revert test requirements
vllmellm Mar 8, 2025
375e9db
revert test requirements
vllmellm Mar 8, 2025
fd3f4e3
revert test requirements
vllmellm Mar 8, 2025
f64bfe0
addressing PR comment reviews: fix isort ignores, revert back missing…
vllmellm Mar 11, 2025
c1297e5
add missing comment in fp8_utils
vllmellm Mar 11, 2025
529714c
Merge remote-tracking branch 'origin/main' into aiter-integration
vllmellm Mar 12, 2025
3859abc
Merge remote-tracking branch 'origin/main' into aiter-integration
vllmellm Mar 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions Dockerfile.rocm_base
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete
ARG HIPBLASLT_BRANCH="4d40e36"
ARG HIPBLASLT_BRANCH="db8e93b4"
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
ARG LEGACY_HIPBLASLT_OPTION=
ARG RCCL_BRANCH="648a58d"
ARG RCCL_REPO="https://github.com/ROCm/rccl"
ARG TRITON_BRANCH="e5be006"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
ARG PYTORCH_BRANCH="3a585126"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_BRANCH="6c0e7463"
ARG PYTORCH_VISION_BRANCH="v0.21.0"
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="b7d29fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="e1ec015"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"

FROM ${BASE_IMAGE} AS base

Expand Down Expand Up @@ -108,7 +110,7 @@ RUN git clone ${FA_REPO}
RUN cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
&& GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
&& cp /app/vision/dist/*.whl /app/install \
&& cp /app/flash-attention/dist/*.whl /app/install
Expand All @@ -129,7 +131,17 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl

ARG AITER_REPO
ARG AITER_BRANCH
RUN git clone --recursive ${AITER_REPO}
RUN cd aiter \
&& git checkout ${AITER_BRANCH} \
&& git submodule update --init --recursive \
&& pip install -r requirements.txt \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter

ARG BASE_IMAGE
ARG HIPBLAS_COMMON_BRANCH
ARG HIPBLASLT_BRANCH
ARG LEGACY_HIPBLASLT_OPTION
ARG RCCL_BRANCH
Expand All @@ -155,4 +167,5 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
22 changes: 19 additions & 3 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ anyio==4.6.2.post1
# via httpx
argcomplete==3.5.1
# via datamodel-code-generator
async-timeout==4.0.3
# via
# aiohttp
# redis
attrs==24.2.0
# via
# aiohttp
Expand Down Expand Up @@ -116,6 +120,10 @@ encodec==0.1.1
# via vocos
evaluate==0.4.3
# via lm-eval
exceptiongroup==1.2.2
# via
# anyio
# pytest
fastparquet==2024.11.0
# via genai-perf
fastrlock==0.8.2
Expand Down Expand Up @@ -544,9 +552,7 @@ sentence-transformers==3.2.1
sentencepiece==0.2.0
# via mistral-common
setuptools==75.8.0
# via
# pytablewriter
# torch
# via pytablewriter
six==1.16.0
# via
# python-dateutil
Expand Down Expand Up @@ -591,6 +597,12 @@ timm==1.0.11
# via -r requirements-test.in
tokenizers==0.21.0
# via transformers
toml==0.10.2
# via datamodel-code-generator
tomli==2.2.1
# via
# black
# pytest
torch==2.5.1
# via
# -r requirements-test.in
Expand Down Expand Up @@ -651,13 +663,17 @@ typepy==1.3.2
# tabledata
typing-extensions==4.12.2
# via
# anyio
# bitsandbytes
# black
# huggingface-hub
# librosa
# mistral-common
# multidict
# pqdm
# pydantic
# pydantic-core
# rich
# torch
tzdata==2024.2
# via pandas
Expand Down
22 changes: 17 additions & 5 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,15 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,

@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
@torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype):
def test_mixtral_moe(dtype: torch.dtype, use_rocm_aiter: bool, monkeypatch):
"""Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""

if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Instantiate our and huggingface's MoE blocks
config = MixtralConfig()
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
Expand Down Expand Up @@ -242,10 +246,18 @@ def test_mixtral_moe(dtype: torch.dtype):
torch.bfloat16: 1e-2,
}

torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])
if use_rocm_aiter:
# The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
# https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501
torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=0.01,
atol=100)
else:
torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])


@pytest.mark.parametrize("m", [1, 33, 64, 222])
Expand Down
93 changes: 92 additions & 1 deletion tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch.nn.functional as F

from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.fused_moe.fused_moe import (
dispatch_fused_experts_func, dispatch_topk_func, rocm_aiter_fused_experts,
rocm_aiter_topk_softmax, torch_vllm_inplace_fused_experts,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this cause import error for non-rocm platform?

Copy link
Contributor

@vllmellm vllmellm Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, rocm_aiter_* functions are wrappers that will only be called when rocm platform is detected and aiter-specific env vars are set to True.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vllmellm @tjtanaa Please rebase to resolve the conflict and then we will put "ready" label to finalized the review. Thank you!

torch_vllm_outplace_fused_experts, vllm_topk_softmax)
from vllm.model_executor.layers.layernorm import (
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
rocm_aiter_rmsnorm2d_fwd_with_add)
from vllm.model_executor.layers.linear import (
dispatch_unquantized_linear_func, rocm_aiter_tgemm_mm)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, dispatch_w8a8_blockscale_func,
rocm_aiter_gemm_a8w8_blockscale, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform


# Registered subclass for test
Expand Down Expand Up @@ -87,3 +100,81 @@ def test_enabled_ops_invalid(env: str):
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled()


@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
topk_func = dispatch_topk_func()

if current_platform.is_rocm() and int(use_rocm_aiter):
assert topk_func == rocm_aiter_topk_softmax
else:
assert topk_func == vllm_topk_softmax


@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("inplace", [True, False])
def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
monkeypatch):

monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
fused_experts_func = dispatch_fused_experts_func(inplace)
if current_platform.is_rocm() and int(use_rocm_aiter):
assert fused_experts_func == rocm_aiter_fused_experts
elif inplace:
assert fused_experts_func == torch_vllm_inplace_fused_experts
else:
assert fused_experts_func == torch_vllm_outplace_fused_experts


@pytest.mark.parametrize("use_cutlass", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_block_gemm", ["0", "1"])
def test_block_gemm_dispatch(use_cutlass: bool, use_rocm_aiter: str,
use_rocm_aiter_block_gemm: str, monkeypatch):

monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_BLOCK_GEMM",
use_rocm_aiter_block_gemm)
block_scale_func = dispatch_w8a8_blockscale_func(use_cutlass)

if use_cutlass:
assert block_scale_func == cutlass_scaled_mm
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_block_gemm):
assert block_scale_func == rocm_aiter_gemm_a8w8_blockscale
else:
assert block_scale_func == w8a8_block_fp8_matmul


@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str,
use_rocm_aiter_norm: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_NORM", use_rocm_aiter_norm)
rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual)

if not add_residual:
assert rms_norm_func == rms_norm
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_norm):
assert rms_norm_func == rocm_aiter_rmsnorm2d_fwd_with_add
else:
assert rms_norm_func == fused_add_rms_norm


@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_linear", ["0", "1"])
def test_unquantized_linear_dispatch(use_rocm_aiter: str,
use_rocm_aiter_linear: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", use_rocm_aiter_linear)
linear_func = dispatch_unquantized_linear_func()
if current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_linear):
assert linear_func == rocm_aiter_tgemm_mm
else:
assert linear_func == F.linear
19 changes: 10 additions & 9 deletions tests/models/decoder_only/language/test_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"""
import pytest

from vllm.platforms import current_platform

from ...utils import check_logprobs_close

MODELS = [
Expand All @@ -18,15 +20,14 @@
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
dtype: str, max_tokens: int, num_logprobs: int,
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
Expand Down
Loading