diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 5f122f6f74..490c819e38 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -30,7 +30,6 @@ from vllm.v1.worker.gpu_input_batch import InputBatch from vllm_ascend.ops.attention import vanilla_chunked_prefill -from vllm_ascend.utils import vllm_version_is class AscendAttentionBackend(AttentionBackend): @@ -141,14 +140,8 @@ def reorder_batch(self, input_batch: "InputBatch", def build(self, num_reqs, num_actual_tokens, max_query_len, common_prefix_len): - if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): - block_table = (self.runner.input_batch.block_table. - get_device_tensor()[:num_reqs]) - else: - block_table = self.runner.input_batch.block_table[ - 0].get_device_tensor() - block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( - block_table[:num_reqs]) + block_table = ( + self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) query_lens = self.runner.query_lens seq_lens = self.runner.seq_lens_cpu[:num_reqs] diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index da17cf21da..2d522e4963 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -16,7 +16,6 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla -from vllm_ascend.utils import vllm_version_is from vllm_ascend.worker.model_runner_v1 import NPUModelRunner if TYPE_CHECKING: @@ -239,12 +238,8 @@ def build(self, # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. device = self.runner.device - if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): - block_table = (self.runner.input_batch.block_table. - get_device_tensor()[:num_reqs]) - else: - block_table = (self.runner.input_batch.block_table[0]. - get_device_tensor()[:num_reqs]) + block_table = ( + self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True) input_positions = self.runner.positions_cpu[:num_actual_tokens].to( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 18037439c4..78b9d54e97 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -157,6 +157,24 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): raise NotImplementedError( "Non-Attention backend is not supported by V1 NPUModelRunner.") + self.attn_backend = get_attn_backend( + self.head_size, + self.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + ) + if self.attn_backend is None: + error_msg = ( + f"Error with get_att_backend: {self.head_size=}, " + f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{self.model_config.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 GPUModelRunner.") + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) @@ -189,6 +207,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): pin_memory=True, vocab_size=self.model_config.get_vocab_size(), ) + else: + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.model_config.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=True, + vocab_size=self.model_config.get_vocab_size(), + ) + self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) @@ -535,10 +564,7 @@ def _process_reqs( block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) - if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - else: - block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, @@ -952,16 +978,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ import torch_npu kv_caches: Dict[str, torch.Tensor] = {} - if not (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")): - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.model_config.max_model_len, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=True, - vocab_size=self.model_config.get_vocab_size(), - kv_cache_config=kv_cache_config, - ) for kv_cache_group in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group.kv_cache_spec