From 980de1369583df6ca8653bb407260d325e545948 Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Mon, 29 Apr 2024 05:48:56 +0000 Subject: [PATCH 01/21] Add IPEX Paged Att. --- vllm/attention/backends/torch_sdpa.py | 11 ++- vllm/attention/ops/ipex_attn.py | 129 ++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 2 deletions(-) create mode 100644 vllm/attention/ops/ipex_attn.py diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 9b50adec524..bcfc81207c4 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -1,5 +1,7 @@ """ Attention layer with torch scaled_dot_product_attention and PagedAttention.""" +import os + from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type @@ -8,8 +10,13 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) +from vllm.attention.ops.paged_attn import PagedAttentionMetadata +from vllm.utils import is_cpu + +if is_cpu() and os.getenv("VLLM_CPU_IPEX", 0): + from vllm.attention.ops.ipex_attn import PagedAttention +else: + from vllm.attention.ops.paged_attn import PagedAttention class TorchSDPABackend(AttentionBackend): diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py new file mode 100644 index 00000000000..4aa1623c610 --- /dev/null +++ b/vllm/attention/ops/ipex_attn.py @@ -0,0 +1,129 @@ +from typing import Dict, List, Optional, Tuple, Type + +import torch +import intel_extension_for_pytorch.llm.modules as ipex_modules + +class PagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 80, 96, 112, 128, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, -1, num_kv_heads, head_size) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, -1, num_kv_heads, head_size) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + kv_scale: float, + *args, + ) -> None: + ipex_modules.PagedAttention.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten().int() + ) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + kv_scale: float, + *args, + ) -> torch.Tensor: + output = torch.empty_like(query) + block_size = value_cache.shape[3] + head_mapping = torch.arange( + 0, + num_kv_heads, + device="cpu", + dtype=torch.int32, + ).view(num_kv_heads,1).repeat_interleave(query.size(1) // num_kv_heads).flatten() + ipex_modules.PagedAttention.single_query_cached_kv_attention( + output, + query.contiguous(), + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes + ) + + return output + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + subquery_start_loc: torch.Tensor, + prompt_lens_tensor: torch.Tensor, + context_lens: torch.Tensor, + max_subquery_len: int, + alibi_slopes: Optional[torch.Tensor], + *args, + ) -> torch.Tensor: + raise NotImplementedError + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + *args, + ) -> None: + pass + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + *args, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + raise NotImplementedError \ No newline at end of file From 648d4c0c8bca545a10109f16fd30b820f556bf6b Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Tue, 14 May 2024 02:17:33 +0000 Subject: [PATCH 02/21] Fix --- vllm/attention/ops/ipex_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 4aa1623c610..7ff0c5a1110 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -69,7 +69,7 @@ def forward_decode( *args, ) -> torch.Tensor: output = torch.empty_like(query) - block_size = value_cache.shape[3] + block_size = value_cache.shape[1] head_mapping = torch.arange( 0, num_kv_heads, @@ -78,7 +78,7 @@ def forward_decode( ).view(num_kv_heads,1).repeat_interleave(query.size(1) // num_kv_heads).flatten() ipex_modules.PagedAttention.single_query_cached_kv_attention( output, - query.contiguous(), + query, key_cache, value_cache, head_mapping, From cc00133d3d9b5bb25ba90b9dd2d3cd939e1c2c8e Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Tue, 14 May 2024 07:59:56 +0000 Subject: [PATCH 03/21] Fix env --- vllm/attention/backends/torch_sdpa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index bcfc81207c4..168d5d146a4 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -13,7 +13,7 @@ from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.utils import is_cpu -if is_cpu() and os.getenv("VLLM_CPU_IPEX", 0): +if is_cpu() and os.getenv("VLLM_CPU_IPEX", "0") == "1": from vllm.attention.ops.ipex_attn import PagedAttention else: from vllm.attention.ops.paged_attn import PagedAttention From 5e8b064744da45cfa23b8caf0b8b4c657a54f66a Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Fri, 17 May 2024 05:31:22 +0000 Subject: [PATCH 04/21] Refactor QKV shape in torch_sdpa to use fast code path. Co-authored-by: Jianan Gu --- vllm/attention/backends/torch_sdpa.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 168d5d146a4..f8f23ad8f82 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -204,13 +204,13 @@ def forward( attn_metadata.attn_bias): end = start + seq_len sub_out = scaled_dot_product_attention( - query[:, start:end, :], - key[:, start:end, :], - value[:, start:end, :], + query[None, :, start:end, :], + key[None, :, start:end, :], + value[None, :, start:end, :], attn_mask=mask, dropout_p=0.0, is_causal=not self.need_mask, - scale=self.scale).movedim(query.dim() - 2, 0) + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) output[start:end, :, :] = sub_out start = end else: @@ -255,7 +255,7 @@ def _make_alibi_bias( num_heads = alibi_slopes.shape[0] bias = bias[None, :].repeat((num_heads, 1, 1)) - bias.mul_(alibi_slopes[:, None, None]) + bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) inf_mask = torch.empty( (1, seq_len, seq_len), dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) From 686a41bc5684f0093e1c4cd6f455aee4771eaa1a Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Wed, 22 May 2024 05:44:27 +0000 Subject: [PATCH 05/21] Refine --- requirements-cpu.txt | 1 + vllm/attention/backends/torch_sdpa.py | 8 +++--- vllm/attention/ops/ipex_attn.py | 39 +++++++++------------------ vllm/envs.py | 6 +++++ 4 files changed, 24 insertions(+), 30 deletions(-) diff --git a/requirements-cpu.txt b/requirements-cpu.txt index b739642d8d3..9a7d5d204ad 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -3,4 +3,5 @@ # Dependencies for x86_64 CPUs torch == 2.3.0+cpu +intel-extension-for-pytorch >= 2.3.0 triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index f8f23ad8f82..e0cc4e601dc 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -1,19 +1,18 @@ """ Attention layer with torch scaled_dot_product_attention and PagedAttention.""" -import os - from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type import torch from torch.nn.functional import scaled_dot_product_attention +import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.utils import is_cpu -if is_cpu() and os.getenv("VLLM_CPU_IPEX", "0") == "1": +if is_cpu() and envs.VLLM_CPU_IPEX: from vllm.attention.ops.ipex_attn import PagedAttention else: from vllm.attention.ops.paged_attn import PagedAttention @@ -210,7 +209,8 @@ def forward( attn_mask=mask, dropout_p=0.0, is_causal=not self.need_mask, - scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + scale=self.scale).squeeze(0).movedim( + query.dim() - 2, 0) output[start:end, :, :] = sub_out start = end else: diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 7ff0c5a1110..ac9d2c004e0 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -1,7 +1,8 @@ -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple -import torch import intel_extension_for_pytorch.llm.modules as ipex_modules +import torch + class PagedAttention: @@ -46,12 +47,8 @@ def write_to_paged_cache( *args, ) -> None: ipex_modules.PagedAttention.reshape_and_cache( - key, - value, - key_cache, - value_cache, - slot_mapping.flatten().int() - ) + key, value, key_cache, value_cache, + slot_mapping.flatten().int()) @staticmethod def forward_decode( @@ -75,21 +72,13 @@ def forward_decode( num_kv_heads, device="cpu", dtype=torch.int32, - ).view(num_kv_heads,1).repeat_interleave(query.size(1) // num_kv_heads).flatten() + ).view(num_kv_heads, + 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() ipex_modules.PagedAttention.single_query_cached_kv_attention( - output, - query, - key_cache, - value_cache, - head_mapping, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes - ) - + output, query, key_cache, value_cache, head_mapping, scale, + block_tables, context_lens, block_size, max_context_len, + alibi_slopes) + return output @staticmethod @@ -116,7 +105,7 @@ def swap_blocks( src_to_dst: Dict[int, int], *args, ) -> None: - pass + raise NotImplementedError @staticmethod def copy_blocks( @@ -124,6 +113,4 @@ def copy_blocks( src_to_dists: Dict[int, List[int]], *args, ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/vllm/envs.py b/vllm/envs.py index bef343d0842..c66686b537e 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -27,6 +27,7 @@ VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 + VLLM_CPU_IPEX: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" VLLM_TARGET_DEVICE: str = "cuda" @@ -203,6 +204,11 @@ "VLLM_CPU_KVCACHE_SPACE": lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")), + # Using Paged Attention kernel from intel-extension-for-pytorch + # for the CPU backend. + "VLLM_CPU_IPEX": + lambda: bool(int(os.getenv("VLLM_CPU_IPEX", 0))), + # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. From 706d14e0afa975962786a5ffdae3c1da588d188e Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Wed, 22 May 2024 06:21:28 +0000 Subject: [PATCH 06/21] Update doc --- .../getting_started/cpu-installation.rst | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 5270253cae9..b82ec91ea2f 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -10,6 +10,7 @@ Table of contents: #. :ref:`Requirements ` #. :ref:`Quick start using Dockerfile ` #. :ref:`Build from source ` +#. :ref:`Intel Extension for PyTorch ` #. :ref:`Performance tips ` .. _cpu_backend_requirements: @@ -18,7 +19,7 @@ Requirements ------------ * OS: Linux -* Compiler: gcc/g++>=12.3.0 (recommended) +* Compiler: gcc/g++>=12.3.0 (optional, recommended) * Instruction set architecture (ISA) requirement: AVX512 is required. .. _cpu_backend_quick_start_dockerfile: @@ -41,7 +42,7 @@ Quick start using Dockerfile Build from source ----------------- -- First, install required compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run: +- First, install recommended compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run: .. code-block:: console @@ -70,6 +71,17 @@ Build from source - If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building. +.. _ipex_guidance: + +Intel Extension for PyTorch +--------------------------- + +- `Intel Extension for PyTorch (IPEX) ` extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware. IPEX can be enabled in the CPU backend by ``VLLM_CPU_IPEX=1``, for example: + +.. code-block:: console + + $ VLLM_CPU_IPEX=1 python examples/offline_inference.py + .. _cpu_backend_performance_tips: Performance tips @@ -77,6 +89,15 @@ Performance tips - vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. +- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run: + +.. code-block:: console + + $ sudo apt-get install libtcmalloc-minimal4 # install TCMalloc library + $ find / -name *libtcmalloc* # find the dynamic link library path + $ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD + $ python examples/offline_inference.py # run vLLM + - vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription. - If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading. From 1647c27a64622d45258143c3d26ee1c7de06f21c Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Wed, 22 May 2024 06:44:53 +0000 Subject: [PATCH 07/21] Update docker image. --- Dockerfile.cpu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 403a1cd0391..2d7744230e7 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -3,9 +3,11 @@ FROM ubuntu:22.04 AS cpu-test-1 RUN apt-get update -y \ - && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ + && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 +RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc + RUN pip install --upgrade pip \ && pip install wheel packaging ninja "setuptools>=49.4.0" numpy @@ -21,6 +23,6 @@ RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install WORKDIR /workspace/ -RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks +RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks CMD ["/bin/bash"] From afe626253777057208d041a48b1018d576758871 Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Wed, 22 May 2024 08:02:07 +0000 Subject: [PATCH 08/21] Fix doc --- docs/source/getting_started/cpu-installation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index b82ec91ea2f..9b8a3e2b9e3 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -76,7 +76,7 @@ Build from source Intel Extension for PyTorch --------------------------- -- `Intel Extension for PyTorch (IPEX) ` extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware. IPEX can be enabled in the CPU backend by ``VLLM_CPU_IPEX=1``, for example: +- `Intel Extension for PyTorch (IPEX) `_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware. IPEX can be enabled in the CPU backend by ``VLLM_CPU_IPEX=1``, for example: .. code-block:: console From 76d319ac9a2323dd2a318f28fe665d895651ae07 Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Thu, 23 May 2024 01:56:42 +0000 Subject: [PATCH 09/21] trigger From 62708ef8bea4c05d4cb95ce4a898046eed511d7e Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Thu, 23 May 2024 03:01:33 +0000 Subject: [PATCH 10/21] trigger From f822617391250cd497787695671f38687c7793d3 Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Fri, 24 May 2024 05:34:35 +0000 Subject: [PATCH 11/21] fix --- vllm/attention/ops/ipex_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index ac9d2c004e0..f463e2d3cf6 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -75,8 +75,8 @@ def forward_decode( ).view(num_kv_heads, 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() ipex_modules.PagedAttention.single_query_cached_kv_attention( - output, query, key_cache, value_cache, head_mapping, scale, - block_tables, context_lens, block_size, max_context_len, + output, query.contiguous(), key_cache, value_cache, head_mapping, + scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes) return output From 5fffea9e32f959c532d6588d09c922bdc7e6145c Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Fri, 24 May 2024 07:10:18 +0000 Subject: [PATCH 12/21] Fix --- Dockerfile.cpu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 2d7744230e7..d4db0876a22 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -8,6 +8,9 @@ RUN apt-get update -y \ RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc +# TODO: Remove it after the 2.3.1 offical release +RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.0%2Bgit1e79114-cp310-cp310-linux_x86_64.whl + RUN pip install --upgrade pip \ && pip install wheel packaging ninja "setuptools>=49.4.0" numpy From 0cda257bab0c4186947829241cf3dc96f850509a Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Fri, 24 May 2024 07:12:41 +0000 Subject: [PATCH 13/21] Fix --- Dockerfile.cpu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.cpu b/Dockerfile.cpu index d4db0876a22..61d1a3ee02f 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -8,7 +8,7 @@ RUN apt-get update -y \ RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc -# TODO: Remove it after the 2.3.1 offical release +# TODO: Remove it after the 2.3.1 official release RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.0%2Bgit1e79114-cp310-cp310-linux_x86_64.whl RUN pip install --upgrade pip \ From b00a5a9e5a6b63cfcd142f25427f3e9a5ad8f16c Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Tue, 28 May 2024 03:14:17 +0000 Subject: [PATCH 14/21] update --- docs/source/getting_started/cpu-installation.rst | 6 ++---- requirements-cpu.txt | 1 - vllm/attention/backends/torch_sdpa.py | 8 +++++--- vllm/envs.py | 6 ------ 4 files changed, 7 insertions(+), 14 deletions(-) diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 9b8a3e2b9e3..a9544e8a59a 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -76,11 +76,9 @@ Build from source Intel Extension for PyTorch --------------------------- -- `Intel Extension for PyTorch (IPEX) `_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware. IPEX can be enabled in the CPU backend by ``VLLM_CPU_IPEX=1``, for example: +- `Intel Extension for PyTorch (IPEX) `_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware. -.. code-block:: console - - $ VLLM_CPU_IPEX=1 python examples/offline_inference.py +- IPEX after the ``2.3.0`` can be enabled in the CPU backend by default if it is installed. .. _cpu_backend_performance_tips: diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 9a7d5d204ad..b739642d8d3 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -3,5 +3,4 @@ # Dependencies for x86_64 CPUs torch == 2.3.0+cpu -intel-extension-for-pytorch >= 2.3.0 triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index e0cc4e601dc..4b08cce99af 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -6,14 +6,16 @@ import torch from torch.nn.functional import scaled_dot_product_attention -import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.utils import is_cpu -if is_cpu() and envs.VLLM_CPU_IPEX: - from vllm.attention.ops.ipex_attn import PagedAttention +if is_cpu(): + try: + from vllm.attention.ops.ipex_attn import PagedAttention + except ImportError: + from vllm.attention.ops.paged_attn import PagedAttention else: from vllm.attention.ops.paged_attn import PagedAttention diff --git a/vllm/envs.py b/vllm/envs.py index c66686b537e..bef343d0842 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -27,7 +27,6 @@ VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 - VLLM_CPU_IPEX: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" VLLM_TARGET_DEVICE: str = "cuda" @@ -204,11 +203,6 @@ "VLLM_CPU_KVCACHE_SPACE": lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")), - # Using Paged Attention kernel from intel-extension-for-pytorch - # for the CPU backend. - "VLLM_CPU_IPEX": - lambda: bool(int(os.getenv("VLLM_CPU_IPEX", 0))), - # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. From b88142aaacb5876398d2599c2f48986637bba711 Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Thu, 30 May 2024 05:43:26 +0000 Subject: [PATCH 15/21] Fix Fix Fix trigger --- Dockerfile.cpu | 2 +- vllm/attention/ops/ipex_attn.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 61d1a3ee02f..725cc2a9436 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -9,7 +9,7 @@ RUN apt-get update -y \ RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc # TODO: Remove it after the 2.3.1 official release -RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.0%2Bgit1e79114-cp310-cp310-linux_x86_64.whl +RUN unset http_proxy && pip install http://mlpc.intel.com/downloads/cpu/ipex-2.3.100/rc0_vllm_0528/intel_extension_for_pytorch-2.3.0+gitac44227-cp310-cp310-linux_x86_64.whl RUN pip install --upgrade pip \ && pip install wheel packaging ninja "setuptools>=49.4.0" numpy diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index f463e2d3cf6..5a5317b6500 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -3,6 +3,8 @@ import intel_extension_for_pytorch.llm.modules as ipex_modules import torch +from vllm import _custom_ops as ops + class PagedAttention: @@ -30,9 +32,9 @@ def split_kv_cache( num_blocks = kv_cache.shape[1] key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, -1, num_kv_heads, head_size) + key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) value_cache = kv_cache[1] - value_cache = value_cache.view(num_blocks, -1, num_kv_heads, head_size) + value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) return key_cache, value_cache @staticmethod @@ -66,7 +68,7 @@ def forward_decode( *args, ) -> torch.Tensor: output = torch.empty_like(query) - block_size = value_cache.shape[1] + block_size = value_cache.shape[2] head_mapping = torch.arange( 0, num_kv_heads, @@ -113,4 +115,6 @@ def copy_blocks( src_to_dists: Dict[int, List[int]], *args, ) -> None: - raise NotImplementedError + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) From fea13c9b228ac42ce5439e1e08d9517a591ca251 Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Fri, 31 May 2024 05:12:56 +0000 Subject: [PATCH 16/21] Revert "Fix" This reverts commit 58c036ad079bab6d4a7beccae735c096e2818e37. --- Dockerfile.cpu | 2 +- vllm/attention/ops/ipex_attn.py | 12 ++++-------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 725cc2a9436..61d1a3ee02f 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -9,7 +9,7 @@ RUN apt-get update -y \ RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc # TODO: Remove it after the 2.3.1 official release -RUN unset http_proxy && pip install http://mlpc.intel.com/downloads/cpu/ipex-2.3.100/rc0_vllm_0528/intel_extension_for_pytorch-2.3.0+gitac44227-cp310-cp310-linux_x86_64.whl +RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.0%2Bgit1e79114-cp310-cp310-linux_x86_64.whl RUN pip install --upgrade pip \ && pip install wheel packaging ninja "setuptools>=49.4.0" numpy diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 5a5317b6500..f463e2d3cf6 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -3,8 +3,6 @@ import intel_extension_for_pytorch.llm.modules as ipex_modules import torch -from vllm import _custom_ops as ops - class PagedAttention: @@ -32,9 +30,9 @@ def split_kv_cache( num_blocks = kv_cache.shape[1] key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) + key_cache = key_cache.view(num_blocks, -1, num_kv_heads, head_size) value_cache = kv_cache[1] - value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) + value_cache = value_cache.view(num_blocks, -1, num_kv_heads, head_size) return key_cache, value_cache @staticmethod @@ -68,7 +66,7 @@ def forward_decode( *args, ) -> torch.Tensor: output = torch.empty_like(query) - block_size = value_cache.shape[2] + block_size = value_cache.shape[1] head_mapping = torch.arange( 0, num_kv_heads, @@ -115,6 +113,4 @@ def copy_blocks( src_to_dists: Dict[int, List[int]], *args, ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - ops.copy_blocks(key_caches, value_caches, src_to_dists) + raise NotImplementedError From ce00ff0dfb68641d25dc2198295cf0cbd886e59e Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Tue, 4 Jun 2024 05:58:21 +0000 Subject: [PATCH 17/21] Revert "Revert "Fix"" This reverts commit 3861c15e282062c8c5165ce01aa93972280ca92a. --- Dockerfile.cpu | 2 +- vllm/attention/ops/ipex_attn.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 61d1a3ee02f..725cc2a9436 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -9,7 +9,7 @@ RUN apt-get update -y \ RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc # TODO: Remove it after the 2.3.1 official release -RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.0%2Bgit1e79114-cp310-cp310-linux_x86_64.whl +RUN unset http_proxy && pip install http://mlpc.intel.com/downloads/cpu/ipex-2.3.100/rc0_vllm_0528/intel_extension_for_pytorch-2.3.0+gitac44227-cp310-cp310-linux_x86_64.whl RUN pip install --upgrade pip \ && pip install wheel packaging ninja "setuptools>=49.4.0" numpy diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index f463e2d3cf6..5a5317b6500 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -3,6 +3,8 @@ import intel_extension_for_pytorch.llm.modules as ipex_modules import torch +from vllm import _custom_ops as ops + class PagedAttention: @@ -30,9 +32,9 @@ def split_kv_cache( num_blocks = kv_cache.shape[1] key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, -1, num_kv_heads, head_size) + key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) value_cache = kv_cache[1] - value_cache = value_cache.view(num_blocks, -1, num_kv_heads, head_size) + value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) return key_cache, value_cache @staticmethod @@ -66,7 +68,7 @@ def forward_decode( *args, ) -> torch.Tensor: output = torch.empty_like(query) - block_size = value_cache.shape[1] + block_size = value_cache.shape[2] head_mapping = torch.arange( 0, num_kv_heads, @@ -113,4 +115,6 @@ def copy_blocks( src_to_dists: Dict[int, List[int]], *args, ) -> None: - raise NotImplementedError + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) From 5779f7007fe287af493a9e78a0da126c530df137 Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Tue, 4 Jun 2024 05:59:33 +0000 Subject: [PATCH 18/21] Update IPEX --- Dockerfile.cpu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 725cc2a9436..31ec1054b67 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -9,7 +9,7 @@ RUN apt-get update -y \ RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc # TODO: Remove it after the 2.3.1 official release -RUN unset http_proxy && pip install http://mlpc.intel.com/downloads/cpu/ipex-2.3.100/rc0_vllm_0528/intel_extension_for_pytorch-2.3.0+gitac44227-cp310-cp310-linux_x86_64.whl +RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.0%2Bgitac44227-cp310-cp310-linux_x86_64.whl RUN pip install --upgrade pip \ && pip install wheel packaging ninja "setuptools>=49.4.0" numpy From 3930932899326453e185d1cfe6cbf2972612dc19 Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Fri, 7 Jun 2024 05:12:42 +0000 Subject: [PATCH 19/21] update --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f333e9fee52..1cd1461edf4 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ vLLM is flexible and easy to use with: - Tensor parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs and AMD GPUs +- Support NVIDIA GPUs, AMD GPUs and Intel CPUs - (Experimental) Prefix caching support - (Experimental) Multi-lora support From 6c77c9ed34c9c151c98b35610e090cc0d369b0d7 Mon Sep 17 00:00:00 2001 From: "jiang1.li" Date: Thu, 13 Jun 2024 03:04:13 +0000 Subject: [PATCH 20/21] update torch --- Dockerfile.cpu | 3 +-- requirements-cpu.txt | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 31ec1054b67..777bb08296e 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -8,8 +8,7 @@ RUN apt-get update -y \ RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc -# TODO: Remove it after the 2.3.1 official release -RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.0%2Bgitac44227-cp310-cp310-linux_x86_64.whl +RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl RUN pip install --upgrade pip \ && pip install wheel packaging ninja "setuptools>=49.4.0" numpy diff --git a/requirements-cpu.txt b/requirements-cpu.txt index b739642d8d3..8b7d86e6862 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -2,5 +2,5 @@ -r requirements-common.txt # Dependencies for x86_64 CPUs -torch == 2.3.0+cpu +torch == 2.3.1+cpu triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file From bdf030a820508bc79c21646e59104fc68bb18035 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Thu, 13 Jun 2024 15:54:46 +0800 Subject: [PATCH 21/21] Update README.md Co-authored-by: Woosuk Kwon --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1cd1461edf4..8a6b7894904 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ vLLM is flexible and easy to use with: - Tensor parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs, AMD GPUs and Intel CPUs +- Support NVIDIA GPUs, AMD GPUs, and Intel CPUs - (Experimental) Prefix caching support - (Experimental) Multi-lora support