6666 BaseProcessingInfo , PromptReplacement ,
6767 PromptUpdate )
6868from vllm .multimodal .profiling import BaseDummyInputsBuilder
69+ from vllm .multimodal .utils import run_dp_sharded_mrope_vision_model
6970from vllm .platforms import _Backend , current_platform
7071from vllm .sequence import IntermediateTensors
7172from vllm .transformers_utils .config import uses_mrope
@@ -217,17 +218,20 @@ def __init__(
217218 act_layer : type [nn .Module ] = QuickGELU ,
218219 quant_config : Optional [QuantizationConfig ] = None ,
219220 prefix : str = "" ,
221+ use_data_parallel : bool = False ,
220222 ):
221223 super ().__init__ ()
222224 self .fc1 = ColumnParallelLinear (in_features ,
223225 hidden_features ,
224226 quant_config = quant_config ,
225- prefix = f"{ prefix } .fc1" )
227+ prefix = f"{ prefix } .fc1" ,
228+ disable_tp = use_data_parallel )
226229 self .act = act_layer ()
227230 self .fc2 = RowParallelLinear (hidden_features ,
228231 in_features ,
229232 quant_config = quant_config ,
230- prefix = f"{ prefix } .fc2" )
233+ prefix = f"{ prefix } .fc2" ,
234+ disable_tp = use_data_parallel )
231235
232236 def forward (self , x : torch .Tensor ) -> torch .Tensor :
233237 x_parallel , _ = self .fc1 (x )
@@ -293,25 +297,28 @@ def __init__(
293297 projection_size : int ,
294298 quant_config : Optional [QuantizationConfig ] = None ,
295299 prefix : str = "" ,
300+ use_data_parallel : bool = False ,
296301 ) -> None :
297302 super ().__init__ ()
298303 # Per attention head and per partition values.
299- world_size = parallel_state . get_tensor_model_parallel_world_size ()
300- self . tp_size = world_size
304+ self . tp_size = ( 1 if use_data_parallel else
305+ parallel_state . get_tensor_model_parallel_world_size ())
301306 self .tp_rank = parallel_state .get_tensor_model_parallel_rank ()
302307 self .hidden_size_per_attention_head = dist_utils .divide (
303308 projection_size , num_heads )
304309 self .num_attention_heads_per_partition = dist_utils .divide (
305- num_heads , world_size )
310+ num_heads , self . tp_size )
306311
307312 self .qkv = ColumnParallelLinear (input_size = embed_dim ,
308313 output_size = 3 * projection_size ,
309314 quant_config = quant_config ,
310- prefix = f"{ prefix } .qkv" )
315+ prefix = f"{ prefix } .qkv" ,
316+ disable_tp = use_data_parallel )
311317 self .proj = RowParallelLinear (input_size = projection_size ,
312318 output_size = embed_dim ,
313319 quant_config = quant_config ,
314- prefix = f"{ prefix } .proj" )
320+ prefix = f"{ prefix } .proj" ,
321+ disable_tp = use_data_parallel )
315322
316323 # Detect attention implementation.
317324 self .attn_backend = get_vit_attn_backend (
@@ -453,6 +460,7 @@ def __init__(
453460 norm_layer : Optional [Callable [[int ], nn .Module ]] = None ,
454461 quant_config : Optional [QuantizationConfig ] = None ,
455462 prefix : str = "" ,
463+ use_data_parallel : bool = False ,
456464 ) -> None :
457465 super ().__init__ ()
458466 if norm_layer is None :
@@ -465,12 +473,14 @@ def __init__(
465473 num_heads = num_heads ,
466474 projection_size = dim ,
467475 quant_config = quant_config ,
468- prefix = f"{ prefix } .attn" )
476+ prefix = f"{ prefix } .attn" ,
477+ use_data_parallel = use_data_parallel )
469478 self .mlp = Qwen2VisionMLP (dim ,
470479 mlp_hidden_dim ,
471480 act_layer = act_layer ,
472481 quant_config = quant_config ,
473- prefix = f"{ prefix } .mlp" )
482+ prefix = f"{ prefix } .mlp" ,
483+ use_data_parallel = use_data_parallel )
474484
475485 def forward (
476486 self ,
@@ -531,6 +541,7 @@ def __init__(
531541 spatial_merge_size : int = 2 ,
532542 quant_config : Optional [QuantizationConfig ] = None ,
533543 prefix : str = "" ,
544+ use_data_parallel : bool = False ,
534545 ) -> None :
535546 super ().__init__ ()
536547 self .hidden_size = context_dim * (spatial_merge_size ** 2 )
@@ -542,13 +553,15 @@ def __init__(
542553 self .hidden_size ,
543554 bias = True ,
544555 quant_config = quant_config ,
545- prefix = f"{ prefix } .mlp.0" ),
556+ prefix = f"{ prefix } .mlp.0" ,
557+ disable_tp = use_data_parallel ),
546558 nn .GELU (),
547559 RowParallelLinear (self .hidden_size ,
548560 d_model ,
549561 bias = True ,
550562 quant_config = quant_config ,
551- prefix = f"{ prefix } .mlp.2" ),
563+ prefix = f"{ prefix } .mlp.2" ,
564+ disable_tp = use_data_parallel ),
552565 ])
553566
554567 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -600,6 +613,7 @@ def __init__(
600613 norm_eps : float = 1e-6 ,
601614 quant_config : Optional [QuantizationConfig ] = None ,
602615 prefix : str = "" ,
616+ use_data_parallel : bool = False ,
603617 ) -> None :
604618 super ().__init__ ()
605619
@@ -613,6 +627,9 @@ def __init__(
613627 num_heads = vision_config .num_heads
614628 mlp_ratio = vision_config .mlp_ratio
615629
630+ self .use_data_parallel = use_data_parallel
631+ self .out_hidden_size = vision_config .hidden_size
632+
616633 self .spatial_merge_size = spatial_merge_size
617634 self .num_heads = num_heads
618635 self .embed_dim = embed_dim
@@ -634,7 +651,8 @@ def __init__(
634651 mlp_ratio = mlp_ratio ,
635652 norm_layer = norm_layer ,
636653 quant_config = quant_config ,
637- prefix = f"{ prefix } .blocks.{ layer_idx } " )
654+ prefix = f"{ prefix } .blocks.{ layer_idx } " ,
655+ use_data_parallel = use_data_parallel )
638656 for layer_idx in range (depth )
639657 ])
640658 self .merger = Qwen2VisionPatchMerger (
@@ -643,6 +661,7 @@ def __init__(
643661 norm_layer = norm_layer ,
644662 quant_config = quant_config ,
645663 prefix = f"{ prefix } .merger" ,
664+ use_data_parallel = use_data_parallel ,
646665 )
647666 self .attn_backend = get_vit_attn_backend (
648667 head_size = head_dim , dtype = torch .get_default_dtype ())
@@ -659,8 +678,9 @@ def dtype(self) -> torch.dtype:
659678 def device (self ) -> torch .device :
660679 return self .patch_embed .proj .weight .device
661680
662- def rot_pos_emb (self , grid_thw : torch . Tensor ) -> torch .Tensor :
681+ def rot_pos_emb (self , grid_thw : list [ list [ int ]] ) -> torch .Tensor :
663682 pos_ids = []
683+ max_grid_size = 0
664684 for t , h , w in grid_thw :
665685 hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
666686 wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
@@ -678,8 +698,8 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
678698 ).permute (0 , 2 , 1 , 3 ).flatten ()
679699 pos_ids .append (
680700 torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
701+ max_grid_size = max (max_grid_size , h , w )
681702 pos_ids = torch .cat (pos_ids , dim = 0 )
682- max_grid_size = grid_thw [:, 1 :].max ()
683703 rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
684704 rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
685705 return rotary_pos_emb
@@ -698,7 +718,7 @@ def compute_attn_mask_seqlen(
698718 def forward (
699719 self ,
700720 x : torch .Tensor ,
701- grid_thw : torch . Tensor ,
721+ grid_thw : list [ list [ int ]] ,
702722 ) -> torch .Tensor :
703723 # patchify
704724 x = x .to (device = self .device , dtype = self .dtype )
@@ -708,8 +728,9 @@ def forward(
708728 rotary_pos_emb = self .rot_pos_emb (grid_thw )
709729
710730 # compute cu_seqlens
711- cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ],
712- grid_thw [:, 0 ]).cumsum (
731+ grid_thw_ = torch .tensor (grid_thw )
732+ cu_seqlens = torch .repeat_interleave (grid_thw_ [:, 1 ] * grid_thw_ [:, 2 ],
733+ grid_thw_ [:, 0 ]).cumsum (
713734 dim = 0 , dtype = torch .int32 )
714735 cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), "constant" , 0 )
715736
@@ -1112,6 +1133,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
11121133 "model." : "language_model.model." ,
11131134 })
11141135
1136+ supports_encoder_tp_data = True
1137+
11151138 def get_mrope_input_positions (
11161139 self ,
11171140 input_tokens : list [int ],
@@ -1239,6 +1262,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
12391262 quant_config = vllm_config .quant_config
12401263 multimodal_config = vllm_config .model_config .multimodal_config
12411264
1265+ self .use_data_parallel = multimodal_config .mm_encoder_tp_mode == "data"
12421266 self .config = config
12431267 self .multimodal_config = multimodal_config
12441268
@@ -1249,6 +1273,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
12491273 norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
12501274 quant_config = self ._maybe_ignore_quant_config (quant_config ),
12511275 prefix = maybe_prefix (prefix , "visual" ),
1276+ use_data_parallel = self .use_data_parallel ,
12521277 )
12531278 else :
12541279 self .visual = None
@@ -1357,7 +1382,15 @@ def _process_image_input(
13571382 image_embeds = image_input ["image_embeds" ]
13581383 else :
13591384 pixel_values = image_input ["pixel_values" ]
1360- image_embeds = self .visual (pixel_values , grid_thw = grid_thw )
1385+
1386+ if self .use_data_parallel :
1387+ return run_dp_sharded_mrope_vision_model (self .visual ,
1388+ pixel_values ,
1389+ grid_thw_list ,
1390+ rope_type = "rope_3d" )
1391+ else :
1392+ image_embeds = self .visual (pixel_values ,
1393+ grid_thw = grid_thw_list )
13611394
13621395 # Split concatenated embeddings for each image item.
13631396 merge_size = self .visual .spatial_merge_size
@@ -1377,7 +1410,14 @@ def _process_video_input(
13771410 video_embeds = video_input ["video_embeds" ]
13781411 else :
13791412 pixel_values_videos = video_input ["pixel_values_videos" ]
1380- video_embeds = self .visual (pixel_values_videos , grid_thw = grid_thw )
1413+ if self .use_data_parallel :
1414+ return run_dp_sharded_mrope_vision_model (self .visual ,
1415+ pixel_values_videos ,
1416+ grid_thw_list ,
1417+ rope_type = "rope_3d" )
1418+ else :
1419+ video_embeds = self .visual (pixel_values_videos ,
1420+ grid_thw = grid_thw_list )
13811421
13821422 # Split concatenated embeddings for each video item.
13831423 merge_size = self .visual .spatial_merge_size
0 commit comments