diff --git a/swift/llm/infer/infer_engine/__init__.py b/swift/llm/infer/infer_engine/__init__.py index e8ec85f2e7..1af1b76bd1 100644 --- a/swift/llm/infer/infer_engine/__init__.py +++ b/swift/llm/infer/infer_engine/__init__.py @@ -11,7 +11,7 @@ from .infer_client import InferClient from .infer_engine import InferEngine from .base import BaseInferEngine - from .utils import prepare_generation_config, AdapterRequest, set_device_context + from .utils import prepare_generation_config, AdapterRequest, set_device_context, patch_vllm_memory_leak else: _extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')} _import_structure = { @@ -22,7 +22,7 @@ 'infer_client': ['InferClient'], 'infer_engine': ['InferEngine'], 'base': ['BaseInferEngine'], - 'utils': ['prepare_generation_config', 'AdapterRequest', 'set_device_context'], + 'utils': ['prepare_generation_config', 'AdapterRequest', 'set_device_context', 'patch_vllm_memory_leak'], } import sys diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 3820253a24..e9c334f558 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -10,7 +10,7 @@ from swift.plugin import Metric from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, RequestConfig from .patch import patch_auto_config, patch_auto_tokenizer -from .utils import AdapterRequest +from .utils import AdapterRequest, patch_vllm_memory_leak try: # After setting the environment variables, import vllm. This way of writing allows lint to pass. @@ -54,6 +54,7 @@ def __init__( distributed_executor_backend: Optional[str] = None, engine_kwargs: Optional[Dict[str, Any]] = None, ) -> None: + patch_vllm_memory_leak() self.use_async_engine = use_async_engine self.processor = get_model_tokenizer( model_id_or_path, diff --git a/swift/llm/infer/infer_engine/utils.py b/swift/llm/infer/infer_engine/utils.py index ebab195a27..b72ccec51c 100644 --- a/swift/llm/infer/infer_engine/utils.py +++ b/swift/llm/infer/infer_engine/utils.py @@ -506,3 +506,231 @@ def restore_torch_device_after_vllm_init(): current_device = torch.cuda.current_device() if origin_device != current_device: torch.cuda.set_device(origin_device) + + +def patch_vllm_memory_leak(): + import vllm + if version.parse(vllm.__version__) != version.parse('0.7.3'): + return + + def patch_vllm_abort_seq_group(): + from vllm.core.scheduler import Scheduler + from typing import Iterable, Dict + from vllm.sequence import SequenceGroupBase, SequenceGroup, SequenceStatus + + def new_abort_seq_group( + self, + request_id: Union[str, Iterable[str]], + seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None, + ) -> None: + if isinstance(request_id, str): + request_id = (request_id, ) + request_ids = set(request_id) + seq_id_to_seq_group = seq_id_to_seq_group or {} + for state_queue in [self.waiting, self.running, self.swapped]: + aborted_groups: List[SequenceGroup] = [] + for seq_group in state_queue: + # When n>1, seq_group.request_id looks like + # foo_parallel_sample_0, while request_ids is just foo, and we + # should resolve it as real_request_id to match. + if seq_group.request_id in seq_id_to_seq_group: + real_request_id = seq_id_to_seq_group[seq_group.request_id].group_id + else: + real_request_id = seq_group.request_id + if real_request_id in request_ids: + # Appending aborted group into pending list. + aborted_groups.append(seq_group) + # We can't remove real_request_id in request_ids here, + # because there may be other seq groups sharing the same + # real_request_id + for aborted_group in aborted_groups: + # Remove the sequence group from the state queue. + state_queue.remove(aborted_group) + # Remove the aborted request from the Mamba cache. + self._finished_requests_ids.append(aborted_group.request_id) + for seq in aborted_group.get_seqs(): + if seq.is_finished(): + continue + seq.status = SequenceStatus.FINISHED_ABORTED + self.free_seq(seq) + if aborted_group.request_id in seq_id_to_seq_group: + del seq_id_to_seq_group[aborted_group.request_id] + + self._free_seq_group_cross_attn_blocks(aborted_group) + + origin_method = Scheduler.abort_seq_group + Scheduler._old_abort_seq_group = origin_method + Scheduler.abort_seq_group = new_abort_seq_group + + def patch_vllm_engine(): + from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState + from vllm.outputs import PoolingRequestOutput, RequestOutput + from vllm.sequence import ExecuteModelRequest + + def new_abort_request(self, request_id) -> None: + for scheduler in self.scheduler: + scheduler.abort_seq_group(request_id, seq_id_to_seq_group=self.seq_id_to_seq_group) + + origin_method = LLMEngine.abort_request + LLMEngine._old_abort_request = origin_method + LLMEngine.abort_request = new_abort_request + + def new_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: + if self.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError('Pipeline parallelism is only supported through AsyncLLMEngine ' + 'as performance will be severely degraded otherwise.') + + # For llm_engine, there is no pipeline parallel support, so the engine + # used is always 0. + virtual_engine = 0 + + # These are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + allow_async_output_proc = cached_outputs.allow_async_output_proc + + ctx = self.scheduler_contexts[virtual_engine] + + # Clear outputs for each new scheduler iteration + ctx.request_outputs.clear() + + # Skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + # The scheduler is also skipped if a single request caused the last + # engine step to fail, and the previous schedule needs to be rerun. + if not self._has_remaining_steps(seq_group_metadata_list): + # Schedule iteration + (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) = self.scheduler[virtual_engine].schedule() + + ctx.seq_group_metadata_list = seq_group_metadata_list + ctx.scheduler_outputs = scheduler_outputs + + finished_requests_ids = self.scheduler[virtual_engine].get_and_reset_finished_requests_ids() + # When n>1, elements in self.seq_id_to_seq_group should be deleted + # here, otherwise memory leaks. + for finished_request_id in finished_requests_ids: + if finished_request_id in self.seq_id_to_seq_group: + del self.seq_id_to_seq_group[finished_request_id] + + # Maybe switch from async mode to sync mode + if not allow_async_output_proc and len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + + if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step(virtual_engine, seq_group_metadata_list, + scheduler_outputs, allow_async_output_proc) + else: + finished_requests_ids = list() + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None + + if not scheduler_outputs.is_empty(): + + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(virtual_engine) + + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) + + if allow_async_output_proc: + execute_model_req.async_callback = self.async_callbacks[virtual_engine] + + outputs = self.model_executor.execute_model(execute_model_req=execute_model_req) + + # We need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(virtual_engine, outputs) + else: + # Nothing scheduled => If there is pending async postprocessor, + # then finish it here. + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + # No outputs in this case + outputs = [] + + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # clear the cache if we have finished all the steps. + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[0] = SchedulerOutputState() + + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + + # Add results to the output_queue + ctx.append_output( + outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True, + is_first_step_output=is_first_step_output) + + if outputs and allow_async_output_proc: + assert len(outputs) == 1, ('Async postprocessor expects only a single output set') + + self._advance_to_next_step(outputs[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + # Check if need to run the usual non-async path + if not allow_async_output_proc: + self._process_model_outputs(ctx=ctx) + + # Log stats. + self.do_log_stats(scheduler_outputs, outputs) + + # Tracing + self.do_tracing(scheduler_outputs) + else: + # Multi-step case + return ctx.request_outputs + + if not self.has_unfinished_requests(): + # Drain async postprocessor (if exists) + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + assert len(ctx.output_queue) == 0 + + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + self.model_executor.stop_remote_worker_execution_loop() + + return ctx.request_outputs + + origin_method = LLMEngine.step + LLMEngine._old_step = origin_method + LLMEngine.step = new_step + + patch_vllm_abort_seq_group() + patch_vllm_engine()