From a369a4743ece62308b8da54b273f78a293bb9320 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 Jan 2025 09:45:26 +0800 Subject: [PATCH] Revert "[core] separate builder init and builder prepare for each batch (#12253)" This reverts commit 66818e5b63818a286756653816772faa8622ad89. --- vllm/attention/backends/abstract.py | 11 +++---- vllm/attention/backends/flash_attn.py | 11 +++---- vllm/attention/backends/flashinfer.py | 14 ++++---- vllm/attention/backends/placeholder_attn.py | 8 ++--- vllm/attention/backends/torch_sdpa.py | 5 +-- vllm/attention/backends/utils.py | 13 ++++---- vllm/worker/cpu_model_runner.py | 24 ++++---------- vllm/worker/model_runner.py | 36 +++++++-------------- vllm/worker/model_runner_base.py | 5 --- vllm/worker/xpu_model_runner.py | 10 ++---- 10 files changed, 47 insertions(+), 90 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 8027a52b82ff..e55c3f7e8795 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -65,6 +65,11 @@ def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": def get_builder_cls() -> Type["AttentionMetadataBuilder"]: raise NotImplementedError + @classmethod + def make_metadata_builder(cls, *args, + **kwargs) -> "AttentionMetadataBuilder": + return cls.get_builder_cls()(*args, **kwargs) + @staticmethod @abstractmethod def get_kv_cache_shape( @@ -213,12 +218,6 @@ class AttentionMetadataBuilder(ABC, Generic[T]): @abstractmethod def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None: - """Create the builder, remember some configuration and parameters.""" - raise NotImplementedError - - @abstractmethod - def prepare(self) -> None: - """Prepare for one batch.""" raise NotImplementedError @abstractmethod diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 1be099283e47..566811573bd1 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -380,12 +380,6 @@ class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] @@ -399,6 +393,11 @@ def prepare(self): self.num_decode_tokens = 0 self.has_prefix_cache_hit = False + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool, prefix_cache_hit: bool): diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 3135b0b40534..be869a84b253 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -489,14 +489,6 @@ def advance_step(self, class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): - - self.input_builder = input_builder - self.runner = input_builder.runner - - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] @@ -509,6 +501,12 @@ def prepare(self): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout # for the precise definition of the following fields. # An example: diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 826311896d1d..d2dc0d6cf0a5 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -255,11 +255,6 @@ class PlaceholderAttentionMetadataBuilder( AttentionMetadataBuilder[PlaceholderAttentionMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): - - self.input_builder = input_builder - self.runner = input_builder.runner - - def prepare(self): self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.curr_seq_lens: List[int] = [] @@ -270,6 +265,9 @@ def prepare(self): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 + self.input_builder = input_builder + self.runner = input_builder.runner + def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool): diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c3b2398b4e63..617a5a5abeef 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -282,10 +282,7 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: self.chunked_prefill = input_builder.chunked_prefill - self.input_builder = input_builder - - def prepare(self): - self.input_data = self.input_builder.input_data + self.input_data = input_builder.input_data def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 84fe89b7df36..8ceeaf48bb19 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -122,13 +122,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): _metadata_cls: Type[TAttentionMetadata] def __init__(self, input_builder: "ModelInputForGPUBuilder"): - self.input_builder = input_builder - self.runner = input_builder.runner - - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] @@ -141,6 +134,12 @@ def prepare(self): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool): diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 4b429b67b36f..abbf6450ab7f 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -144,7 +144,9 @@ def __init__(self, runner: "CPUModelRunner", finished_requests_ids: Optional[List[str]] = None) -> None: super().__init__() + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.runner = runner + self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled or runner.cache_config.enable_prefix_caching) self.model_input_cls = self.runner._model_input_cls @@ -154,17 +156,10 @@ def __init__(self, self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.enable_lora = self.runner.lora_config is not None - if self.runner.attn_backend is not None: - # spec decode (e.g. Medusa) does not have atten backend - attn_backend = self.runner.attn_backend - self.att_metadata_builder = attn_backend.get_builder_cls()(self) - - def prepare(self, - finished_requests_ids: Optional[List[str]] = None) -> None: - self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.input_data = ModelInputForCPUBuilder.ModelInputData( self.runner.model_config.uses_mrope) - self.att_metadata_builder.prepare() + self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( + self) def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) @@ -436,7 +431,6 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): """ _model_input_cls: Type[TModelInputForCPU] _builder_cls: Type[ModelInputForCPUBuilder] - builder: ModelInputForCPUBuilder def __init__( self, @@ -483,10 +477,6 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - if hasattr(self, "_builder_cls"): - # multi-step model runner does not have `_builder_cls` - self.builder = self._builder_cls(weakref.proxy(self)) - def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) @@ -532,10 +522,10 @@ def _prepare_model_input_tensors( metadata for possible additional steps, e.g., sampling. """ - self.builder.prepare(finished_requests_ids) - self.builder.set_seq_group_list(seq_group_metadata_list) + builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) + builder.set_seq_group_list(seq_group_metadata_list) - return self.builder.build() # type: ignore + return builder.build() # type: ignore # sampler property will be used by spec_decode_worker @property diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cf2f1c6b3b87..7d483b3a9f5a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -455,13 +455,17 @@ def __init__(self, self.enable_prompt_adapter = (self.runner.prompt_adapter_config is not None) self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + self.finished_requests_ids = finished_requests_ids self.decode_only = True + # Intermediate data (data in CPU before going to GPU) for + # the current sequence group. + self.inter_data_list: List[ + ModelInputForGPUBuilder.InterDataForSeqGroup] = [] + # Attention metadata inputs. - if self.attn_backend is not None: - # spec decode (e.g. Medusa) does not have atten backend - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - weakref.proxy(self)) + self.attn_metadata_builder = self.attn_backend.make_metadata_builder( + weakref.proxy(self)) # Engine/Model configurations. self.chunked_prefill_enabled = ( @@ -473,17 +477,6 @@ def __init__(self, self.block_aligned_sliding_window = \ self.sliding_window_blocks * self.block_size - def prepare(self, - finished_requests_ids: Optional[List[str]] = None) -> None: - self.finished_requests_ids = finished_requests_ids - - # Intermediate data (data in CPU before going to GPU) for - # the current sequence group. - self.inter_data_list: List[ - ModelInputForGPUBuilder.InterDataForSeqGroup] = [] - - self.attn_metadata_builder.prepare() - def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, seq_group_metadata: SequenceGroupMetadata): """Compute context length, sequence length and tokens @@ -998,7 +991,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): """ _model_input_cls: Type[TModelInputForGPU] _builder_cls: Type[ModelInputForGPUBuilder] - builder: ModelInputForGPUBuilder def __init__( self, @@ -1099,10 +1091,6 @@ def __init__( SamplingMetadataCache() \ if self.parallel_config.pipeline_parallel_size == 1 else None - if hasattr(self, "_builder_cls"): - # multi-step model runner does not have `_builder_cls` - self.builder = self._builder_cls(weakref.proxy(self)) - def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: @@ -1208,13 +1196,13 @@ def _prepare_model_input_tensors( If cuda graph is required, this API automatically pads inputs. """ - self.builder.prepare(finished_requests_ids) + builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: - self.builder.add_seq_group(seq_group_metadata) + builder.add_seq_group(seq_group_metadata) - self.builder.reset_cached_inter_data() + builder.reset_cached_inter_data() - return self.builder.build() # type: ignore + return builder.build() # type: ignore @contextmanager def set_in_profile_run(self): diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index aef4bdcdd4bf..acfd6d0b03f6 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -200,11 +200,6 @@ class ModelRunnerInputBuilderBase(ABC, Generic[T]): """A builder to create ModelRunnerInputBase objects. """ - @abstractmethod - def prepare(self, - finished_requests_ids: Optional[List[str]] = None) -> None: - raise NotImplementedError - @abstractmethod def add_seq_group(self, seq_group_metadata): """TBA""" diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index b7b7b7227b22..ffe8c3219dbe 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -113,6 +113,7 @@ def __init__(self, runner: "XPUModelRunner", finished_requests_ids: Optional[List[str]] = None) -> None: super().__init__() + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.runner = runner self.model_input_cls = self.runner._model_input_cls self.attn_backend = self.runner.attn_backend @@ -120,10 +121,6 @@ def __init__(self, self.block_size = self.runner.block_size self.device = self.runner.device - def prepare(self, - finished_requests_ids: Optional[List[str]] = None) -> None: - self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) @@ -413,8 +410,6 @@ def __init__( SamplingMetadataCache() \ if self.parallel_config.pipeline_parallel_size == 1 else None - self.builder = self._builder_cls(weakref.proxy(self)) - def load_model(self) -> None: with DeviceMemoryProfiler() as m: self.model = get_model(vllm_config=self.vllm_config) @@ -524,8 +519,7 @@ def _prepare_model_input_tensors( metadata for possible additional steps, e.g., sampling. """ - builder = self.builder - builder.prepare(finished_requests_ids) + builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: builder.add_seq_group(seq_group_metadata)