-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Hardware][Intel] Optimize CPU backend and add more performance tips #4971
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
980de13
Add IPEX Paged Att.
bigPYJ1151 648d4c0
Fix
bigPYJ1151 cc00133
Fix env
bigPYJ1151 5e8b064
Refactor QKV shape in torch_sdpa to use fast code path.
bigPYJ1151 686a41b
Refine
bigPYJ1151 706d14e
Update doc
bigPYJ1151 1647c27
Update docker image.
bigPYJ1151 afe6262
Fix doc
bigPYJ1151 76d319a
trigger
bigPYJ1151 62708ef
trigger
bigPYJ1151 f822617
fix
bigPYJ1151 5fffea9
Fix
bigPYJ1151 0cda257
Fix
bigPYJ1151 b00a5a9
update
bigPYJ1151 b88142a
Fix
bigPYJ1151 fea13c9
Revert "Fix"
bigPYJ1151 ce00ff0
Revert "Revert "Fix""
bigPYJ1151 5779f70
Update IPEX
bigPYJ1151 3930932
update
bigPYJ1151 6c77c9e
update torch
bigPYJ1151 bdf030a
Update README.md
bigPYJ1151 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
from typing import Dict, List, Optional, Tuple | ||
|
||
import intel_extension_for_pytorch.llm.modules as ipex_modules | ||
import torch | ||
|
||
from vllm import _custom_ops as ops | ||
|
||
|
||
class PagedAttention: | ||
|
||
@staticmethod | ||
def get_supported_head_sizes() -> List[int]: | ||
return [64, 80, 96, 112, 128, 256] | ||
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@staticmethod | ||
def get_kv_cache_shape( | ||
num_blocks: int, | ||
block_size: int, | ||
num_kv_heads: int, | ||
head_size: int, | ||
*args, | ||
) -> Tuple[int, ...]: | ||
return (2, num_blocks, block_size * num_kv_heads * head_size) | ||
|
||
@staticmethod | ||
def split_kv_cache( | ||
kv_cache: torch.Tensor, | ||
num_kv_heads: int, | ||
head_size: int, | ||
*args, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
num_blocks = kv_cache.shape[1] | ||
|
||
key_cache = kv_cache[0] | ||
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) | ||
value_cache = kv_cache[1] | ||
value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) | ||
return key_cache, value_cache | ||
|
||
@staticmethod | ||
def write_to_paged_cache( | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
key_cache: torch.Tensor, | ||
value_cache: torch.Tensor, | ||
slot_mapping: torch.Tensor, | ||
kv_cache_dtype: str, | ||
kv_scale: float, | ||
*args, | ||
) -> None: | ||
ipex_modules.PagedAttention.reshape_and_cache( | ||
key, value, key_cache, value_cache, | ||
slot_mapping.flatten().int()) | ||
|
||
@staticmethod | ||
def forward_decode( | ||
query: torch.Tensor, | ||
key_cache: torch.Tensor, | ||
value_cache: torch.Tensor, | ||
block_tables: torch.Tensor, | ||
context_lens: torch.Tensor, | ||
max_context_len: int, | ||
kv_cache_dtype: str, | ||
num_kv_heads: int, | ||
scale: float, | ||
alibi_slopes: Optional[torch.Tensor], | ||
kv_scale: float, | ||
*args, | ||
) -> torch.Tensor: | ||
output = torch.empty_like(query) | ||
block_size = value_cache.shape[2] | ||
head_mapping = torch.arange( | ||
0, | ||
num_kv_heads, | ||
device="cpu", | ||
dtype=torch.int32, | ||
).view(num_kv_heads, | ||
1).repeat_interleave(query.size(1) // num_kv_heads).flatten() | ||
ipex_modules.PagedAttention.single_query_cached_kv_attention( | ||
output, query.contiguous(), key_cache, value_cache, head_mapping, | ||
scale, block_tables, context_lens, block_size, max_context_len, | ||
alibi_slopes) | ||
|
||
return output | ||
|
||
@staticmethod | ||
def forward_prefix( | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
key_cache: torch.Tensor, | ||
value_cache: torch.Tensor, | ||
block_tables: torch.Tensor, | ||
subquery_start_loc: torch.Tensor, | ||
prompt_lens_tensor: torch.Tensor, | ||
context_lens: torch.Tensor, | ||
max_subquery_len: int, | ||
alibi_slopes: Optional[torch.Tensor], | ||
*args, | ||
) -> torch.Tensor: | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def swap_blocks( | ||
src_kv_cache: torch.Tensor, | ||
dst_kv_cache: torch.Tensor, | ||
src_to_dst: Dict[int, int], | ||
*args, | ||
) -> None: | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def copy_blocks( | ||
kv_caches: List[torch.Tensor], | ||
src_to_dists: Dict[int, List[int]], | ||
*args, | ||
) -> None: | ||
key_caches = [kv_cache[0] for kv_cache in kv_caches] | ||
value_caches = [kv_cache[1] for kv_cache in kv_caches] | ||
ops.copy_blocks(key_caches, value_caches, src_to_dists) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we simply require users to use IPEX? In which case do we have to use the PagedAttention kernel in vLLM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, after the APIs in IPEX become stable we will add IPEX to the requirements so the users can use it directly. We want to leave the native kernel here to evaluate some latest features (e.g., 8bit KV cache) before the IPEX supports them and public release.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Thanks!