diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 36d805a32db7..98fe36d0fb79 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -87,6 +87,11 @@ def parse_args(): default=0.8, help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), ) + parser.add_argument( + "--enable-dbo", + action="store_true", + help=("Enable microbatched execution"), + ) parser.add_argument( "--compilation-config", type=int, @@ -113,6 +118,7 @@ def main( max_model_len, compilation_config, gpu_memory_utilization, + enable_dbo, quantization, ): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) @@ -167,6 +173,7 @@ def start(rank): max_num_seqs=max_num_seqs, max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, + enable_dbo=enable_dbo, quantization=quantization, compilation_config=compilation_config, ) @@ -227,6 +234,7 @@ def start(rank): args.max_model_len, args.compilation_config, args.gpu_memory_utilization, + args.enable_dbo, args.quantization, ), ) diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 3fc1011d5042..c74dbb3ebb17 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -6,7 +6,7 @@ from tests.v1.attention.test_attention_backends import BATCH_SPECS from tests.v1.attention.utils import create_common_attn_metadata -from vllm.v1.attention.backends.utils import (UbatchSlice, +from vllm.v1.attention.backends.utils import (UBatchSlice, _make_metadata_with_slice, slice_query_start_locs, split_attn_metadata) @@ -106,7 +106,7 @@ def mixed_small_metadata(): def test_make_metadata_with_slice_decode_batch(small_decode_metadata): """Test slicing decode batch metadata""" # Split first request only - ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1)) + ubatch_slice = UBatchSlice(slice(0, 1), slice(0, 1)) result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata) @@ -120,7 +120,7 @@ def test_make_metadata_with_slice_decode_batch(small_decode_metadata): def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata): """Test slicing mixed batch metadata""" - ubatch_slice = UbatchSlice(slice(1, 3), + ubatch_slice = UBatchSlice(slice(1, 3), slice(1, 7)) # Requests 1-3, tokens 1-7 result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata) @@ -137,8 +137,8 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): num_tokens = large_decode_metadata.num_reqs mid_point = num_tokens // 2 ubatch_slices = [ - UbatchSlice(slice(0, mid_point), slice(0, mid_point)), - UbatchSlice(slice(mid_point, num_tokens), slice(mid_point, + UBatchSlice(slice(0, mid_point), slice(0, mid_point)), + UBatchSlice(slice(mid_point, num_tokens), slice(mid_point, num_tokens)), ] diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index ddedc61aae29..ccab04628a16 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -365,7 +365,9 @@ def create_deterministic_logits(token_ids): # Mock runner for attention metadata building proposer.runner = mock.MagicMock() proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder + proposer.runner.attn_groups[0][0].metadata_builders = [ + attn_metadata_builder + ] result = proposer.propose(target_token_ids=target_token_ids, target_positions=target_positions, @@ -489,7 +491,9 @@ def create_deterministic_logits(token_ids, k: int): # Mock runner for attention metadata building. proposer.runner = mock.MagicMock() proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder + proposer.runner.attn_groups[0][0].metadata_builders = [ + attn_metadata_builder + ] # Setup inputs for the proposer. target_token_ids = torch.randint(0, diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 6bb0fef23719..535802585d18 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2848,6 +2848,14 @@ def __post_init__(self): "when cudagraph_mode piecewise cudagraphs is used, "\ f"cudagraph_mode={self.compilation_config.cudagraph_mode}" + if self.parallel_config.enable_dbo: + a2a_backend = envs.VLLM_ALL2ALL_BACKEND + assert a2a_backend == "deepep_low_latency", \ + "Microbatching currently only supports the deepep_low_latency "\ + f"all2all backend. {a2a_backend} is not supported. To fix set "\ + "the VLLM_ALL2ALL_BACKEND environment variable to "\ + "deepep_low_latency and install the DeepEP kerenls." + if not self.instance_id: self.instance_id = random_uuid()[:5] diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 231406bf6052..8e92e54a9678 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -137,6 +137,14 @@ class ParallelConfig: disable_custom_all_reduce: bool = False """Disable the custom all-reduce kernel and fall back to NCCL.""" + enable_dbo: bool = False + """Enable microbatching for the model executor.""" + + dbo_decode_token_threshold: int = 32 + """The threshold for microbatching. If the number of tokens in the + request is greater than this threshold, microbatching will be used. + Otherwise, the request will be processed in a single batch.""" + ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 7c0f30b9aab8..427fd040fcb7 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -251,9 +251,4 @@ def get_handle(self, kwargs): logger.debug("DeepEP all2all args %s", buffer_kwargs) handle: deep_ep.Buffer = self.handle_cache.get_or_create( buffer_kwargs, deep_ep.Buffer) - # It is dangerous to set num sms outside this function. num_sms is not - # a part of the hash-key that identifies this object. If we are in a - # situation where we make objects with different num_sms, the hash key - # in get_or_create must be updated. - handle.set_num_sms(self.num_sms) return handle diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 20d998d613d4..4831cb5348c7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -327,6 +327,9 @@ class EngineArgs: data_parallel_hybrid_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + enable_dbo: bool = ParallelConfig.enable_dbo + dbo_decode_token_threshold: int = \ + ParallelConfig.dbo_decode_token_threshold eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb expert_placement_strategy: ExpertPlacementStrategy = \ @@ -695,6 +698,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) + parallel_group.add_argument("--enable-dbo", + **parallel_kwargs["enable_dbo"]) + parallel_group.add_argument( + "--dbo-decode-token-threshold", + **parallel_kwargs["dbo_decode_token_threshold"]) parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) parallel_group.add_argument("--eplb-config", @@ -1339,6 +1347,8 @@ def create_engine_config( data_parallel_backend=self.data_parallel_backend, data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, + enable_dbo=self.enable_dbo, + dbo_decode_token_threshold=self.dbo_decode_token_threshold, enable_eplb=self.enable_eplb, eplb_config=self.eplb_config, expert_placement_strategy=self.expert_placement_strategy, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index b3ddd7b9a739..3b535423f7bc 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -14,6 +14,7 @@ from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -97,6 +98,53 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int, dist.all_reduce(num_tokens_tensor, group=group) return num_tokens_tensor.cpu() + @staticmethod + def should_ubatch_across_dp( + should_ubatch: bool, orig_num_tokens_per_ubatch: int, + padded_num_tokens_per_ubatch: int, dp_size: int, + dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]: + """ + 1. Decides if each DP rank is going to microbatch. Either all ranks + run with microbatching or none of them do. If this function decides + not to run with microbatching. It will "abort" meaning that no padding + information will be returned to the caller. It will return (False, None) + + 2. Determines the total number of tokens that each rank will run. + All ranks will be padded out so that the run with the same number + of tokens + + Returns: tuple[ + should_ubatch: Are all DP ranks going to microbatch + num_tokens_after_padding: A tensor containing the total number of + tokens per-microbatch for each DP rank including padding. Will be + None if should_ubatch if False + ] + """ + + device = current_platform.device_type + tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32) + tensor[0][dp_rank] = orig_num_tokens_per_ubatch + tensor[1][dp_rank] = padded_num_tokens_per_ubatch + tensor[2][dp_rank] = 1 if should_ubatch else 0 + + from vllm.distributed.parallel_state import get_dp_group + dist.all_reduce(tensor, group=get_dp_group().device_group) + + result: bool = bool(torch.all(tensor[2] == 1).item()) + if not result: + return result, None + + orig_num_tokens_tensor = tensor[0, :] + padded_num_tokens_tensor = tensor[1, :] + + orig_min_num_tokens = int(orig_num_tokens_tensor.min().item()) + padded_max_num_tokens = int(padded_num_tokens_tensor.max().item()) + if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens): + logger.debug("Aborting ubatching %s %s", orig_min_num_tokens, + padded_max_num_tokens) + return False, None + return result, padded_num_tokens_tensor.cpu() + @staticmethod def make( parallel_config: ParallelConfig, @@ -119,14 +167,15 @@ def make( # If num_tokens_across_dp is None, it will be computed by all_reduce # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize - assert (num_tokens_across_dp is None - or num_tokens_across_dp[dp_rank] == batchsize) + assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank] + == batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}" if num_tokens_across_dp is None: num_tokens_across_dp = DPMetadata.num_tokens_across_dp( batchsize, dp_size, dp_rank) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp) cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) - return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu) + return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu, + num_tokens_across_dp) @contextmanager def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int): @@ -179,9 +228,12 @@ class ForwardContext: Type AttentionMetadata for v0, Type Dict[str, AttentionMetadata] for v1, map from layer_name of each attention layer to its attention metadata - set dynamically for each forward pass + Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one + for each microbatch. + Set dynamically for each forward pass """ - attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]] + attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"], + list[dict[str, "AttentionMetadata"]]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass @@ -191,6 +243,8 @@ class ForwardContext: cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE batch_descriptor: Optional[BatchDescriptor] = None + ubatch_slices: Optional[UBatchSlices] = None + def __post_init__(self): assert self.cudagraph_runtime_mode in [ CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ @@ -208,6 +262,39 @@ def get_forward_context() -> ForwardContext: return _forward_context +def create_forward_context( + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + dp_metadata: Optional[DPMetadata] = None, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor: Optional[BatchDescriptor] = None, + ubatch_slices: Optional[UBatchSlices] = None): + return ForwardContext(no_compile_layers=vllm_config.compilation_config. + static_forward_context, + virtual_engine=virtual_engine, + attn_metadata=attn_metadata, + dp_metadata=dp_metadata, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices) + + +@contextmanager +def override_forward_context(forward_context: Optional[ForwardContext]): + """A context manager that overrides the current forward context. + This is used to override the forward context for a specific + forward pass. + """ + global _forward_context + prev_context = _forward_context + _forward_context = forward_context + try: + yield + finally: + _forward_context = prev_context + + @contextmanager def set_forward_context( attn_metadata: Any, @@ -216,7 +303,8 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor: Optional[BatchDescriptor] = None): + batch_descriptor: Optional[BatchDescriptor] = None, + ubatch_slices: Optional[UBatchSlices] = None): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -225,6 +313,7 @@ def set_forward_context( need_to_track_batchsize = track_batchsize and attn_metadata is not None if need_to_track_batchsize: forward_start_time = time.perf_counter() + dp_metadata: Optional[DPMetadata] = None if vllm_config.parallel_config.data_parallel_size > 1 and ( attn_metadata is not None or num_tokens is not None): @@ -232,20 +321,14 @@ def set_forward_context( attn_metadata, num_tokens or 0, num_tokens_across_dp) - global _forward_context - prev_context = _forward_context - _forward_context = ForwardContext( - no_compile_layers=vllm_config.compilation_config. - static_forward_context, - virtual_engine=virtual_engine, - attn_metadata=attn_metadata, - dp_metadata=dp_metadata, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ) + forward_context = create_forward_context(attn_metadata, vllm_config, + virtual_engine, dp_metadata, + cudagraph_runtime_mode, + batch_descriptor, ubatch_slices) try: - yield + with override_forward_context(forward_context): + yield finally: global last_logging_time, batchsize_logging_interval if need_to_track_batchsize: @@ -282,5 +365,3 @@ def set_forward_context( logger.info(("Batchsize forward time stats " "(batchsize, count, median_time(ms)): %s"), forward_stats) - - _forward_context = prev_context diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 2a3ae478f3ea..92cbb1742974 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -191,7 +191,7 @@ def prepare_async( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> Callable: + ) -> tuple[Callable, mk.ReceiverType]: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -217,13 +217,14 @@ def prepare_async( a1q_scale = None a1_post_scale = a1_scale - return self._do_dispatch(tokens=a1q, - token_scales=a1q_scale, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts, - a1_scale=a1_post_scale, - quant_config=quant_config) + return (lambda *args: None, + self._do_dispatch(tokens=a1q, + token_scales=a1q_scale, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, + num_experts=num_experts, + a1_scale=a1_post_scale, + quant_config=quant_config)) def prepare( self, @@ -237,10 +238,11 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights, - topk_ids, num_experts, expert_map, - apply_router_weight_on_input, - quant_config) + (_, receiver) = self.prepare_async(a1, a1_scale, a2_scale, + topk_weights, topk_ids, num_experts, + expert_map, + apply_router_weight_on_input, + quant_config) return receiver() def finalize( diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 1849e49e0ab5..61f8297f0f14 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -11,6 +11,9 @@ TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input, normalize_batched_scales_shape) +from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, + dbo_maybe_run_recv_hook, + dbo_register_recv_hook, dbo_yield) # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 @@ -55,7 +58,7 @@ def __init__(self, # The dispatch function returns a handle that the combine function # requires. We store the handle here so it is available to the # combine function. - self.handle = None + self.handles: list[Optional[tuple]] = [None, None] self.num_dispatchers_ = num_dispatchers def num_dispatchers(self) -> int: @@ -123,13 +126,15 @@ def prepare_async( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> mk.ReceiverType: + ) -> tuple[Callable, mk.ReceiverType]: hidden_size = a1.size(1) assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ (f"Hidden Size {hidden_size} not in supported list of hidden sizes" f"{self.SUPPORTED_HIDDEN_SIZES}") + a2a_idx = dbo_current_ubatch_id() + if self.use_fp8_dispatch: assert hidden_size % 128 == 0, \ "DeepEP kernels quantize the inputs in blocks of shape 128" @@ -148,7 +153,7 @@ def prepare_async( a1 = a1 * topk_weights.to(a1.dtype) # Dispatch - expert_x, expert_num_tokens, self.handle, event, hook = \ + expert_x, expert_num_tokens, handle, _, hook= \ self.buffer.low_latency_dispatch(a1, topk_ids, self.max_tokens_per_rank, @@ -156,21 +161,19 @@ def prepare_async( use_fp8=self.use_fp8_dispatch, async_finish=False, return_recv_hook=True) + self.handles[a2a_idx] = handle - return lambda: self._receiver(hook, expert_x, expert_num_tokens, - a1_scale, a1.dtype, quant_config) + return (hook, lambda: self._receiver(expert_x, expert_num_tokens, + a1_scale, a1.dtype, quant_config)) def _receiver( self, - hook: Callable, expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], expert_num_tokens: torch.Tensor, a1_scale, a1_dtype, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - hook() - expert_x, expert_x_scale = self._do_quant( expert_x, a1_scale, a1_dtype, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) @@ -192,10 +195,12 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights, - topk_ids, num_experts, expert_map, - apply_router_weight_on_input, - quant_config) + hook, receiver = self.prepare_async(a1, a1_scale, a2_scale, + topk_weights, topk_ids, + num_experts, expert_map, + apply_router_weight_on_input, + quant_config) + hook() return receiver() def finalize( @@ -210,7 +215,11 @@ def finalize( assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") - assert self.handle is not None + + a2a_idx = dbo_current_ubatch_id() + do_recv_hook = dbo_enabled() + handle = self.handles[a2a_idx] + assert handle is not None combine_topk_weights = topk_weights if apply_router_weight_on_input: @@ -218,12 +227,16 @@ def finalize( combine_topk_weights = torch.ones_like(topk_weights) # TODO (varun) : Enable zero copy mode - _, event, hook = self.buffer.low_latency_combine( + dbo_maybe_run_recv_hook() + _, _, recv_hook = self.buffer.low_latency_combine( fused_expert_output, topk_ids, combine_topk_weights, - self.handle, + handle, async_finish=False, zero_copy=False, - return_recv_hook=False, + return_recv_hook=do_recv_hook, out=output) + if recv_hook is not None: + dbo_register_recv_hook(recv_hook) + dbo_yield() diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c62897c91816..d22bb253f4a7 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -38,6 +38,7 @@ from vllm.platforms.interface import CpuArchEnum from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up) +from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts @@ -992,16 +993,28 @@ def __init__( if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels or self.moe_config.use_flashinfer_cutlass_kernels): - self.batched_hidden_states = torch.zeros( - (moe.max_num_tokens, self.hidden_size), - dtype=moe.in_dtype, - device=torch.cuda.current_device()) + if vllm_config.parallel_config.enable_dbo: + self.batched_hidden_states = torch.zeros( + (2, moe.max_num_tokens, self.hidden_size), + dtype=moe.in_dtype, + device=torch.cuda.current_device()) + + # Note here we use `num_experts` which is logical expert count + self.batched_router_logits = torch.zeros( + (2, moe.max_num_tokens, num_experts), + dtype=moe.in_dtype, + device=torch.cuda.current_device()) + else: + self.batched_hidden_states = torch.zeros( + (moe.max_num_tokens, self.hidden_size), + dtype=moe.in_dtype, + device=torch.cuda.current_device()) - # Note here we use `num_experts` which is logical expert count - self.batched_router_logits = torch.zeros( - (moe.max_num_tokens, num_experts), - dtype=moe.in_dtype, - device=torch.cuda.current_device()) + # Note here we use `num_experts` which is logical expert count + self.batched_router_logits = torch.zeros( + (moe.max_num_tokens, num_experts), + dtype=moe.in_dtype, + device=torch.cuda.current_device()) @property def shared_experts(self) -> Optional[torch.nn.Module]: @@ -1708,14 +1721,29 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - assert (self.batched_hidden_states.size(0) # type: ignore + assert self.batched_hidden_states is not None + assert self.batched_router_logits is not None + # This is only true when DBO has been enabled in the config. + # Both tensors will have an outer dimension for the ubatch id + if self.batched_hidden_states.dim() == 3: + assert self.batched_router_logits.dim() == 3 + batch_buffer_idx = dbo_current_ubatch_id() + batched_hidden_states = self.batched_hidden_states[ + batch_buffer_idx, :] + batched_router_logits = self.batched_router_logits[ + batch_buffer_idx, :] + else: + batched_hidden_states = self.batched_hidden_states + batched_router_logits = self.batched_router_logits + + assert (batched_hidden_states.size(0) # type: ignore >= chunk_size) - assert (self.batched_router_logits.size(0) # type: ignore + assert (batched_router_logits.size(0) # type: ignore >= chunk_size) - staged_hidden_states = self.batched_hidden_states[: - chunk_size, :] # type: ignore - staged_router_logits = self.batched_router_logits[: - chunk_size, :] # type: ignore + staged_hidden_states = batched_hidden_states[: + chunk_size, :] # type: ignore + staged_router_logits = batched_router_logits[: + chunk_size, :] # type: ignore staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 281563c3bfca..33799b58d199 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -13,6 +13,8 @@ from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable _resize_cache, count_expert_num_tokens) from vllm.utils import cdiv +from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook, + dbo_register_recv_hook, dbo_yield) # # This file defines a set of base classes used to make MoE kernels more modular. @@ -226,7 +228,7 @@ def prepare_async( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> ReceiverType: + ) -> tuple[Callable, ReceiverType]: """ Perform any quantization (and/or) dispatching needed for this kernel but do not wait for results from other workers. @@ -496,6 +498,23 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int, return None +class SharedResizableBuffer: + + def __init__(self): + self.buffer = None + + def get(self, shape: tuple[int, ...], device: torch.device, + dtype: torch.dtype): + shape_numel = prod(shape) + if self.buffer is None or self.buffer.numel() < shape_numel: + self.buffer = torch.empty(shape_numel, device=device, dtype=dtype) + assert self.buffer.device == device, \ + f"Buffer device mismatch: {self.buffer.device} != {device}" + assert self.buffer.dtype == dtype, \ + f"Buffer dtype mismatch: {self.buffer.dtype} != {dtype}" + return self.buffer[:shape_numel].view(*shape) + + @final class FusedMoEModularKernel(torch.nn.Module): """ @@ -509,6 +528,9 @@ class FusedMoEModularKernel(torch.nn.Module): layer due to any layer specific state that may be used by the component objects. """ + fused_out_buffer = SharedResizableBuffer() + workspace13_buffer = SharedResizableBuffer() + workspace2_buffer = SharedResizableBuffer() def __init__( self, @@ -559,12 +581,12 @@ def _do_fused_experts( # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. - workspace13 = torch.empty(prod(workspace13_shape), - device=a1.device, - dtype=workspace_dtype) - workspace2 = torch.empty(prod(workspace2_shape), - device=a1.device, - dtype=workspace_dtype) + workspace13 = self.workspace13_buffer.get(workspace13_shape, + device=a1.device, + dtype=workspace_dtype) + workspace2 = self.workspace2_buffer.get(workspace2_shape, + device=a1.device, + dtype=workspace_dtype) assert fused_out is None or fused_out.shape == fused_out_shape, ( f"fused_out {fused_out.shape} but expected {fused_out_shape}") @@ -656,9 +678,9 @@ def _maybe_chunk_fused_experts( (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, expert_tokens_meta) - fused_out = torch.empty(fused_out_shape, - device=a1q.device, - dtype=a1.dtype) + fused_out = self.fused_out_buffer.get(fused_out_shape, + device=a1q.device, + dtype=a1.dtype) def slice_input_tensors( chunk_idx: int @@ -801,8 +823,10 @@ def forward( shared_output: torch.Tensor - if (not self.prepare_finalize.supports_async() - or self.shared_experts is None): + if not self.prepare_finalize.supports_async(): + # We shouldn't be running an a2a kernel that doesn't + # support async prepare/finalize + assert not dbo_enabled() # Run shared experts serially with dispatch. if self.shared_experts is not None: @@ -822,7 +846,8 @@ def forward( ) else: # Overlap shared expert compute with all2all dispatch. - receiver = self.prepare_finalize.prepare_async( + dbo_maybe_run_recv_hook() + hook, receiver = self.prepare_finalize.prepare_async( a1, a1_scale, a2_scale, @@ -834,8 +859,16 @@ def forward( self.fused_experts.quant_config, ) - assert self.shared_experts is not None - shared_output = self.shared_experts(a1) + if self.shared_experts is not None: + shared_output = self.shared_experts(a1) + + # If DBO is being used, register the hook with the ubatch context + # and call it in dbo_maybe_run_recv_hook instead of passing it to + # the receiver. + dbo_register_recv_hook(hook) + dbo_yield() + if not dbo_enabled(): + hook() (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, _expert_topk_weights) = receiver() diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 2ae79e69f555..b8c1c14317c4 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Callable, Optional, Union import pplx_kernels as pplx import torch @@ -103,7 +103,7 @@ def prepare_async( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> mk.ReceiverType: + ) -> tuple[Callable, mk.ReceiverType]: num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K @@ -214,41 +214,33 @@ def prepare_async( do_recv=False, ) - return lambda: self._receiver( + hook = lambda: self.a2a.dispatch( + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=topk_ids, + bound_m=bound_m, + do_send=False, + do_recv=True, + ) + + return (hook, lambda: self._receiver( expert_num_tokens, expert_x, expert_x_scale, - a1q, - a1q_scale, - topk_ids, - bound_m, orig_a_scale_block_shape, - ) + )) def _receiver( self, expert_num_tokens: torch.Tensor, expert_x: torch.Tensor, expert_x_scale: Optional[torch.Tensor], - a1q: torch.Tensor, - a1q_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - bound_m: Optional[torch.Tensor], orig_a_scale_block_shape: Optional[int], ) -> mk.PrepareResultType: - self.a2a.dispatch( - out_expert_num_tokens=expert_num_tokens, - out_expert_x=expert_x, - out_expert_x_scale=expert_x_scale, - dp_x=a1q, - dp_x_scale=a1q_scale, - indices=topk_ids, - bound_m=bound_m, - do_send=False, - do_recv=True, - ) - if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] assert expert_x_scale.ndim == 3 @@ -270,7 +262,7 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - receiver = self.prepare_async( + hook, receiver = self.prepare_async( a1, a1_scale, a2_scale, @@ -281,6 +273,7 @@ def prepare( apply_router_weight_on_input, quant_config, ) + hook() return receiver() def finalize( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index ead70c910a8f..63326d19194f 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -28,6 +28,7 @@ get_kv_connector_cache_layout) from vllm.logger import init_logger from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.ubatch_utils import UBatchSlice logger = init_logger(__name__) KVCacheLayoutType = Literal["NHD", "HND"] @@ -81,12 +82,6 @@ class CommonAttentionMetadata: encoder_seq_lens: Optional[np.ndarray] = None -@dataclass -class UbatchSlice: - request_slice: slice - token_slice: slice - - def slice_query_start_locs( query_start_loc: torch.Tensor, request_slice: slice, @@ -103,7 +98,7 @@ def slice_query_start_locs( def _make_metadata_with_slice( - ubatch_slice: UbatchSlice, + ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata: """ This function creates a new CommonAttentionMetadata that corresponds to @@ -133,6 +128,11 @@ def _make_metadata_with_slice( torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()) + # This is to account for the case where we are in a dummy + # run and query_start_loc_cpu is full of 0s + if max_query_len == 0: + max_query_len = attn_metadata.max_query_len + block_table_tensor = attn_metadata.block_table_tensor[request_slice] slot_mapping = attn_metadata.slot_mapping[token_slice] @@ -152,12 +152,12 @@ def _make_metadata_with_slice( def split_attn_metadata( - ubatch_slices: list[UbatchSlice], + ubatch_slices: list[UBatchSlice], common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: """ Creates a new CommonAttentionMetadata instance that corresponds to the - requests for each UbatchSlice in ubatch_slices. + requests for each UBatchSlice in ubatch_slices. Note: This function does not modify common_attn_metadata """ diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 7132d507c722..5154b29405b6 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -27,6 +27,7 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.ubatching import dbo_current_ubatch_id logger = init_logger(__name__) @@ -179,9 +180,11 @@ def propose( assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups - attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ - .build_for_drafting(common_attn_metadata=common_attn_metadata, - draft_index=0) + ubatch_id = dbo_current_ubatch_id() + attn_metadata_builder = \ + self.runner.attn_groups[0][0].metadata_builders[ubatch_id] + attn_metadata = attn_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, draft_index=0) # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. @@ -355,8 +358,9 @@ def propose_tree( hidden_states: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, ) -> list[torch.Tensor]: + ubatch_id = dbo_current_ubatch_id() tree_attn_metadata_builder = \ - self.runner.attn_groups[0][0].metadata_builder + self.runner.attn_groups[0][0].metadata_builders[ubatch_id] assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder) diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index d5ec19b86b06..619ed88ab5b2 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -64,8 +64,13 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: if not self.attn_groups[0]: return - mb = getattr(self.attn_groups[0][0], "metadata_builder", None) - if not isinstance(mb, TorchSDPAMetadataBuilderV1): + mb = getattr(self.attn_groups[0][0], "metadata_builders", None) + if isinstance(mb, list): + if not isinstance(mb[0], TorchSDPAMetadataBuilderV1): + return + mb[0].reorder_batch(self.input_batch, scheduler_output) + return + elif not isinstance(mb, TorchSDPAMetadataBuilderV1): # Encoder-only / rerank models do not benefit from reordering, # so we safely skip here. return diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d4d1f814afc0..2ae748dee43c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -15,6 +15,7 @@ import torch.distributed import torch.nn as nn from tqdm import tqdm +from typing_extensions import TypeAlias import vllm.envs as envs from vllm.attention import Attention, AttentionType @@ -55,11 +56,12 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up, supports_dynamo) +from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, create_fast_prefill_custom_backend, - reorder_batch_to_split_decodes_and_prefills) + reorder_batch_to_split_decodes_and_prefills, split_attn_metadata) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher # yapf conflicts with isort for this block # yapf: disable @@ -85,9 +87,12 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split +from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices from vllm.v1.worker.utils import is_residual_scattered_for_sp from .utils import (AttentionGroup, MultiModalBudget, @@ -105,6 +110,11 @@ logger = init_logger(__name__) +AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] +# list when ubatching is enabled +PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], + AttnMetadataDict] + # Wrapper for ModelRunnerOutput to support overlapped execution. class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): @@ -274,6 +284,7 @@ def __init__( # Request states. self.requests: dict[str, CachedRequestState] = {} + self.comm_stream = torch.cuda.Stream() # Input Batch # NOTE(Chen): Ideally, we should initialize the input batch inside @@ -872,10 +883,11 @@ def _get_encoder_seq_lens( return encoder_seq_lens def _prepare_inputs( - self, - scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata], - np.ndarray, Optional[CommonAttentionMetadata], int]: + self, scheduler_output: "SchedulerOutput" + ) -> tuple[PerLayerAttnMetadata, torch.Tensor, + Optional[SpecDecodeMetadata], np.ndarray, + Optional[CommonAttentionMetadata], int, Optional[UBatchSlices], + Optional[torch.Tensor]]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -947,6 +959,15 @@ def _prepare_inputs( self.query_start_loc.copy_to_gpu() query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + num_tokens_padded = num_tokens_unpadded + self.get_local_padding( + num_tokens_unpadded) + ubatch_slices, num_tokens_after_padding = \ + ubatch_split(max_num_scheduled_tokens, + num_tokens_unpadded, + num_tokens_padded, + self.vllm_config) + self.seq_lens.np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) @@ -1001,7 +1022,9 @@ def _prepare_inputs( logits_indices_padded = self._prepare_kv_sharing_fast_prefill( logits_indices) - attn_metadata: dict[str, Any] = {} + attn_metadata: PerLayerAttnMetadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] # Used in the below loop. query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] @@ -1075,7 +1098,7 @@ def _prepare_inputs( for attn_group in self.attn_groups[kv_cache_group_id]: # Prepare for cascade attention if enabled & beneficial. common_prefix_len = 0 - builder = attn_group.metadata_builder + builder = attn_group.get_metadata_builder() if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, @@ -1093,13 +1116,27 @@ def _prepare_inputs( num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs], ) - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args) - - for layer_name in attn_group.layer_names: - attn_metadata[layer_name] = attn_metadata_i + if ubatch_slices is not None: + common_attn_metadata_list = split_attn_metadata( + ubatch_slices, common_attn_metadata) + for ubid, common_attn_metadata in enumerate( + common_attn_metadata_list): + assert common_attn_metadata.max_query_len == 1 + attn_metadata_i = (attn_group.get_metadata_builder( + ubatch_id=ubid).build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata)) + for layer_name in kv_cache_group_spec.layer_names: + assert type(attn_metadata) is list + attn_metadata[ubid][layer_name] = attn_metadata_i + else: + assert isinstance(attn_metadata, dict) + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args) + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i # Hot-Swap lora model if self.lora_config: @@ -1107,7 +1144,8 @@ def _prepare_inputs( return (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens) + max_num_scheduled_tokens, ubatch_slices, + num_tokens_after_padding) def _compute_cascade_attn_prefix_len( self, @@ -1508,7 +1546,7 @@ def _extract_encoder_inputs( def get_model(self) -> nn.Module: # get raw model out of the cudagraph wrapper. - if isinstance(self.model, CUDAGraphWrapper): + if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)): return self.model.unwrap() return self.model @@ -1675,6 +1713,17 @@ def eplb_step(self, def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + """ + Determines the total number of tokens that each rank will run. + All ranks will be padded out so that they run with the same number + of tokens + + Returns: tuple[ + num_pad_tokens: The number of tokens that will be added to the batch + num_tokens_after_padding: A tensor containing the total number of + tokens for each DP rank including padding. + ] + """ dp_size = self.vllm_config.parallel_config.data_parallel_size dp_rank = self.vllm_config.parallel_config.data_parallel_rank @@ -1698,6 +1747,39 @@ def get_dp_padding(self, dtype=torch.int32) return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding + def get_local_padding(self, num_tokens_unpadded: int) -> int: + + num_tokens_padded = num_tokens_unpadded + + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_tokens_padded = self.vllm_config.pad_for_cudagraph( + num_tokens_unpadded) + else: + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if self.vllm_config.compilation_config.pass_config. \ + enable_sequence_parallelism and tp_size > 1: + num_tokens_padded = round_up(num_tokens_unpadded, tp_size) + + num_pad_tokens = num_tokens_padded - num_tokens_unpadded + return num_pad_tokens + + # This is where the second ubatch is adjusted to account for the padding. + # Should be called after attention metadata creation. This just pads + # the second ubatch slice out to the total number of tokens + # (num_tokens + padding) + def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, + num_total_tokens: int): + padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start, + num_total_tokens) + ubatch_slices[1] = UBatchSlice(padded_second_ubatch_slice, + padded_second_ubatch_slice) + def _pool( self, hidden_states: torch.Tensor, @@ -1758,15 +1840,22 @@ def _preprocess( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, + ubatch_slices: Optional[UBatchSlices] = None, + num_tokens_after_padding: Optional[torch.Tensor] = None, ) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor, Optional[IntermediateTensors], dict[str, Any]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) - # Padding for DP - num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) - num_input_tokens += num_pad + if ubatch_slices: + assert num_tokens_after_padding is not None + num_input_tokens = int(num_tokens_after_padding[0].item() * 2) + self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) + elif ubatch_slices is None: + num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) + num_pad, num_tokens_after_padding = self.get_dp_padding( + num_input_tokens) + num_input_tokens += num_pad # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1821,7 +1910,7 @@ def _preprocess( return ( num_scheduled_tokens, num_input_tokens, - num_tokens_across_dp, + num_tokens_after_padding, input_ids, inputs_embeds, positions, @@ -2027,7 +2116,8 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len) = self._prepare_inputs(scheduler_output) + max_query_len, ubatch_slices, num_tokens_after_padding + ) = self._prepare_inputs(scheduler_output) finally: if self.prepare_inputs_event is not None: @@ -2042,7 +2132,11 @@ def execute_model( positions, intermediate_tensors, model_kwargs, - ) = self._preprocess(scheduler_output, intermediate_tensors) + ) = self._preprocess(scheduler_output, intermediate_tensors, + ubatch_slices, num_tokens_after_padding) + + if ubatch_slices is not None: + num_input_tokens = num_input_tokens // 2 uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( @@ -2062,6 +2156,7 @@ def execute_model( num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices, ), record_function_or_nullcontext("Forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output): @@ -2441,10 +2536,18 @@ def load_model(self, eep_scale_up: bool = False) -> None: # CudagraphWraper and CudagraphDispatcher of vllm. # wrap the model with full cudagraph wrapper if needed. - if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + if self.compilation_config.cudagraph_mode.has_full_cudagraphs() \ + and not self.parallel_config.enable_dbo: self.model = CUDAGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) + elif self.parallel_config.enable_dbo: + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.model = UBatchWrapper(self.model, self.vllm_config, + CUDAGraphMode.FULL, self.device) + else: + self.model = UBatchWrapper(self.model, self.vllm_config, + CUDAGraphMode.NONE, self.device) def reload_weights(self) -> None: assert getattr(self, "model", None) is not None, \ @@ -2642,6 +2745,7 @@ def _dummy_run( cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, force_attention: bool = False, uniform_decode: bool = False, + allow_microbatching: bool = False, skip_eplb: bool = False, is_profile: bool = False, create_mixed_batch: bool = False, @@ -2667,12 +2771,30 @@ def _dummy_run( (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run """ + ubatch_enabled = self.parallel_config.enable_dbo + num_tokens_across_dp = None + num_pad = 0 + should_ubatch = False + if ubatch_enabled: + should_ubatch = num_tokens >= \ + self.parallel_config.dbo_decode_token_threshold and \ + allow_microbatching + + (should_ubatch, num_tokens_across_dp) = get_dp_padding_ubatch( + num_tokens, num_tokens, should_ubatch, self.vllm_config) + + # Currently the dummy run should only be ubatching during + # cuda graph capture, meaning all DP ranks should already + # have the same batch size + if num_tokens_across_dp is not None: + assert int(num_tokens_across_dp[0]) == num_tokens // 2 + assert cudagraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } - # Padding for DP - num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + if not should_ubatch: + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_tokens += num_pad # If cudagraph_mode.decode_mode() == FULL and @@ -2690,6 +2812,10 @@ def _dummy_run( # for GQA/MQA. max_query_len = self.uniform_decode_query_len if uniform_decode else \ num_tokens + if allow_microbatching: + assert self.uniform_decode_query_len == 1 + assert uniform_decode is True + assert max_query_len == 1 # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -2728,12 +2854,28 @@ def _dummy_run( num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - attn_metadata: Optional[dict[str, Any]] = None + ubatch_slices = None + # We currently only microbatch if the number of tokens is + # over a certain threshold. + if should_ubatch: + # We only support decode-only cudagraphs + assert num_reqs == num_tokens + assert num_tokens % 2 == 0 + ubatch_slices = [ + UBatchSlice(slice(0, num_reqs // 2), slice(0, + num_tokens // 2)), + UBatchSlice(slice(num_reqs // 2, num_reqs), + slice(num_tokens // 2, num_tokens)) + ] + + attn_metadata: Optional[PerLayerAttnMetadata] = None # If force_attention is True, we always capture attention. Otherwise, # it only happens for cudagraph_runtime_mode=FULL. if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: attn_metadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] if create_mixed_batch: # In the mixed batch mode (used for FI warmup), we use @@ -2766,12 +2908,26 @@ def _dummy_run( slot_mapping=self.input_batch. block_table[kv_cache_group_id].slot_mapping[:num_tokens], causal=True) - for attn_group in self.attn_groups[kv_cache_group_id]: - attn_metadata_i = attn_group.metadata_builder\ - .build_for_cudagraph_capture(common_attn_metadata) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i + if ubatch_slices is not None: + common_attn_metadata_list = split_attn_metadata( + ubatch_slices, common_attn_metadata) + for ubid, common_attn_metadata in enumerate( + common_attn_metadata_list): + assert common_attn_metadata.max_query_len == 1 + attn_metadata_i = (attn_group\ + .get_metadata_builder(ubatch_id=ubid)\ + .build_for_cudagraph_capture(common_attn_metadata)) + for layer_name in kv_cache_group_spec.layer_names: + assert type(attn_metadata) is list + attn_metadata[ubid][ + layer_name] = attn_metadata_i + else: + assert type(attn_metadata) is dict + attn_metadata_i = attn_group.get_metadata_builder()\ + .build_for_cudagraph_capture(common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens, remove_lora): @@ -2818,13 +2974,16 @@ def _dummy_run( f"Cudagraph runtime mode mismatch at dummy_run. " f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") + if ubatch_slices is not None: + num_tokens = num_tokens // 2 with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor): + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices): outputs = self.model( input_ids=input_ids, positions=positions, @@ -3096,6 +3255,7 @@ def freeze_gc(): set_cudagraph_capturing_enabled(True) with freeze_gc(), graph_capture(device=self.device): cudagraph_mode = self.compilation_config.cudagraph_mode + assert cudagraph_mode is not None if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: cudagraph_runtime_mode = cudagraph_mode.mixed_mode() @@ -3153,6 +3313,35 @@ def _capture_cudagraphs(self, compilation_cases: list[int], desc="Capturing CUDA graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", cudagraph_runtime_mode.name)) + enable_dbo = self.parallel_config.enable_dbo + # DBO Only supports running Full cudagraphs with uniform + # decode lengths + if enable_dbo and uniform_decode: + for num_tokens in compilation_cases: + # If the number of tokens is greater than the microbatching + # threshold, don't generate a microbatched cudagraph + if (num_tokens + < self.parallel_config.dbo_decode_token_threshold): + continue + + # Warmup + for _ in range( + self.compilation_config.cudagraph_num_of_warmups): + force_attention = ( + cudagraph_runtime_mode == CUDAGraphMode.FULL) + self._dummy_run(num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=True, + allow_microbatching=True, + skip_eplb=True) + + # Graph Capture + self._dummy_run(num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + uniform_decode=True, + allow_microbatching=True, + skip_eplb=True) # We skip EPLB here since we don't want to record dummy metrics for num_tokens in compilation_cases: for _ in range(self.compilation_config.cudagraph_num_of_warmups): @@ -3219,14 +3408,23 @@ def create_attn_groups( ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] for attn_backend, layer_names in attn_backends_map.items(): - attn_metadata_builder_i = attn_backend.get_builder_cls()( + attn_metadata_builders = [] + attn_metadata_builders.append(attn_backend.get_builder_cls()( kv_cache_spec, layer_names, self.vllm_config, self.device, - ) + )) + if self.parallel_config.enable_dbo: + attn_metadata_builders.append( + attn_backend.get_builder_cls()( + kv_cache_spec, + layer_names, + self.vllm_config, + self.device, + )) attn_group = AttentionGroup(attn_backend, - attn_metadata_builder_i, + attn_metadata_builders, layer_names) attn_groups.append(attn_group) return attn_groups @@ -3246,11 +3444,10 @@ def initialize_cudagraph_capture(self) -> None: min_cg_builder_name = None for attn_group in self._attn_group_iterator(): - builder = attn_group.metadata_builder + builder = attn_group.get_metadata_builder() if builder.cudagraph_support.value < min_cg_support.value: min_cg_support = builder.cudagraph_support min_cg_builder_name = builder.__class__.__name__ - # Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.cudagraph_mode # check cudagraph for mixed batch is supported @@ -3316,7 +3513,7 @@ def calculate_reorder_batch_threshold(self) -> None: is compatible (e.g., decode threshold is the same) """ for group in self._attn_group_iterator(): - attn_metadata_builder_i = group.metadata_builder + attn_metadata_builder_i = group.get_metadata_builder() # check that if any backends reorder batches; that the reordering # is compatible (e.g., decode threshold is the same) diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py new file mode 100644 index 000000000000..5012ad0483c8 --- /dev/null +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +import threading +from typing import Any, Callable, Optional + +import torch + +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.forward_context import (create_forward_context, get_forward_context, + override_forward_context) +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class UbatchMetadata: + context: UBatchContext + input_ids: torch.Tensor + positions: torch.Tensor + inputs_embeds: Optional[torch.Tensor] + intermediate_tensors: Optional[IntermediateTensors] + num_tokens: int + + +@dataclasses.dataclass +class CUDAGraphMetaData: + cudagraph: torch.cuda.CUDAGraph + ubatch_metadata: UbatchMetadata + outputs: Optional[Any] = None + + +class UBatchWrapper: + + def __init__(self, runnable: Callable, vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, device: torch.cuda.device): + self.runnable = runnable + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.comm_stream = torch.cuda.Stream(device=device) + # Two ubatch threads plus the main thread + self.ready_barrier = threading.Barrier(3) + + self.cudagraphs: dict[int, CUDAGraphMetaData] = {} + + self.cudagraph_wrapper = None + self.graph_pool = None + if runtime_mode is not CUDAGraphMode.NONE: + self.cudagraph_wrapper = CUDAGraphWrapper( + runnable, vllm_config, runtime_mode=runtime_mode) + self.graph_pool = current_platform.get_global_graph_pool() + + def __getattr__(self, key: str): + # allow accessing the attributes of the runnable. + if hasattr(self.runnable, key): + return getattr(self.runnable, key) + raise AttributeError(f"Attribute {key} not exists in the runnable of " + f"cudagraph wrapper: {self.runnable}") + + def unwrap(self) -> Callable: + # in case we need to access the original runnable. + return self.runnable + + def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor: + """ + Capture a cudagraph for a microbatched run. + + The logic here is somewhat complicated because we need to make sure that + each of the ubatch threads initialize the cuda context before we start + the graph capture. + + The flow is as follows: + 1. The main thread starts up each ubatch thread. Each thread will + initialize its cuda context (torch.cuda.current_blas_handle()) + before going to sleep upon entering the ubatch_context. + + 2. The main thread starts the graph capture and wakes up the first + ubatch thread. + + 3. Each ubatch thread runs the model to completion and returns the + completed output tensors back to the main thread. + + 4. The main thread stores the captured cudagraph along with its metadata + and returns + """ + + @torch.inference_mode() + def _capture_ubatch_thread(results, ubatch_metadata): + ubatch_context = ubatch_metadata.context + with torch.cuda.stream(ubatch_context.compute_stream): + _ = torch.cuda.current_blas_handle() + with torch.cuda.stream(ubatch_context.comm_stream): + _ = torch.cuda.current_blas_handle() + with ubatch_context: + model_output = model( + input_ids=ubatch_metadata.input_ids, + positions=ubatch_metadata.positions, + intermediate_tensors=ubatch_metadata.intermediate_tensors, + inputs_embeds=ubatch_metadata.inputs_embeds, + ) + + results.append((ubatch_metadata.context.id, model_output)) + + results: list[tuple[int, torch.Tensor]] = [] + compute_stream = ubatch_metadata[0].context.compute_stream + num_tokens = ubatch_metadata[0].num_tokens + \ + ubatch_metadata[1].num_tokens + + # Ubatches will manually manage the forward context, so we override + # it to None here so we can have it restored correctly later + with override_forward_context(None): + ubatch_threads = [] + for metadata in ubatch_metadata: + thread = threading.Thread(target=_capture_ubatch_thread, + args=( + results, + metadata, + )) + ubatch_threads.append(thread) + thread.start() + self.ready_barrier.wait() # Wait for both threads to be ready + + # Capture the cudagraph + cudagraph_metadata = \ + CUDAGraphMetaData( + cudagraph=torch.cuda.CUDAGraph(), + ubatch_metadata=ubatch_metadata, + ) + with torch.cuda.graph(cudagraph_metadata.cudagraph, + stream=compute_stream, + pool=self.graph_pool): + ubatch_metadata[0].context.cpu_wait_event.set() + for thread in ubatch_threads: + thread.join() + sorted_results = [value for position, value in sorted(results)] + result = torch.cat(sorted_results, dim=0) + cudagraph_metadata.outputs = result + self.cudagraphs[num_tokens] = cudagraph_metadata + return cudagraph_metadata.outputs + + def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor: + + @torch.inference_mode() + def _ubatch_thread(results, model, ubatch_metadata): + with ubatch_metadata.context: + model_output = model( + input_ids=ubatch_metadata.input_ids, + positions=ubatch_metadata.positions, + intermediate_tensors=ubatch_metadata.intermediate_tensors, + inputs_embeds=ubatch_metadata.inputs_embeds, + ) + results.append((ubatch_metadata.context.id, model_output)) + + results: list[tuple[int, torch.Tensor]] = [] + + # Ubatch threads will manually manage the forward context, so we + # override it to None here so we can have it restored correctly + # after both threads have finished + with override_forward_context(None): + ubatch_threads = [] + for metadata in ubatch_metadata: + thread = threading.Thread(target=_ubatch_thread, + args=( + results, + model, + metadata, + )) + ubatch_threads.append(thread) + thread.start() + self.ready_barrier.wait() # Wait for both threads to be ready + ubatch_metadata[0].context.cpu_wait_event.set() + for thread in ubatch_threads: + thread.join() + sorted_results = [value for position, value in sorted(results)] + result = torch.cat(sorted_results, dim=0) + return result + + def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids, + positions, inputs_embeds, intermediate_tensors, + compute_stream, dp_metadata, batch_descriptor, + cudagraph_runtime_mode) -> list[UbatchMetadata]: + + # Create one forward context per ubatch + forward_contexts = [] + for i, ubatch_slice in enumerate(ubatch_slices): + forward_contexts.append( + create_forward_context( + attn_metadata[i] if attn_metadata is not None else None, + self.vllm_config, + dp_metadata=dp_metadata, + batch_descriptor=batch_descriptor, + cudagraph_runtime_mode=cudagraph_runtime_mode)) + + ubatch_ctxs = make_ubatch_contexts( + num_micro_batches=len(ubatch_slices), + comm_stream=self.comm_stream, + compute_stream=compute_stream, + forward_contexts=forward_contexts, + ready_barrier=self.ready_barrier) + + ubatch_metadata: list[UbatchMetadata] = [] + for i, ubatch_slice in enumerate(ubatch_slices): + sliced_input_ids, sliced_positions, sliced_inputs_embeds, \ + sliced_intermediate_tensors = \ + self._slice_model_inputs( + ubatch_slice.token_slice, input_ids, positions, + inputs_embeds, intermediate_tensors) + ubatch_metadata.append( + UbatchMetadata( + context=ubatch_ctxs[i], + input_ids=sliced_input_ids, + positions=sliced_positions, + inputs_embeds=sliced_inputs_embeds, + intermediate_tensors=sliced_intermediate_tensors, + num_tokens=ubatch_slice.token_slice.stop - + ubatch_slice.token_slice.start)) + + return ubatch_metadata + + def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions, + inputs_embeds, intermediate_tensors): + sliced_input_ids = input_ids[tokens_slice] + # if we are using mrope. Mrope adds an additional dimension to the + # positions tensor + if positions.ndim == 2: + sliced_positions = positions[:, tokens_slice] + else: + sliced_positions = positions[tokens_slice] + sliced_inputs_embeds = inputs_embeds[ + tokens_slice] if inputs_embeds else None + sliced_intermediate_tensors = intermediate_tensors[ + tokens_slice] if intermediate_tensors else None + + return (sliced_input_ids, sliced_positions, sliced_inputs_embeds, + sliced_intermediate_tensors) + + def __call__(self, *args, **kwargs): + forward_context = get_forward_context() + batch_descriptor = forward_context.batch_descriptor + ubatch_slices = forward_context.ubatch_slices + cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode + + # If there's no ubatching, just run the runnable object + if ubatch_slices is None: + if cudagraph_runtime_mode in (CUDAGraphMode.NONE, + CUDAGraphMode.PIECEWISE): + return self.runnable(*args, **kwargs) + else: + assert self.cudagraph_wrapper is not None + return self.cudagraph_wrapper(*args, **kwargs) + + attn_metadata = forward_context.attn_metadata + num_tokens = (ubatch_slices[0].token_slice.stop - + ubatch_slices[0].token_slice.start) * 2 + input_ids = kwargs['input_ids'] + positions = kwargs['positions'] + intermediate_tensors = kwargs['intermediate_tensors'] + inputs_embeds = kwargs['inputs_embeds'] + compute_stream = torch.cuda.current_stream() + + dp_metadata = forward_context.dp_metadata + + # We shouldn't be here unless we are running with multiple DP ranks + assert dp_metadata is not None + + if num_tokens not in self.cudagraphs \ + and cudagraph_runtime_mode is CUDAGraphMode.FULL: + ubatch_metadata = self._make_ubatch_metadata( + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + compute_stream=compute_stream, + dp_metadata=dp_metadata, + batch_descriptor=batch_descriptor, + cudagraph_runtime_mode=CUDAGraphMode.NONE) + + return self._capture_ubatches(ubatch_metadata, self.model) + elif num_tokens in self.cudagraphs: + cudagraph_metadata = self.cudagraphs[num_tokens] + cudagraph_metadata.cudagraph.replay() + return cudagraph_metadata.outputs + else: + ubatch_metadata = self._make_ubatch_metadata( + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + compute_stream=compute_stream, + dp_metadata=dp_metadata, + batch_descriptor=batch_descriptor, + cudagraph_runtime_mode=CUDAGraphMode.NONE) + return self._run_ubatches(ubatch_metadata, self.model) diff --git a/vllm/v1/worker/ubatch_splitting.py b/vllm/v1/worker/ubatch_splitting.py new file mode 100644 index 000000000000..650f0ec5138d --- /dev/null +++ b/vllm/v1/worker/ubatch_splitting.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm.config import VllmConfig +from vllm.forward_context import DPMetadata +from vllm.logger import init_logger +from vllm.utils import round_up +from vllm.v1.worker.ubatch_utils import (UBatchSlice, UBatchSlices, + is_second_ubatch_empty) + +logger = init_logger(__name__) + + +def should_ubatch_with_num_tokens( + should_ubatch: bool, + orig_num_tokens_per_ubatch: int, + padded_num_tokens_per_ubatch: int, + vllm_config: VllmConfig, +) -> tuple[bool, Optional[torch.Tensor]]: + dp_size = vllm_config.parallel_config.data_parallel_size + dp_rank = vllm_config.parallel_config.data_parallel_rank + return DPMetadata.should_ubatch_across_dp(should_ubatch, + orig_num_tokens_per_ubatch, + padded_num_tokens_per_ubatch, + dp_size, dp_rank) + + +def get_dp_padding_ubatch( + num_tokens_unpadded: int, num_tokens_padded: int, + should_attempt_ubatching: bool, + vllm_config: VllmConfig) -> tuple[bool, Optional[torch.Tensor]]: + """ + 1. Decides if each DP rank is going to microbatch. Either all ranks + run with microbatching or none of them do. If this function decides + not to run with microbatching. It will "abort" meaning that no padding + information will be returned to the caller. It will return (False, None) + + 2. Determines the total number of tokens that each rank will run. + All ranks will be padded out so that the run with the same number + of tokens + + Returns: tuple[ + should_ubatch: Are all DP ranks going to microbatch + num_tokens_after_padding: A tensor containing the total number of + tokens per-microbatch for each DP rank including padding. Will be + None if should_ubatch if False + ] + + """ + assert num_tokens_padded >= num_tokens_unpadded + dp_size = vllm_config.parallel_config.data_parallel_size + if dp_size == 1: + # Early exit. + return False, None + + # If this DP rank doesn't want to attempt microbatching + if not should_attempt_ubatching: + (should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens( + False, 0, 0, vllm_config) + assert should_ubatch is False + assert num_tokens_across_dp is None + return should_ubatch, num_tokens_across_dp + + # Round up to the next multiple of two for even divisibility + num_tokens_padded = round_up(num_tokens_padded, 2) + num_tokens_per_ubatch = num_tokens_padded // 2 + should_ubatch = True + + # Sanity Check that the existing padding isn't giving us an empty second + # ubatch. Abort if so + if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded): + logger.debug( + "Empty second µbatch detected: unpadded tokens: %s, padded " + "tokens: %s", num_tokens_unpadded, num_tokens_padded) + should_ubatch = False + + # Note that we compute the number of padded tokens per ubatch + (should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens( + should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch, + vllm_config) + if not should_ubatch: + assert num_tokens_across_dp is None + return should_ubatch, num_tokens_across_dp + + assert num_tokens_across_dp is not None + + max_tokens_across_dp_cpu = int(torch.max(num_tokens_across_dp).item()) + num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * + dp_size, + device="cpu", + dtype=torch.int32) + return should_ubatch, num_tokens_after_padding + + +def ubatch_split( + max_num_scheduled_tokens: int, + num_tokens_unpadded: int, + num_tokens_padded: int, + vllm_config: VllmConfig, +) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]: + """ + Coordinates amongst all DP ranks to determine if and how the full batch + should be split into microbatches. + + Returns: tuple[ + ubatch_slices: if this is set then all DP ranks have agreed to + microbatch + num_tokens_after_padding: A tensor containing the total number of + tokens per-microbatch for each DP rank including padding. Will be + None if ubatch_slices is None + ] + + """ + parallel_config = vllm_config.parallel_config + # Don't bother with the should_ubatch handshaking unless microbatching + # is enabled + if not parallel_config.enable_dbo: + return (None, None) + + # Check preconditions for microbatching + should_attempt_ubatching = \ + parallel_config.enable_dbo and \ + num_tokens_unpadded >= \ + parallel_config.dbo_decode_token_threshold \ + and max_num_scheduled_tokens == 1 + + # Don't microbatch unless every other DP worker is also microbatching + num_tokens_after_padding = None + (should_ubatch, num_tokens_after_padding) = get_dp_padding_ubatch( + num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching, + vllm_config) + if not should_ubatch: + return (None, None) + + # This doesn't actually pad the ubatch slices. It just initializes the + # split point to the padded value so that padding can be applied + # to the second ubatch in pad_out_ubatch_slice after attention + # metadata creation + assert num_tokens_after_padding is not None + total_num_tokens_per_ubatch = int(num_tokens_after_padding[0].item()) + padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch) + padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch, + num_tokens_unpadded) + + # Note there's an assumption here that there's 1 token per request + ubatch_slices = [ + UBatchSlice(padded_first_ubatch_slice, padded_first_ubatch_slice), + UBatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice) + ] + + return (ubatch_slices, num_tokens_after_padding) diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py new file mode 100644 index 000000000000..6716d171cc70 --- /dev/null +++ b/vllm/v1/worker/ubatch_utils.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +from typing_extensions import TypeAlias + + +@dataclass +class UBatchSlice: + request_slice: slice + token_slice: slice + + +UBatchSlices: TypeAlias = list[UBatchSlice] + + +def is_second_ubatch_empty(orig_num_tokens_per_ubatch: int, + padded_num_tokens_per_ubatch: int) -> bool: + return padded_num_tokens_per_ubatch >= 2 * orig_num_tokens_per_ubatch diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py new file mode 100644 index 000000000000..9aeaa9909dc8 --- /dev/null +++ b/vllm/v1/worker/ubatching.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading +from typing import Optional + +import torch + +from vllm import forward_context +from vllm.forward_context import ForwardContext +from vllm.utils import current_stream + +_THREAD_ID_TO_CONTEXT: dict = {} +_CURRENT_CONTEXTS: list[Optional['UBatchContext']] = [None, None] + + +class UBatchContext: + """ + Context manager for micro-batching synchronization using threading events. + """ + + def __init__(self, + id: int, + comm_stream: torch.cuda.Stream, + compute_stream: torch.cuda.Stream, + forward_context: ForwardContext, + ready_barrier: threading.Barrier, + cpu_wait_event: threading.Event, + cpu_signal_event: threading.Event, + gpu_comm_done_event: torch.cuda.Event, + gpu_compute_done_event: torch.cuda.Event, + schedule: str = "default"): + self.id = id + self.comm_stream = comm_stream + self.compute_stream = compute_stream + self.forward_context = forward_context + self.ready_barrier = ready_barrier + self.cpu_wait_event = cpu_wait_event + self.cpu_signal_event = cpu_signal_event + self.current_stream = compute_stream + self.gpu_comm_done_event = gpu_comm_done_event + self.gpu_compute_done_event = gpu_compute_done_event + self.schedule = schedule + self.recv_hook = None + + def __enter__(self): + global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT + _THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id + _CURRENT_CONTEXTS[self.id] = self + self.ready_barrier.wait() + + self.cpu_wait_event.wait() + self.cpu_wait_event.clear() + self._restore_context() + # Assume we start on the compute stream + assert current_stream() == self.compute_stream + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT + _CURRENT_CONTEXTS[self.id] = None + del _THREAD_ID_TO_CONTEXT[threading.get_ident()] + self.maybe_run_recv_hook() + self.cpu_signal_event.set() + self.cpu_wait_event.clear() + self.current_stream = self.compute_stream + torch.cuda.set_stream(self.current_stream) + return False + + def _restore_context(self): + forward_context._forward_context = self.forward_context + torch.cuda.set_stream(self.current_stream) + + def update_stream(self, stream): + self.current_stream = stream + torch.cuda.set_stream(self.current_stream) + + def _signal_comm_done(self): + self.gpu_comm_done_event.record(self.comm_stream) + + def _signal_compute_done(self): + self.gpu_compute_done_event.record(self.compute_stream) + + def _wait_compute_done(self): + self.comm_stream.wait_event(self.gpu_compute_done_event) + + def _wait_comm_done(self): + self.compute_stream.wait_event(self.gpu_comm_done_event) + + def _cpu_yield(self): + # It is critical for correctness that only one thread is running + # at a time. These asserts just make sure that this is the only + # thread running before waking the other one up and going to sleep + assert forward_context._forward_context == self.forward_context + assert current_stream() == self.current_stream + assert not self.cpu_wait_event.is_set() + + self.cpu_signal_event.set() + self.cpu_wait_event.wait() + self.cpu_wait_event.clear() + self._restore_context() + + def switch_to_comm_sync(self): + self._signal_compute_done() + self.update_stream(self.comm_stream) + self._wait_comm_done() + + def maybe_run_recv_hook(self): + if self.recv_hook is not None: + self.recv_hook() + self.recv_hook = None + + def yield_(self): + self.current_stream = current_stream() + self._cpu_yield() + if self.current_stream != current_stream(): + self.update_stream(self.current_stream) + + def yield_and_switch_from_compute_to_comm(self): + assert current_stream() == self.compute_stream + self._signal_compute_done() + self._cpu_yield() + assert self.current_stream == self.compute_stream + self.update_stream(self.comm_stream) + self._wait_compute_done() + + def yield_and_switch_from_comm_to_compute(self): + assert current_stream() == self.comm_stream + self._signal_comm_done() + self._cpu_yield() + assert self.current_stream == self.comm_stream + self.update_stream(self.compute_stream) + self._wait_comm_done() + + +def dbo_enabled() -> bool: + return len(_THREAD_ID_TO_CONTEXT) > 0 + + +def dbo_current_ubatch_id() -> int: + if len(_THREAD_ID_TO_CONTEXT) == 0: + return 0 + return _THREAD_ID_TO_CONTEXT[threading.get_ident()] + + +def _register_ubatch_function(func): + + def wrapper(*args, **kwargs): + if len(_THREAD_ID_TO_CONTEXT) > 0: + ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] + ctx = _CURRENT_CONTEXTS[ctx_idx] + func(ctx, *args, **kwargs) + + return wrapper + + +dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function( + UBatchContext.yield_and_switch_from_compute_to_comm) +dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function( + UBatchContext.yield_and_switch_from_comm_to_compute) +dbo_yield = _register_ubatch_function(UBatchContext.yield_) +dbo_maybe_run_recv_hook = _register_ubatch_function( + UBatchContext.maybe_run_recv_hook) +dbo_switch_to_comm_sync = _register_ubatch_function( + UBatchContext.switch_to_comm_sync) + + +def dbo_register_recv_hook(recv_hook): + if len(_THREAD_ID_TO_CONTEXT) > 0: + ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] + next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2] + next_ctx.recv_hook = recv_hook + + +def make_ubatch_contexts( + num_micro_batches: int, + compute_stream: torch.cuda.Stream, + comm_stream: torch.cuda.Stream, + forward_contexts: list[ForwardContext], + ready_barrier: threading.Barrier, + schedule: str = "default", +) -> list[UBatchContext]: + assert num_micro_batches == 2, "only been tested with 2 micro-batches" + """ + Create a context manager for micro-batching synchronization. + """ + cpu_events = [threading.Event() for _ in range(num_micro_batches)] + gpu_comm_done_events = [ + torch.cuda.Event() for _ in range(num_micro_batches) + ] + gpu_compute_done_events = [ + torch.cuda.Event() for _ in range(num_micro_batches) + ] + + assert len(forward_contexts) == 2 + + ctxs = [] + for i in range(num_micro_batches): + ctx = UBatchContext(id=i, + compute_stream=compute_stream, + comm_stream=comm_stream, + forward_context=forward_contexts[i], + ready_barrier=ready_barrier, + cpu_wait_event=cpu_events[i], + cpu_signal_event=cpu_events[(i + 1) % + num_micro_batches], + gpu_comm_done_event=gpu_comm_done_events[i], + gpu_compute_done_event=gpu_compute_done_events[i], + schedule=schedule) + ctxs.append(ctx) + + return ctxs diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 5ac7470c1ac9..fc831a73a75e 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -130,9 +130,17 @@ def get_max_items( @dataclass class AttentionGroup: backend: type[AttentionBackend] - metadata_builder: AttentionMetadataBuilder + metadata_builders: list[AttentionMetadataBuilder] layer_names: list[str] + def get_metadata_builder(self, + ubatch_id: Optional[int] = None + ) -> AttentionMetadataBuilder: + if ubatch_id is None: + return self.metadata_builders[0] + assert len(self.metadata_builders) > ubatch_id + return self.metadata_builders[ubatch_id] + def sanity_check_mm_encoder_outputs( mm_embeddings: MultiModalEmbeddings,