-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[Hardware][Intel] Generate custom activation ops using torch.compile for CPU backend. #5446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9983bb0
3b44ad6
f9fb3b7
9b8e97f
e910051
1aaccff
1b39c10
7ea50d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,8 @@ | |
| from vllm.model_executor import SamplingMetadata | ||
| from vllm.model_executor.model_loader import get_model | ||
| from vllm.multimodal import MULTIMODAL_REGISTRY | ||
| from vllm.sequence import SamplerOutput, SequenceGroupMetadata | ||
| from vllm.sampling_params import SamplingParams | ||
| from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata | ||
| from vllm.utils import make_tensor_with_pad | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
@@ -78,6 +79,10 @@ def __init__( | |
| # Lazy initialization. | ||
| self.model: nn.Module # Set after init_Model | ||
|
|
||
| @property | ||
| def vocab_size(self) -> int: | ||
| return self.model_config.get_vocab_size() | ||
|
|
||
| def load_model(self) -> None: | ||
| self.model = get_model( | ||
| model_config=self.model_config, | ||
|
|
@@ -89,6 +94,25 @@ def load_model(self) -> None: | |
| scheduler_config=self.scheduler_config, | ||
| cache_config=self.cache_config) | ||
|
|
||
| # Apply torch.compile to custom ops | ||
| from vllm.model_executor.custom_op import CustomOp | ||
|
||
|
|
||
| def replace_model(model: torch.nn.Module): | ||
| for _, m in model.named_children(): | ||
| if isinstance(m, CustomOp): | ||
| m.forward_native = torch.compile(m.forward_native, | ||
| dynamic=True, | ||
| options={ | ||
| "fx_graph_cache": | ||
| True, | ||
| "cpp_wrapper": True, | ||
| "dce": True | ||
| }) | ||
|
||
| replace_model(m) | ||
| return | ||
|
|
||
| replace_model(self.model) | ||
|
|
||
| def _prepare_prompt( | ||
| self, | ||
| seq_group_metadata_list: List[SequenceGroupMetadata], | ||
|
|
@@ -134,26 +158,29 @@ def _prepare_prompt( | |
| multi_modal_kwargs_list[k].append(v) | ||
|
|
||
| # Compute the slot mapping. | ||
| block_table = seq_group_metadata.block_tables[seq_id] | ||
| # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, | ||
| # where start_idx is max(0, seq_len - sliding_window). | ||
| # For example, if the prompt len is 10, sliding window is 8, and | ||
| # block size is 4, the first two tokens are masked and the slot | ||
| # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. | ||
| start_idx = 0 | ||
| if self.sliding_window is not None: | ||
| start_idx = max(0, seq_len - self.sliding_window) | ||
|
|
||
| for i in range(computed_len, seq_len): | ||
| if i < start_idx: | ||
| slot_mapping.append(_PAD_SLOT_ID) | ||
| continue | ||
|
|
||
| block_number = block_table[i // | ||
| self.block_size] # type: ignore | ||
| block_offset = i % self.block_size # type: ignore | ||
| slot = block_number * self.block_size + block_offset | ||
| slot_mapping.append(slot) | ||
| if seq_group_metadata.block_tables is None: | ||
| block_table = [] | ||
| else: | ||
| block_table = seq_group_metadata.block_tables[seq_id] | ||
| # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID | ||
| # , where start_idx is max(0, seq_len - sliding_window). | ||
| # For example, if the prompt len is 10, sliding window is 8, and | ||
| # block size is 4, the first two tokens are masked and the slot | ||
| # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. | ||
| start_idx = 0 | ||
| if self.sliding_window is not None: | ||
| start_idx = max(0, seq_len - self.sliding_window) | ||
|
|
||
| for i in range(computed_len, seq_len): | ||
| if i < start_idx: | ||
| slot_mapping.append(_PAD_SLOT_ID) | ||
| continue | ||
|
|
||
| block_number = block_table[i // | ||
| self.block_size] # type: ignore | ||
| block_offset = i % self.block_size # type: ignore | ||
| slot = block_number * self.block_size + block_offset | ||
| slot_mapping.append(slot) | ||
|
|
||
| multi_modal_kwargs = { | ||
| k: torch.cat(v, dim=0).to(self.device) | ||
|
|
@@ -361,3 +388,45 @@ def execute_model( | |
| sampling_metadata=sampling_metadata, | ||
| ) | ||
| return output | ||
|
|
||
| @torch.inference_mode() | ||
| def profile_run(self) -> None: | ||
|
||
| # Warming up the model with batchsize = [1, max_num_seqs, | ||
| # max_num_batched_tokens], to generate corresponding operators | ||
| # using torch.compile. | ||
|
|
||
| model_config = self.model_config | ||
| vlm_config = self.vision_language_config | ||
|
|
||
| sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) | ||
| max_num_batched_tokens = min( | ||
| self.scheduler_config.max_num_batched_tokens, | ||
| self.model_config.max_model_len, | ||
| ) | ||
| max_num_seqs = self.scheduler_config.max_num_seqs | ||
|
|
||
| assert self.lora_config is None | ||
|
|
||
| # Run the model with the dummy inputs. | ||
| num_layers = self.model_config.get_num_layers(self.parallel_config) | ||
| kv_caches = [None] * num_layers | ||
|
|
||
| for seq_len in [1, max_num_seqs, max_num_batched_tokens]: | ||
| if vlm_config: | ||
| seq_data, dummy_multi_modal_data = ( | ||
| MULTIMODAL_REGISTRY.dummy_data_for_profiling( | ||
| seq_len, model_config, vlm_config)) | ||
| else: | ||
| seq_data = SequenceData([0] * seq_len) | ||
| dummy_multi_modal_data = None | ||
| seq = SequenceGroupMetadata( | ||
| request_id=str(0), | ||
| is_prompt=True, | ||
| seq_data={0: seq_data}, | ||
| sampling_params=sampling_params, | ||
| block_tables=None, | ||
| lora_request=None, | ||
| multi_modal_data=dummy_multi_modal_data, | ||
| ) | ||
| self.execute_model([seq], kv_caches) | ||
| return | ||
Uh oh!
There was an error while loading. Please reload this page.