@@ -457,17 +457,13 @@ def __init__(self,
457
457
self .enable_prompt_adapter = (self .runner .prompt_adapter_config
458
458
is not None )
459
459
self .multi_modal_input_mapper = self .runner .multi_modal_input_mapper
460
- self .finished_requests_ids = finished_requests_ids
461
460
self .decode_only = True
462
461
463
- # Intermediate data (data in CPU before going to GPU) for
464
- # the current sequence group.
465
- self .inter_data_list : List [
466
- ModelInputForGPUBuilder .InterDataForSeqGroup ] = []
467
-
468
462
# Attention metadata inputs.
469
- self .attn_metadata_builder = self .attn_backend .make_metadata_builder (
470
- weakref .proxy (self ))
463
+ if self .attn_backend is not None :
464
+ # spec decode (e.g. Medusa) does not have atten backend
465
+ self .attn_metadata_builder = self .attn_backend .get_builder_cls ()(
466
+ weakref .proxy (self ))
471
467
472
468
# Engine/Model configurations.
473
469
self .chunked_prefill_enabled = (
@@ -479,6 +475,17 @@ def __init__(self,
479
475
self .block_aligned_sliding_window = \
480
476
self .sliding_window_blocks * self .block_size
481
477
478
+ def prepare (self ,
479
+ finished_requests_ids : Optional [List [str ]] = None ) -> None :
480
+ self .finished_requests_ids = finished_requests_ids
481
+
482
+ # Intermediate data (data in CPU before going to GPU) for
483
+ # the current sequence group.
484
+ self .inter_data_list : List [
485
+ ModelInputForGPUBuilder .InterDataForSeqGroup ] = []
486
+
487
+ self .attn_metadata_builder .prepare ()
488
+
482
489
def _compute_lens (self , inter_data : InterDataForSeqGroup , seq_idx : int ,
483
490
seq_group_metadata : SequenceGroupMetadata ):
484
491
"""Compute context length, sequence length and tokens
@@ -993,6 +1000,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
993
1000
"""
994
1001
_model_input_cls : Type [TModelInputForGPU ]
995
1002
_builder_cls : Type [ModelInputForGPUBuilder ]
1003
+ builder : ModelInputForGPUBuilder
996
1004
997
1005
def __init__ (
998
1006
self ,
@@ -1093,6 +1101,10 @@ def __init__(
1093
1101
SamplingMetadataCache () \
1094
1102
if self .parallel_config .pipeline_parallel_size == 1 else None
1095
1103
1104
+ if hasattr (self , "_builder_cls" ):
1105
+ # multi-step model runner does not have `_builder_cls`
1106
+ self .builder = self ._builder_cls (weakref .proxy (self ))
1107
+
1096
1108
def load_model (self ) -> None :
1097
1109
logger .info ("Starting to load model %s..." , self .model_config .model )
1098
1110
with DeviceMemoryProfiler () as m :
@@ -1226,13 +1238,13 @@ def _prepare_model_input_tensors(
1226
1238
1227
1239
If cuda graph is required, this API automatically pads inputs.
1228
1240
"""
1229
- builder = self ._builder_cls ( weakref . proxy ( self ), finished_requests_ids )
1241
+ self .builder . prepare ( finished_requests_ids )
1230
1242
for seq_group_metadata in seq_group_metadata_list :
1231
- builder .add_seq_group (seq_group_metadata )
1243
+ self . builder .add_seq_group (seq_group_metadata )
1232
1244
1233
- builder .reset_cached_inter_data ()
1245
+ self . builder .reset_cached_inter_data ()
1234
1246
1235
- return builder .build () # type: ignore
1247
+ return self . builder .build () # type: ignore
1236
1248
1237
1249
@contextmanager
1238
1250
def set_in_profile_run (self ):
0 commit comments