Skip to content

Commit 39fd287

Browse files
authored
fix vllm memory leak (#3515)
1 parent f4b9566 commit 39fd287

File tree

3 files changed

+232
-3
lines changed

3 files changed

+232
-3
lines changed

swift/llm/infer/infer_engine/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .infer_client import InferClient
1212
from .infer_engine import InferEngine
1313
from .base import BaseInferEngine
14-
from .utils import prepare_generation_config, AdapterRequest, set_device_context
14+
from .utils import prepare_generation_config, AdapterRequest, set_device_context, patch_vllm_memory_leak
1515
else:
1616
_extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')}
1717
_import_structure = {
@@ -22,7 +22,7 @@
2222
'infer_client': ['InferClient'],
2323
'infer_engine': ['InferEngine'],
2424
'base': ['BaseInferEngine'],
25-
'utils': ['prepare_generation_config', 'AdapterRequest', 'set_device_context'],
25+
'utils': ['prepare_generation_config', 'AdapterRequest', 'set_device_context', 'patch_vllm_memory_leak'],
2626
}
2727

2828
import sys

swift/llm/infer/infer_engine/grpo_vllm_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from swift.plugin import Metric
1111
from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, RequestConfig
1212
from .patch import patch_auto_config, patch_auto_tokenizer
13-
from .utils import AdapterRequest
13+
from .utils import AdapterRequest, patch_vllm_memory_leak
1414

1515
try:
1616
# After setting the environment variables, import vllm. This way of writing allows lint to pass.
@@ -54,6 +54,7 @@ def __init__(
5454
distributed_executor_backend: Optional[str] = None,
5555
engine_kwargs: Optional[Dict[str, Any]] = None,
5656
) -> None:
57+
patch_vllm_memory_leak()
5758
self.use_async_engine = use_async_engine
5859
self.processor = get_model_tokenizer(
5960
model_id_or_path,

swift/llm/infer/infer_engine/utils.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,3 +506,231 @@ def restore_torch_device_after_vllm_init():
506506
current_device = torch.cuda.current_device()
507507
if origin_device != current_device:
508508
torch.cuda.set_device(origin_device)
509+
510+
511+
def patch_vllm_memory_leak():
512+
import vllm
513+
if version.parse(vllm.__version__) != version.parse('0.7.3'):
514+
return
515+
516+
def patch_vllm_abort_seq_group():
517+
from vllm.core.scheduler import Scheduler
518+
from typing import Iterable, Dict
519+
from vllm.sequence import SequenceGroupBase, SequenceGroup, SequenceStatus
520+
521+
def new_abort_seq_group(
522+
self,
523+
request_id: Union[str, Iterable[str]],
524+
seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None,
525+
) -> None:
526+
if isinstance(request_id, str):
527+
request_id = (request_id, )
528+
request_ids = set(request_id)
529+
seq_id_to_seq_group = seq_id_to_seq_group or {}
530+
for state_queue in [self.waiting, self.running, self.swapped]:
531+
aborted_groups: List[SequenceGroup] = []
532+
for seq_group in state_queue:
533+
# When n>1, seq_group.request_id looks like
534+
# foo_parallel_sample_0, while request_ids is just foo, and we
535+
# should resolve it as real_request_id to match.
536+
if seq_group.request_id in seq_id_to_seq_group:
537+
real_request_id = seq_id_to_seq_group[seq_group.request_id].group_id
538+
else:
539+
real_request_id = seq_group.request_id
540+
if real_request_id in request_ids:
541+
# Appending aborted group into pending list.
542+
aborted_groups.append(seq_group)
543+
# We can't remove real_request_id in request_ids here,
544+
# because there may be other seq groups sharing the same
545+
# real_request_id
546+
for aborted_group in aborted_groups:
547+
# Remove the sequence group from the state queue.
548+
state_queue.remove(aborted_group)
549+
# Remove the aborted request from the Mamba cache.
550+
self._finished_requests_ids.append(aborted_group.request_id)
551+
for seq in aborted_group.get_seqs():
552+
if seq.is_finished():
553+
continue
554+
seq.status = SequenceStatus.FINISHED_ABORTED
555+
self.free_seq(seq)
556+
if aborted_group.request_id in seq_id_to_seq_group:
557+
del seq_id_to_seq_group[aborted_group.request_id]
558+
559+
self._free_seq_group_cross_attn_blocks(aborted_group)
560+
561+
origin_method = Scheduler.abort_seq_group
562+
Scheduler._old_abort_seq_group = origin_method
563+
Scheduler.abort_seq_group = new_abort_seq_group
564+
565+
def patch_vllm_engine():
566+
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
567+
from vllm.outputs import PoolingRequestOutput, RequestOutput
568+
from vllm.sequence import ExecuteModelRequest
569+
570+
def new_abort_request(self, request_id) -> None:
571+
for scheduler in self.scheduler:
572+
scheduler.abort_seq_group(request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
573+
574+
origin_method = LLMEngine.abort_request
575+
LLMEngine._old_abort_request = origin_method
576+
LLMEngine.abort_request = new_abort_request
577+
578+
def new_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
579+
if self.parallel_config.pipeline_parallel_size > 1:
580+
raise NotImplementedError('Pipeline parallelism is only supported through AsyncLLMEngine '
581+
'as performance will be severely degraded otherwise.')
582+
583+
# For llm_engine, there is no pipeline parallel support, so the engine
584+
# used is always 0.
585+
virtual_engine = 0
586+
587+
# These are cached outputs from previous iterations. None if on first
588+
# iteration
589+
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
590+
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
591+
scheduler_outputs = cached_outputs.scheduler_outputs
592+
allow_async_output_proc = cached_outputs.allow_async_output_proc
593+
594+
ctx = self.scheduler_contexts[virtual_engine]
595+
596+
# Clear outputs for each new scheduler iteration
597+
ctx.request_outputs.clear()
598+
599+
# Skip the scheduler if there are any remaining steps in the seq groups.
600+
# This ensures that the scheduler is only called again when the current
601+
# batch has completed.
602+
# The scheduler is also skipped if a single request caused the last
603+
# engine step to fail, and the previous schedule needs to be rerun.
604+
if not self._has_remaining_steps(seq_group_metadata_list):
605+
# Schedule iteration
606+
(seq_group_metadata_list, scheduler_outputs,
607+
allow_async_output_proc) = self.scheduler[virtual_engine].schedule()
608+
609+
ctx.seq_group_metadata_list = seq_group_metadata_list
610+
ctx.scheduler_outputs = scheduler_outputs
611+
612+
finished_requests_ids = self.scheduler[virtual_engine].get_and_reset_finished_requests_ids()
613+
# When n>1, elements in self.seq_id_to_seq_group should be deleted
614+
# here, otherwise memory leaks.
615+
for finished_request_id in finished_requests_ids:
616+
if finished_request_id in self.seq_id_to_seq_group:
617+
del self.seq_id_to_seq_group[finished_request_id]
618+
619+
# Maybe switch from async mode to sync mode
620+
if not allow_async_output_proc and len(ctx.output_queue) > 0:
621+
self._process_model_outputs(ctx=ctx)
622+
623+
if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0):
624+
# cache the scheduler outputs for the next iteration if we have
625+
# lookahead slots
626+
self._cache_scheduler_outputs_for_multi_step(virtual_engine, seq_group_metadata_list,
627+
scheduler_outputs, allow_async_output_proc)
628+
else:
629+
finished_requests_ids = list()
630+
631+
assert seq_group_metadata_list is not None
632+
assert scheduler_outputs is not None
633+
634+
if not scheduler_outputs.is_empty():
635+
636+
# Check if we have a cached last_output from the previous iteration.
637+
# For supporting PP this is probably the best way to pass the
638+
# sampled_token_ids, as a separate broadcast over all the PP stages
639+
# will cause one virtual engine's microbatch to block the pipeline.
640+
last_sampled_token_ids = \
641+
self._get_last_sampled_token_ids(virtual_engine)
642+
643+
execute_model_req = ExecuteModelRequest(
644+
seq_group_metadata_list=seq_group_metadata_list,
645+
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
646+
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
647+
blocks_to_copy=scheduler_outputs.blocks_to_copy,
648+
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
649+
running_queue_size=scheduler_outputs.running_queue_size,
650+
finished_requests_ids=finished_requests_ids,
651+
# We use ExecuteModelRequest to pass the last sampled_token_ids
652+
# to each of the non-last PP stages for in-place prepare_input.
653+
last_sampled_token_ids=last_sampled_token_ids)
654+
655+
if allow_async_output_proc:
656+
execute_model_req.async_callback = self.async_callbacks[virtual_engine]
657+
658+
outputs = self.model_executor.execute_model(execute_model_req=execute_model_req)
659+
660+
# We need to do this here so that last step's sampled_token_ids can
661+
# be passed to the next iteration for PP.
662+
if self.scheduler_config.is_multi_step:
663+
self._update_cached_scheduler_output(virtual_engine, outputs)
664+
else:
665+
# Nothing scheduled => If there is pending async postprocessor,
666+
# then finish it here.
667+
if len(ctx.output_queue) > 0:
668+
self._process_model_outputs(ctx=ctx)
669+
# No outputs in this case
670+
outputs = []
671+
672+
# Finish the current step for all the sequence groups.
673+
if self.scheduler_config.is_multi_step:
674+
for seq_group in seq_group_metadata_list:
675+
seq_group.finish_step()
676+
677+
if not self._has_remaining_steps(seq_group_metadata_list):
678+
# clear the cache if we have finished all the steps.
679+
if self.scheduler_config.is_multi_step:
680+
self.cached_scheduler_outputs[0] = SchedulerOutputState()
681+
682+
# is_first_step_output is True only when the num_steps of all
683+
# the sequences are 1. When the num_steps > 1,
684+
# multi_step_model_runner does the first-step output append.
685+
is_first_step_output: bool = False if not seq_group_metadata_list \
686+
else seq_group_metadata_list[0].state.num_steps == 1
687+
688+
# Add results to the output_queue
689+
ctx.append_output(
690+
outputs=outputs,
691+
seq_group_metadata_list=seq_group_metadata_list,
692+
scheduler_outputs=scheduler_outputs,
693+
is_async=allow_async_output_proc,
694+
is_last_step=True,
695+
is_first_step_output=is_first_step_output)
696+
697+
if outputs and allow_async_output_proc:
698+
assert len(outputs) == 1, ('Async postprocessor expects only a single output set')
699+
700+
self._advance_to_next_step(outputs[0], seq_group_metadata_list,
701+
scheduler_outputs.scheduled_seq_groups)
702+
703+
# Check if need to run the usual non-async path
704+
if not allow_async_output_proc:
705+
self._process_model_outputs(ctx=ctx)
706+
707+
# Log stats.
708+
self.do_log_stats(scheduler_outputs, outputs)
709+
710+
# Tracing
711+
self.do_tracing(scheduler_outputs)
712+
else:
713+
# Multi-step case
714+
return ctx.request_outputs
715+
716+
if not self.has_unfinished_requests():
717+
# Drain async postprocessor (if exists)
718+
if len(ctx.output_queue) > 0:
719+
self._process_model_outputs(ctx=ctx)
720+
assert len(ctx.output_queue) == 0
721+
722+
# Stop the execute model loop in parallel workers until there are
723+
# more requests to process. This avoids waiting indefinitely in
724+
# torch.distributed ops which may otherwise timeout, and unblocks
725+
# the RPC thread in the workers so that they can process any other
726+
# queued control plane messages, such as add/remove lora adapters.
727+
self.model_executor.stop_remote_worker_execution_loop()
728+
729+
return ctx.request_outputs
730+
731+
origin_method = LLMEngine.step
732+
LLMEngine._old_step = origin_method
733+
LLMEngine.step = new_step
734+
735+
patch_vllm_abort_seq_group()
736+
patch_vllm_engine()

0 commit comments

Comments
 (0)