@@ -506,3 +506,231 @@ def restore_torch_device_after_vllm_init():
506
506
current_device = torch .cuda .current_device ()
507
507
if origin_device != current_device :
508
508
torch .cuda .set_device (origin_device )
509
+
510
+
511
+ def patch_vllm_memory_leak ():
512
+ import vllm
513
+ if version .parse (vllm .__version__ ) != version .parse ('0.7.3' ):
514
+ return
515
+
516
+ def patch_vllm_abort_seq_group ():
517
+ from vllm .core .scheduler import Scheduler
518
+ from typing import Iterable , Dict
519
+ from vllm .sequence import SequenceGroupBase , SequenceGroup , SequenceStatus
520
+
521
+ def new_abort_seq_group (
522
+ self ,
523
+ request_id : Union [str , Iterable [str ]],
524
+ seq_id_to_seq_group : Optional [Dict [str , SequenceGroupBase ]] = None ,
525
+ ) -> None :
526
+ if isinstance (request_id , str ):
527
+ request_id = (request_id , )
528
+ request_ids = set (request_id )
529
+ seq_id_to_seq_group = seq_id_to_seq_group or {}
530
+ for state_queue in [self .waiting , self .running , self .swapped ]:
531
+ aborted_groups : List [SequenceGroup ] = []
532
+ for seq_group in state_queue :
533
+ # When n>1, seq_group.request_id looks like
534
+ # foo_parallel_sample_0, while request_ids is just foo, and we
535
+ # should resolve it as real_request_id to match.
536
+ if seq_group .request_id in seq_id_to_seq_group :
537
+ real_request_id = seq_id_to_seq_group [seq_group .request_id ].group_id
538
+ else :
539
+ real_request_id = seq_group .request_id
540
+ if real_request_id in request_ids :
541
+ # Appending aborted group into pending list.
542
+ aborted_groups .append (seq_group )
543
+ # We can't remove real_request_id in request_ids here,
544
+ # because there may be other seq groups sharing the same
545
+ # real_request_id
546
+ for aborted_group in aborted_groups :
547
+ # Remove the sequence group from the state queue.
548
+ state_queue .remove (aborted_group )
549
+ # Remove the aborted request from the Mamba cache.
550
+ self ._finished_requests_ids .append (aborted_group .request_id )
551
+ for seq in aborted_group .get_seqs ():
552
+ if seq .is_finished ():
553
+ continue
554
+ seq .status = SequenceStatus .FINISHED_ABORTED
555
+ self .free_seq (seq )
556
+ if aborted_group .request_id in seq_id_to_seq_group :
557
+ del seq_id_to_seq_group [aborted_group .request_id ]
558
+
559
+ self ._free_seq_group_cross_attn_blocks (aborted_group )
560
+
561
+ origin_method = Scheduler .abort_seq_group
562
+ Scheduler ._old_abort_seq_group = origin_method
563
+ Scheduler .abort_seq_group = new_abort_seq_group
564
+
565
+ def patch_vllm_engine ():
566
+ from vllm .engine .llm_engine import LLMEngine , SchedulerOutputState
567
+ from vllm .outputs import PoolingRequestOutput , RequestOutput
568
+ from vllm .sequence import ExecuteModelRequest
569
+
570
+ def new_abort_request (self , request_id ) -> None :
571
+ for scheduler in self .scheduler :
572
+ scheduler .abort_seq_group (request_id , seq_id_to_seq_group = self .seq_id_to_seq_group )
573
+
574
+ origin_method = LLMEngine .abort_request
575
+ LLMEngine ._old_abort_request = origin_method
576
+ LLMEngine .abort_request = new_abort_request
577
+
578
+ def new_step (self ) -> List [Union [RequestOutput , PoolingRequestOutput ]]:
579
+ if self .parallel_config .pipeline_parallel_size > 1 :
580
+ raise NotImplementedError ('Pipeline parallelism is only supported through AsyncLLMEngine '
581
+ 'as performance will be severely degraded otherwise.' )
582
+
583
+ # For llm_engine, there is no pipeline parallel support, so the engine
584
+ # used is always 0.
585
+ virtual_engine = 0
586
+
587
+ # These are cached outputs from previous iterations. None if on first
588
+ # iteration
589
+ cached_outputs = self .cached_scheduler_outputs [virtual_engine ]
590
+ seq_group_metadata_list = cached_outputs .seq_group_metadata_list
591
+ scheduler_outputs = cached_outputs .scheduler_outputs
592
+ allow_async_output_proc = cached_outputs .allow_async_output_proc
593
+
594
+ ctx = self .scheduler_contexts [virtual_engine ]
595
+
596
+ # Clear outputs for each new scheduler iteration
597
+ ctx .request_outputs .clear ()
598
+
599
+ # Skip the scheduler if there are any remaining steps in the seq groups.
600
+ # This ensures that the scheduler is only called again when the current
601
+ # batch has completed.
602
+ # The scheduler is also skipped if a single request caused the last
603
+ # engine step to fail, and the previous schedule needs to be rerun.
604
+ if not self ._has_remaining_steps (seq_group_metadata_list ):
605
+ # Schedule iteration
606
+ (seq_group_metadata_list , scheduler_outputs ,
607
+ allow_async_output_proc ) = self .scheduler [virtual_engine ].schedule ()
608
+
609
+ ctx .seq_group_metadata_list = seq_group_metadata_list
610
+ ctx .scheduler_outputs = scheduler_outputs
611
+
612
+ finished_requests_ids = self .scheduler [virtual_engine ].get_and_reset_finished_requests_ids ()
613
+ # When n>1, elements in self.seq_id_to_seq_group should be deleted
614
+ # here, otherwise memory leaks.
615
+ for finished_request_id in finished_requests_ids :
616
+ if finished_request_id in self .seq_id_to_seq_group :
617
+ del self .seq_id_to_seq_group [finished_request_id ]
618
+
619
+ # Maybe switch from async mode to sync mode
620
+ if not allow_async_output_proc and len (ctx .output_queue ) > 0 :
621
+ self ._process_model_outputs (ctx = ctx )
622
+
623
+ if (self .scheduler_config .is_multi_step and scheduler_outputs .num_lookahead_slots > 0 ):
624
+ # cache the scheduler outputs for the next iteration if we have
625
+ # lookahead slots
626
+ self ._cache_scheduler_outputs_for_multi_step (virtual_engine , seq_group_metadata_list ,
627
+ scheduler_outputs , allow_async_output_proc )
628
+ else :
629
+ finished_requests_ids = list ()
630
+
631
+ assert seq_group_metadata_list is not None
632
+ assert scheduler_outputs is not None
633
+
634
+ if not scheduler_outputs .is_empty ():
635
+
636
+ # Check if we have a cached last_output from the previous iteration.
637
+ # For supporting PP this is probably the best way to pass the
638
+ # sampled_token_ids, as a separate broadcast over all the PP stages
639
+ # will cause one virtual engine's microbatch to block the pipeline.
640
+ last_sampled_token_ids = \
641
+ self ._get_last_sampled_token_ids (virtual_engine )
642
+
643
+ execute_model_req = ExecuteModelRequest (
644
+ seq_group_metadata_list = seq_group_metadata_list ,
645
+ blocks_to_swap_in = scheduler_outputs .blocks_to_swap_in ,
646
+ blocks_to_swap_out = scheduler_outputs .blocks_to_swap_out ,
647
+ blocks_to_copy = scheduler_outputs .blocks_to_copy ,
648
+ num_lookahead_slots = scheduler_outputs .num_lookahead_slots ,
649
+ running_queue_size = scheduler_outputs .running_queue_size ,
650
+ finished_requests_ids = finished_requests_ids ,
651
+ # We use ExecuteModelRequest to pass the last sampled_token_ids
652
+ # to each of the non-last PP stages for in-place prepare_input.
653
+ last_sampled_token_ids = last_sampled_token_ids )
654
+
655
+ if allow_async_output_proc :
656
+ execute_model_req .async_callback = self .async_callbacks [virtual_engine ]
657
+
658
+ outputs = self .model_executor .execute_model (execute_model_req = execute_model_req )
659
+
660
+ # We need to do this here so that last step's sampled_token_ids can
661
+ # be passed to the next iteration for PP.
662
+ if self .scheduler_config .is_multi_step :
663
+ self ._update_cached_scheduler_output (virtual_engine , outputs )
664
+ else :
665
+ # Nothing scheduled => If there is pending async postprocessor,
666
+ # then finish it here.
667
+ if len (ctx .output_queue ) > 0 :
668
+ self ._process_model_outputs (ctx = ctx )
669
+ # No outputs in this case
670
+ outputs = []
671
+
672
+ # Finish the current step for all the sequence groups.
673
+ if self .scheduler_config .is_multi_step :
674
+ for seq_group in seq_group_metadata_list :
675
+ seq_group .finish_step ()
676
+
677
+ if not self ._has_remaining_steps (seq_group_metadata_list ):
678
+ # clear the cache if we have finished all the steps.
679
+ if self .scheduler_config .is_multi_step :
680
+ self .cached_scheduler_outputs [0 ] = SchedulerOutputState ()
681
+
682
+ # is_first_step_output is True only when the num_steps of all
683
+ # the sequences are 1. When the num_steps > 1,
684
+ # multi_step_model_runner does the first-step output append.
685
+ is_first_step_output : bool = False if not seq_group_metadata_list \
686
+ else seq_group_metadata_list [0 ].state .num_steps == 1
687
+
688
+ # Add results to the output_queue
689
+ ctx .append_output (
690
+ outputs = outputs ,
691
+ seq_group_metadata_list = seq_group_metadata_list ,
692
+ scheduler_outputs = scheduler_outputs ,
693
+ is_async = allow_async_output_proc ,
694
+ is_last_step = True ,
695
+ is_first_step_output = is_first_step_output )
696
+
697
+ if outputs and allow_async_output_proc :
698
+ assert len (outputs ) == 1 , ('Async postprocessor expects only a single output set' )
699
+
700
+ self ._advance_to_next_step (outputs [0 ], seq_group_metadata_list ,
701
+ scheduler_outputs .scheduled_seq_groups )
702
+
703
+ # Check if need to run the usual non-async path
704
+ if not allow_async_output_proc :
705
+ self ._process_model_outputs (ctx = ctx )
706
+
707
+ # Log stats.
708
+ self .do_log_stats (scheduler_outputs , outputs )
709
+
710
+ # Tracing
711
+ self .do_tracing (scheduler_outputs )
712
+ else :
713
+ # Multi-step case
714
+ return ctx .request_outputs
715
+
716
+ if not self .has_unfinished_requests ():
717
+ # Drain async postprocessor (if exists)
718
+ if len (ctx .output_queue ) > 0 :
719
+ self ._process_model_outputs (ctx = ctx )
720
+ assert len (ctx .output_queue ) == 0
721
+
722
+ # Stop the execute model loop in parallel workers until there are
723
+ # more requests to process. This avoids waiting indefinitely in
724
+ # torch.distributed ops which may otherwise timeout, and unblocks
725
+ # the RPC thread in the workers so that they can process any other
726
+ # queued control plane messages, such as add/remove lora adapters.
727
+ self .model_executor .stop_remote_worker_execution_loop ()
728
+
729
+ return ctx .request_outputs
730
+
731
+ origin_method = LLMEngine .step
732
+ LLMEngine ._old_step = origin_method
733
+ LLMEngine .step = new_step
734
+
735
+ patch_vllm_abort_seq_group ()
736
+ patch_vllm_engine ()
0 commit comments