diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 27e44463a30a..695290ed74ab 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -29,6 +29,8 @@ steps: - pytest -v -s test_pynccl.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py - label: Engine Test command: pytest -v -s engine tokenization test_sequence.py test_config.py diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 91510dafc57a..aadbc441713f 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -177,8 +177,7 @@ def run_to_completion(profile_dir: Optional[str] = None): help='block size of key/value cache') parser.add_argument( '--enable-chunked-prefill', - type=bool, - default=False, + action='store_true', help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') parser.add_argument( diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e71338273d1e..6df1e1d628e6 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -74,25 +74,31 @@ def run_vllm( quantization_param_path: Optional[str], device: str, enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, gpu_memory_utilization: float = 0.9, download_dir: Optional[str] = None, ) -> float: from vllm import LLM, SamplingParams - llm = LLM(model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, - device=device, - enable_prefix_caching=enable_prefix_caching, - download_dir=download_dir) + llm = LLM( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + ) # Add the requests to the engine. for prompt, _, output_len in requests: @@ -213,15 +219,15 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - elapsed_time = run_vllm(requests, args.model, args.tokenizer, - args.quantization, args.tensor_parallel_size, - args.seed, args.n, args.use_beam_search, - args.trust_remote_code, args.dtype, - args.max_model_len, args.enforce_eager, - args.kv_cache_dtype, - args.quantization_param_path, args.device, - args.enable_prefix_caching, - args.gpu_memory_utilization, args.download_dir) + elapsed_time = run_vllm( + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.gpu_memory_utilization, + args.download_dir) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -335,6 +341,14 @@ def main(args: argparse.Namespace): "--enable-prefix-caching", action='store_true', help="enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') parser.add_argument('--download-dir', type=str, default=None, diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py new file mode 100644 index 000000000000..9ff07b3c0902 --- /dev/null +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -0,0 +1,70 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +It tests chunked prefill. Chunked prefill can be enabled by +enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens, +prefill requests are chunked. + +Run `pytest tests/models/test_chunked_prefill.py`. +""" +import pytest + +MODELS = [ + "facebook/opt-125m", + "meta-llama/Llama-2-7b-hf", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("enforce_eager", [False, True]) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, + enforce_eager: bool, + tensor_parallel_size: int, +) -> None: + if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16 + and not enforce_eager): + pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} " + "for high TP to save testing time.") + max_num_seqs = min(chunked_prefill_token_size, 256) + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + print(vllm_outputs[0]) + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 05e62ced5898..cce396bf4953 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -104,10 +104,10 @@ def test_chunk(): # One chunked prefill, and one decoding. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set(running) - # The first one is decoding. - assert seq_group_meta[0].token_chunk_size == 1 + # The first one is prefill. Scheduler guarantees ordering. + assert seq_group_meta[0].token_chunk_size == 56 # The second one is a chunked prefill. - assert seq_group_meta[1].token_chunk_size == 56 + assert seq_group_meta[1].token_chunk_size == 1 assert out.num_prefill_groups == 1 assert out.num_batched_tokens == 57 @@ -157,12 +157,12 @@ def test_complex(): # Decoding & chunked prefill & first chunk of 3rd request is scheduled. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert len(get_sequence_groups(out)) == 3 - # The first one is decoding. - assert seq_group_meta[0].token_chunk_size == 1 - # The second one is a chunked prefill. + # The first one is the first chunked prefill. + assert seq_group_meta[0].token_chunk_size == 7 + # The second one is the second new chunked prefill. assert seq_group_meta[1].token_chunk_size == 56 - # The third one is also chunked. - assert seq_group_meta[2].token_chunk_size == 7 + # The last one is decode. + assert seq_group_meta[2].token_chunk_size == 1 # Two of them are in chunked prefill. assert out.num_prefill_groups == 2 assert out.num_batched_tokens == 64 diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 1eba14d7a642..77aa90b12bf8 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -33,11 +33,16 @@ def test_models( dtype: str, max_tokens: int, ) -> None: + hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2) + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py new file mode 100644 index 000000000000..737b1f316951 --- /dev/null +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -0,0 +1,66 @@ +"""Compare the outputs of HF and distributed vLLM when using greedy sampling. +vLLM will allocate all the available memory, so we need to run the tests one +by one. The solution is to pass arguments (model name) by environment +variables. + +Run: +```sh +TEST_DIST_MODEL=facebook/opt-125m pytest \ + test_chunked_prefill_distributed.py +TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ + test_chunked_prefill_distributed.py +``` +""" +import os + +import pytest +import torch + +MODELS = [ + os.environ["TEST_DIST_MODEL"], +] + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, +) -> None: + # Add a chunked prefill config. + max_num_seqs = min(chunked_prefill_token_size, 256) + assert chunked_prefill_token_size != -1 + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + max_num_seqs=max_num_seqs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 442f8bdf3b4b..6f2086c4dd26 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -141,7 +141,7 @@ def server(zephyr_lora_files): "--max-cpu-loras", "2", "--max-num-seqs", - "128" + "128", ]) ray.get(server_runner.ready.remote()) yield server_runner diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 53a80d461964..cfe2539e3a05 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -12,7 +12,7 @@ "gpt2", "bigcode/tiny_starcoder_py", "EleutherAI/pythia-70m", - "bigscience/bloom-560m", + "bigscience/bloom-560m", # Testing alibi slopes. "microsoft/phi-2", "stabilityai/stablelm-3b-4e1t", # "allenai/OLMo-1B", # Broken diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 5b6f001f62fa..dcaae4af4a6f 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,14 +1,18 @@ import pytest import torch -from vllm.config import ModelConfig +from vllm.config import ModelConfig, SchedulerConfig from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_prompt(batch_size): - model_runner = ModelRunner(None, None, None, None, None) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=False) + model_runner = ModelRunner(None, None, scheduler_config, None, None) model_runner.set_block_size(16) prompt_lens = [] @@ -36,8 +40,10 @@ def test_prepare_prompt(batch_size): prompt_len - 1) selected_token_start_idx += prompt_len (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, - _, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) + _, _, + slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens + assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device @@ -45,8 +51,6 @@ def test_prepare_prompt(batch_size): assert torch.allclose(attn_metadata.prompt_lens_tensor, torch.tensor(prompt_lens, device=device)) assert attn_metadata.prompt_lens == prompt_lens - assert attn_metadata.num_prompt_tokens == sum(prompt_lens) - assert attn_metadata.num_generation_tokens == 0 assert attn_metadata.max_prompt_len == max(prompt_lens) # Test subquery start locs. @@ -83,23 +87,22 @@ def test_prepare_prompt(batch_size): assert torch.allclose(attn_metadata.block_tables, expected) # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is False - assert attn_metadata.kv_cache_dtype == "auto" - assert input_tokens.shape == (sum(prompt_lens), ) - assert input_positions.shape == (sum(prompt_lens), ) + assert len(input_tokens) == sum(prompt_lens) + assert len(input_positions) == sum(prompt_lens) torch.testing.assert_close(input_tokens, input_positions) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - assert input_tokens.shape == (sum(prompt_lens), ) - assert input_positions.shape == (sum(prompt_lens), ) + assert len(input_tokens) == sum(prompt_lens) + assert len(input_positions) == sum(prompt_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) - torch.testing.assert_close(input_tokens, input_positions) + assert input_tokens == input_positions actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, @@ -122,7 +125,12 @@ def test_prepare_decode_cuda_graph(batch_size): revision=None, enforce_eager=False, ) - model_runner = ModelRunner(model_config, None, None, None, None) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=False) + model_runner = ModelRunner(model_config, None, scheduler_config, None, + None) model_runner.set_block_size(16) prompt_lens = [] @@ -143,16 +151,15 @@ def test_prepare_decode_cuda_graph(batch_size): assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) - input_tokens, input_positions, attn_metadata, _, _, _ = ( + input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( model_runner._prepare_decode(seq_group_metadata_list)) + assert len(slot_mapping) == len(input_tokens) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is False assert attn_metadata.prompt_lens is None - assert attn_metadata.num_prompt_tokens == 0 - assert attn_metadata.num_generation_tokens == expected_bs assert attn_metadata.max_prompt_len is None assert attn_metadata.subquery_start_loc is None assert attn_metadata.seq_start_loc is None @@ -170,11 +177,10 @@ def test_prepare_decode_cuda_graph(batch_size): model_runner.get_max_block_per_batch()) # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is True - assert attn_metadata.kv_cache_dtype == "auto" - assert input_tokens.shape == (expected_bs, ) - assert input_positions.shape == (expected_bs, ) - torch.testing.assert_close(input_tokens, input_positions) + assert len(input_tokens) == expected_bs + assert len(input_positions) == expected_bs + assert input_tokens == input_positions # Verify Sampling expected_selected_token_indices = [] @@ -190,3 +196,148 @@ def test_prepare_decode_cuda_graph(batch_size): device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) + + +def test_empty_seq_group(): + """Verify prepare prompt and decode returns empty output.""" + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=False, + ) + model_runner = ModelRunner(model_config, None, None, None, None) + model_runner.set_block_size(16) + seq_group_metadata_list = [] + input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( + model_runner._prepare_decode(seq_group_metadata_list)) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + + (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, + _, _, + slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + assert len(return_prompt_lens) == 0 + + +@pytest.mark.parametrize("batch_size", list(range(2, 128))) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): + + def get_world_size(group=None): + return 1 + + def mock_get_process_group_ranks(group=None): + return [0] + + monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size) + monkeypatch.setattr(torch.distributed, "get_process_group_ranks", + mock_get_process_group_ranks) + + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=enforce_eager, + ) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=True) + model_runner = ModelRunner(model_config, + None, + scheduler_config, + None, + None, + is_driver_worker=True) + model_runner.set_block_size(16) + + # Add prefill requests. + prompt_lens = [] + seq_group_metadata_list = [] + prefill_metadata_list = [] + decode_metadata_list = [] + block_tables = {0: [1]} + prefill_batch_size = batch_size // 2 + decode_batch_size = batch_size - prefill_batch_size + for i in range(prefill_batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_lens.append(prompt_len) + seq_data = SequenceData(list(range(prompt_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + ) + assert seq_group_metadata.token_chunk_size == seq_data.get_len() + seq_group_metadata_list.append(seq_group_metadata) + prefill_metadata_list.append(seq_group_metadata) + + # Add decode requests + for i in range(prefill_batch_size, batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(prompt_len)) + seq_data = SequenceData(prompt_toks) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + ) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) + decode_metadata_list.append(seq_group_metadata) + + (input_tokens, input_positions, attn_metadata, _, _, _, + _) = model_runner.prepare_input_tensors(seq_group_metadata_list) + + prefill_meta_actual = attn_metadata.prefill_metadata + decode_meta_actual = attn_metadata.decode_metadata + + assert len(attn_metadata.slot_mapping) == len(input_tokens) + assert len(input_positions) == len(input_tokens) + assert attn_metadata.kv_cache_dtype == "auto" + assert attn_metadata.num_prefills == prefill_batch_size + if enforce_eager: + assert attn_metadata.num_decode_tokens == decode_batch_size + else: + assert attn_metadata.num_decode_tokens == _get_graph_batch_size( + decode_batch_size) + assert attn_metadata.num_prefill_tokens == sum(prompt_lens) + + # Verify attn metadata is consistent. We don't need to test individual + # values here because they are tested above. + prefill_meta = model_runner._prepare_prompt( + prefill_metadata_list).attn_metadata + decode_meta = model_runner._prepare_decode( + decode_metadata_list).attn_metadata + + for attr_expected, attr_actual in zip(vars(prefill_meta), + vars(prefill_meta_actual)): + assert attr_expected[1] == attr_actual[1] + for attr_expected, attr_actual in zip(vars(decode_meta), + vars(decode_meta_actual)): + assert attr_expected[1] == attr_actual[1] diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 9acb82c0df2c..7636b34a16fe 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,5 +1,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -8,4 +9,5 @@ "AttentionMetadata", "Attention", "get_attn_backend", + "AttentionMetadataPerStage", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index a03cf2dd7a6f..7a4ccecf702f 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar import torch @@ -47,7 +47,8 @@ def copy_blocks( @dataclass -class AttentionMetadata: +class AttentionMetadataPerStage: + """Attention metadata for a specific stage. I.e., prefill or decode.""" def asdict_zerocopy(self) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" @@ -59,6 +60,41 @@ def asdict_zerocopy(self) -> Dict[str, Any]: } +T = TypeVar("T", bound=AttentionMetadataPerStage) + + +@dataclass +class AttentionMetadata(Generic[T]): + """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # The attention metadata for prefill requests in a batch. + # None if there's no prefill requests in a batch. + prefill_metadata: Optional[T] + # The attention metadata for decode requests in a batch. + # None if there's no decode requests in a batch. + decode_metadata: Optional[T] + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + # The kv cache's data type. + kv_cache_dtype: str + + def __post_init__(self): + if self.num_prefill_tokens > 0: + assert self.num_prefills > 0 + assert self.prefill_metadata is not None + if self.num_decode_tokens > 0: + assert self.decode_metadata is not None + + class AttentionImpl(ABC): @abstractmethod @@ -80,7 +116,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + attn_metadata: AttentionMetadata[AttentionMetadataPerStage], kv_scale: float, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4e0d9d1418b3..12e8c4404b94 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -11,7 +11,8 @@ from flash_attn import flash_attn_varlen_func from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -53,7 +54,8 @@ def copy_blocks( @dataclass -class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -68,10 +70,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| @@ -107,18 +105,27 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): class FlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -155,7 +162,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, + attn_metadata: AttentionMetadata[FlashAttentionMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -188,52 +195,70 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - output = flash_attn_varlen_func( + out = flash_attn_varlen_func( q=query, k=key, v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prompt_len, + max_seqlen_k=prefill_meta.max_prompt_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. - output = PagedAttention.forward_prefix( + output[:num_prefill_tokens] = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 6019d917b449..e55435cd2c94 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -6,7 +6,8 @@ import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -51,7 +52,8 @@ def copy_blocks( @dataclass -class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): +class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -66,10 +68,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| @@ -117,6 +115,15 @@ class ROCmFlashAttentionImpl(AttentionImpl): The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| + |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -181,7 +188,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: ROCmFlashAttentionMetadata, + attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata], kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -218,9 +225,25 @@ def forward( kv_scale, ) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -230,63 +253,69 @@ def forward( key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) if self.use_naive_attn: - output = self.attn_fuc( + out = self.attn_fuc( query, key, value, - attn_metadata.prompt_lens, + prefill_meta.prompt_lens, self.scale, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: - output, _ = self.attn_func( + out, _ = self.attn_func( query, key, value, None, - attn_metadata.seq_start_loc, - attn_metadata.seq_start_loc, - attn_metadata.max_prompt_len, - attn_metadata.max_prompt_len, + prefill_meta.seq_start_loc, + prefill_meta.seq_start_loc, + prefill_meta.max_prompt_len, + prefill_meta.max_prompt_len, True, self.scale, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: - output = self.attn_func( + out = self.attn_func( q=query, k=key, v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prompt_len, + max_seqlen_k=prefill_meta.max_prompt_len, softmax_scale=self.scale, causal=True, ) - + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention - output = PagedAttention.forward_prefix( + output[:num_prefill_tokens] = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: + + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 9706e1910cb7..63904ea92987 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,7 +7,8 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -49,17 +50,14 @@ def copy_blocks( @dataclass -class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): +class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - slot_mapping: torch.Tensor prompt_lens: Optional[List[int]] prompt_lens_tensor: Optional[torch.Tensor] - num_prompt_tokens: int - num_generation_tokens: int max_subquery_len: Optional[int] = None max_prompt_len: Optional[int] = None @@ -113,7 +111,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: TorchSDPAMetadata, + attn_metadata: AttentionMetadata[TorchSDPAMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -142,36 +140,51 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: - if (kv_cache is None or attn_metadata.block_tables.numel() == 0): + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + if (kv_cache is None or prefill_meta.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - if attn_metadata.attn_bias is None: + if prefill_meta.attn_bias is None: if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.prompt_lens) # type: ignore + prefill_meta.prompt_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - attn_metadata.prompt_lens, self.sliding_window, + prefill_meta.prompt_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(attn_metadata.prompt_lens) - attn_metadata.attn_bias = att_masks + att_masks = [None] * len(prefill_meta.prompt_lens) + prefill_meta.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) start = 0 - output = torch.empty( - (num_tokens, self.num_heads, self.head_size), - dtype=query.dtype) - for prompt_len, mask in zip(attn_metadata.prompt_lens, - attn_metadata.attn_bias): + out = torch.empty((num_tokens, self.num_heads, self.head_size), + dtype=query.dtype) + for prompt_len, mask in zip(prefill_meta.prompt_lens, + prefill_meta.attn_bias): end = start + prompt_len sub_out = scaled_dot_product_attention( query[:, start:end, :], @@ -181,28 +194,32 @@ def forward( dropout_p=0.0, is_causal=not self.need_mask, scale=self.scale).movedim(query.dim() - 2, 0) - output[start:end, :, :] = sub_out + out[start:end, :, :] = sub_out start = end + assert out.shape == output[:num_prefill_tokens].shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention raise RuntimeError( "Torch SDPA backend doesn't support prefix decoding.") - else: + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + out = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, kv_scale, ) + assert out.shape == output[num_prefill_tokens:].shape + output[num_prefill_tokens:] # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 05b68bba5e6e..b745a04a143b 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -9,7 +9,8 @@ LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -54,7 +55,7 @@ def copy_blocks( @dataclass -class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): +class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): """Metadata for XFormersbackend. NOTE: Any python object stored here is not updated when it is @@ -65,19 +66,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor # (batch_size,). The prompt length per sequence. None if it is a decoding. prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| @@ -123,18 +115,27 @@ def __post_init__(self): class XFormersImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens --------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->| + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -170,7 +171,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: XFormersMetadata, + attn_metadata: AttentionMetadata[XFormersMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -202,59 +203,61 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention. # block tables are empty if the prompt does not have a cached # prefix. - if self.num_kv_heads != self.num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # TODO(woosuk): Use MQA/GQA kernels for higher performance. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - - output = self._run_memory_efficient_xformers_forward( - query, key, value, attn_metadata) + out = self._run_memory_efficient_xformers_forward( + query, key, value, prefill_meta) + assert out.shape == output[:num_prefill_tokens].shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. - output = PagedAttention.forward_prefix( + out = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: - # Decoding run. - output = PagedAttention.forward_decode( - query, + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + + if decode_meta := attn_metadata.decode_metadata: + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -275,13 +278,30 @@ def _run_memory_efficient_xformers_forward( """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. + See https://facebookresearch.github.io/xformers/components/ops.html + for API spec. + Args: - output: shape = [num_prompt_tokens, num_heads, head_size] - query: shape = [num_prompt_tokens, num_heads, head_size] - key: shape = [num_prompt_tokens, num_kv_heads, head_size] - value: shape = [num_prompt_tokens, num_kv_heads, head_size] + output: shape = [num_prefill_tokens, num_heads, head_size] + query: shape = [num_prefill_tokens, num_heads, head_size] + key: shape = [num_prefill_tokens, num_kv_heads, head_size] + value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ + original_query = query + if self.num_kv_heads != self.num_heads: + # GQA/MQA requires the shape [B, M, G, H, K]. + # Note that the output also has the same shape (which is different + # from a spec from the doc). + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. @@ -302,6 +322,7 @@ def _run_memory_efficient_xformers_forward( # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. if self.alibi_slopes is None: + # Add the batch dimension. query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) @@ -312,14 +333,13 @@ def _run_memory_efficient_xformers_forward( attn_bias=attn_metadata.attn_bias[0], p=0.0, scale=self.scale) - - return out.view_as(query) + return out.view_as(original_query) # Attention with alibi slopes. # FIXME(woosuk): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. - output = torch.empty_like(query) + output = torch.empty_like(original_query) start = 0 for i, prompt_len in enumerate(attn_metadata.prompt_lens): end = start + prompt_len @@ -331,7 +351,7 @@ def _run_memory_efficient_xformers_forward( p=0.0, scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out.squeeze(0)) + output[start:end].copy_(out.view_as(original_query[start:end])) start += prompt_len return output diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9856654fc5f9..fc65ae108dbb 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,7 +4,8 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.abstract import (AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.selector import get_attn_backend @@ -41,7 +42,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata, + attn_metadata: AttentionMetadata[AttentionMetadataPerStage], kv_scale: float = 1.0, ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 256bffdf032e..2d918491d657 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -13,11 +13,6 @@ @dataclass class PagedAttentionMetadata: """Metadata for PagedAttention.""" - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor # (batch_size,). The length of context (tokens stored in KV cache) per # sequence. WARNING: When it is a prefill request, it doesn't include new # tokens. When it is for decoding, it includes a new token. @@ -31,7 +26,6 @@ class PagedAttentionMetadata: # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. block_tables: Optional[torch.Tensor] - kv_cache_dtype: str class PagedAttention: diff --git a/vllm/config.py b/vllm/config.py index bca250e92228..4102edbe01d3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -565,9 +565,16 @@ def __init__( if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens else: - # If max_model_len is too short, use 2048 as the default value for - # higher throughput. - self.max_num_batched_tokens = max(max_model_len, 2048) + if enable_chunked_prefill: + # For chunked prefill, choose the well-tuned batch size. + self.max_num_batched_tokens = 768 + else: + # If max_model_len is too short, use 2048 as the default value + # for higher throughput. + self.max_num_batched_tokens = max(max_model_len, 2048) + if enable_chunked_prefill: + logger.info("Chunked prefill is enabled (EXPERIMENTAL).") + self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.use_v2_block_manager = use_v2_block_manager diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 0ae53f937496..2942eab735a9 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -140,7 +140,11 @@ def _sort_by_lora_ids(self) -> bool: @property def lora_requests(self) -> Set[LoRARequest]: - return {g.seq_group.lora_request for g in self.scheduled_seq_groups} + return { + g.seq_group.lora_request + for g in self.scheduled_seq_groups + if g.seq_group.lora_request is not None + } @dataclass @@ -826,13 +830,12 @@ def _schedule_chunked_prefill(self): # Update swapped requests. self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) - return SchedulerOutputs( scheduled_seq_groups=(prefills.seq_groups + - running_scheduled.decode_seq_groups + running_scheduled.prefill_seq_groups + - swapped_in.decode_seq_groups + - swapped_in.prefill_seq_groups), + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups), num_prefill_groups=(len(prefills.seq_groups) + len(swapped_in.prefill_seq_groups) + len(running_scheduled.prefill_seq_groups)), @@ -907,7 +910,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. - is_prompt = i < scheduler_outputs.num_prefill_groups + is_prompt = seq_group.is_prefill() seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=is_prompt, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index cf15db099b30..1004d626b6a4 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -173,10 +173,18 @@ def broadcast_tensor_dict( torch.distributed.broadcast_object_list([metadata_list], src=src, group=group) + async_handles = [] for key, value in metadata_list: if isinstance(value, TensorMetadata): tensor = tensor_dict[key] - torch.distributed.broadcast(tensor, src=src, group=group) + async_handles.append( + torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True)) + for async_handle in async_handles: + async_handle.wait() + else: recv_metadata_list = [None] torch.distributed.broadcast_object_list(recv_metadata_list, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4b573992c06..daefddc01b43 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -386,9 +386,8 @@ def add_cli_args( 'prompt latency) before scheduling next prompt.') parser.add_argument( '--enable-chunked-prefill', - type=bool, - default=False, - help='If True, the prefill requests can be chunked based on the ' + action='store_true', + help='If set, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') parser.add_argument( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1c639af69654..ddfdda898a5c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -633,7 +633,10 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - self._process_sequence_group_outputs(seq_group, outputs) + # If uncomputed tokens > 0, it means prefill is chunked. + # We don't need to process outputs in that case. + if seq_group.get_num_uncomputed_tokens() == 0: + self._process_sequence_group_outputs(seq_group, outputs) # Free the finished sequence groups. self.scheduler.free_finished_seq_groups() diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index dd33868f7630..84a94091486d 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -267,12 +267,13 @@ def set_mapping( def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = x > self.base_layer.org_vocab_size - 1 - indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) + embedding_len = self.indices_len[3] + indices = self.embeddings_indices[1][:embedding_len].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, self.lora_a_stacked_2d, ) - indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) + indices = self.embeddings_indices[0][:embedding_len].view_as(x) full_output = self.base_layer.forward( x.add_(indices * added_tokens_mask)) diff --git a/vllm/sequence.py b/vllm/sequence.py index 576bbe8c4f6c..77029908c221 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -500,7 +500,8 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int): def get_num_uncomputed_tokens(self) -> int: num_uncomputed_tokens = 0 for seq in self.get_seqs(): - num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() + if not seq.is_finished(): + num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() return num_uncomputed_tokens def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1de4748b7bcc..47ad8f0c9b78 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,12 +1,14 @@ import contextlib import time -from typing import Dict, List, Optional, Set, Tuple +from enum import IntEnum +from typing import Dict, List, NamedTuple, Optional, Set, Tuple import numpy as np import torch import torch.nn as nn -from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, + get_attn_backend) from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce @@ -37,6 +39,66 @@ ] +class PreparePromptMetadata(NamedTuple): + input_tokens: List[int] + input_positions: List[int] + attn_metadata: Optional[AttentionMetadataPerStage] + prompt_lens: List[int] + subquery_lens: List[int] + lora_index_mapping: List[int] + lora_prompt_mapping: List[int] + lora_requests: Set[LoRARequest] + multi_modal_input: Optional[torch.Tensor] + slot_mapping: List[int] + + @classmethod + def empty(cls): + return PreparePromptMetadata( + input_tokens=[], + input_positions=[], + attn_metadata=None, + prompt_lens=[], + subquery_lens=[], + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + multi_modal_input=None, + slot_mapping=[], + ) + + +class PrepareDecodeMetadata(NamedTuple): + input_tokens: List[int] + input_positions: List[int] + attn_metadata: Optional[AttentionMetadata] + lora_index_mapping: List[int] + lora_prompt_mapping: List[int] + lora_requests: Set[LoRARequest] + slot_mapping: List[int] + + @classmethod + def empty(cls): + return PrepareDecodeMetadata( + input_tokens=[], + input_positions=[], + attn_metadata=None, + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + slot_mapping=[], + ) + + +# How batches are constructed. +class BatchType(IntEnum): + # Every batch is prefill. + PREFILL = 0 + # Every batch is decode. + DECODE = 1 + # Batch is a mixture of prefill and decode. + MIXED = 2 + + class ModelRunner: def __init__( @@ -152,10 +214,7 @@ def get_max_block_per_batch(self) -> int: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], List[int], List[int], Set[LoRARequest], - torch.Tensor]: - assert len(seq_group_metadata_list) > 0 + ) -> PreparePromptMetadata: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -169,6 +228,9 @@ def _prepare_prompt( prefix_block_tables: List[List[int]] = [] multi_modal_input_list: List[torch.Tensor] = [] + if len(seq_group_metadata_list) == 0: + return PreparePromptMetadata.empty() + for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -178,7 +240,8 @@ def _prepare_prompt( computed_block_nums = seq_group_metadata.computed_block_nums if (self.scheduler_config is not None and self.scheduler_config.chunked_prefill_enabled - and computed_block_nums is not None): + and not (computed_block_nums is None + or computed_block_nums == [])): raise RuntimeError( "chunked prefill cannot be used with prefix caching " "now.") @@ -190,13 +253,8 @@ def _prepare_prompt( # it contains output tokens. prefill_end = min(seq_data.get_len(), computed_len + token_chunk_size) - # TODO(sang): Rename it after chunked prefill is introduced. prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] - prompt_len = len(prompt_tokens) - # Right now, the prefill_end is always same as the length of - # sequence. However, once chunked prefill is introduced, this - # assumption can be changed. - assert prefill_end == seq_data.get_len() + prompt_len = prefill_end prompt_lens.append(prompt_len) # NOTE: This only works for oooooooxxx style attention. @@ -206,6 +264,14 @@ def _prepare_prompt( computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) + elif self.scheduler_config.chunked_prefill_enabled: + if seq_group_metadata.block_tables is not None: + # Prefill has chunked before. + block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) + else: + # The first prefill. + prefix_block_tables.append([]) else: prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this @@ -267,20 +333,8 @@ def _prepare_prompt( max_subquery_len = max(subquery_lens) max_prompt_len = max(prompt_lens) - num_prompt_tokens = len(input_tokens) assert max_subquery_len > 0 - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - lora_index_mapping = lora_index_mapping - context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -332,11 +386,8 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - slot_mapping=slot_mapping, prompt_lens=prompt_lens, prompt_lens_tensor=prompt_lens_tensor, - num_prompt_tokens=num_prompt_tokens, - num_generation_tokens=0, max_subquery_len=max_subquery_len, max_context_len=None, max_prompt_len=max_prompt_len, @@ -345,18 +396,25 @@ def _prepare_prompt( context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, prompt_lens, - subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input) + + return PreparePromptMetadata( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + prompt_lens=prompt_lens, + subquery_lens=subquery_lens, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + multi_modal_input=multi_modal_input, + slot_mapping=slot_mapping, + ) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], Set[LoRARequest]]: - assert len(seq_group_metadata_list) > 0 + ) -> PrepareDecodeMetadata: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -366,6 +424,9 @@ def _prepare_decode( lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + if len(seq_group_metadata_list) == 0: + return PrepareDecodeMetadata.empty() + for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt assert seq_group_metadata.token_chunk_size == 1 @@ -424,15 +485,6 @@ def _prepare_decode( lora_index_mapping.append(0) batch_size = graph_batch_size - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -440,9 +492,9 @@ def _prepare_decode( if use_captured_graph: # When using cuda-graph all these tensors should be # padded. - assert context_lens.shape[0] == input_tokens.shape[0] - assert context_lens.shape[0] == input_positions.shape[0] - assert context_lens.shape[0] == slot_mapping.shape[0] + assert context_lens.shape[0] == len(input_tokens) + assert context_lens.shape[0] == len(input_positions) + assert context_lens.shape[0] == len(slot_mapping) # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -464,11 +516,8 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - slot_mapping=slot_mapping, prompt_lens=None, prompt_lens_tensor=None, - num_prompt_tokens=0, - num_generation_tokens=len(input_tokens), max_subquery_len=None, max_context_len=max_context_len, max_prompt_len=None, @@ -477,10 +526,16 @@ def _prepare_decode( context_lens=context_lens, block_tables=block_tables, use_cuda_graph=use_captured_graph, - kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, - lora_index_mapping, lora_prompt_mapping, lora_requests) + return PrepareDecodeMetadata( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + slot_mapping=slot_mapping, + ) def _prepare_sample( self, @@ -586,26 +641,66 @@ def prepare_input_tensors( ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[int], LoRAMapping, torch.Tensor]: if self.is_driver_worker: - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt + prefill_reqs = [] + decode_reqs = [] + for seq_group_meta in seq_group_metadata_list: + if seq_group_meta.is_prompt: + prefill_reqs.append(seq_group_meta) + else: + decode_reqs.append(seq_group_meta) + # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, prompt_lens, - subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, attn_metadata, - lora_index_mapping, lora_prompt_mapping, - lora_requests) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] - subquery_lens = None - multi_modal_input = None + ( + input_tokens, + input_positions, + prefill_attn_metadata, + prompt_lens, + subquery_lens, + lora_index_mapping, + lora_prompt_mapping, + lora_requests, + multi_modal_input, + slot_mapping, + ) = self._prepare_prompt(prefill_reqs) + ( + decode_input_tokens, + decode_input_positions, + decode_attn_metadata, + decode_lora_index_mapping, + decode_lora_prompt_mapping, + decode_lora_requests, + decode_slot_mapping, + ) = self._prepare_decode(decode_reqs) sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens) + if not self.scheduler_config.chunked_prefill_enabled: + assert (len(prefill_reqs) and len(decode_reqs)) == 0 + + num_prefills = len(prompt_lens) + num_prefill_tokens = len(input_tokens) + num_decode_tokens = len(decode_input_tokens) + + # Coalesce tensors. Note that attn_metadata is currently not + # coalesced for simplicity. + input_tokens.extend(decode_input_tokens) + input_positions.extend(decode_input_positions) + slot_mapping.extend(decode_slot_mapping) + lora_index_mapping.extend(decode_lora_index_mapping) + lora_prompt_mapping.extend(decode_lora_prompt_mapping) + lora_requests.update(decode_lora_requests) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + if self.lora_config: lora_mapping = LoRAMapping( lora_index_mapping, @@ -615,6 +710,16 @@ def prepare_input_tensors( lora_mapping = None # Broadcast the metadata. + # If batch contains both prefill and decode, it sends 2 broadcasts. + # If it only contains 1 type, it triggers a single broadcast. + if (prefill_attn_metadata is not None + and decode_attn_metadata is not None): + batch_type = BatchType.MIXED + elif prefill_attn_metadata is not None: + batch_type = BatchType.PREFILL + else: + batch_type = BatchType.DECODE + metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, @@ -623,19 +728,49 @@ def prepare_input_tensors( "lora_requests": lora_requests, "lora_mapping": lora_mapping, "multi_modal_input": multi_modal_input, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + "batch_type": batch_type, } - metadata_dict.update(attn_metadata.asdict_zerocopy()) + if prefill_attn_metadata is not None: + metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) + else: + metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) + + # Broadcast decode attn metadata for mixed batch type. + # The additional broadcast costs 300us overhead on 4 A10 GPUs. + # We can potentially reduce the overhead by coelescing tensors. + if batch_type == BatchType.MIXED: + assert decode_attn_metadata is not None + metadata_dict = decode_attn_metadata.asdict_zerocopy() + broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") + slot_mapping = metadata_dict.pop("slot_mapping") + num_prefills = metadata_dict.pop("num_prefills") selected_token_indices = metadata_dict.pop( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") - attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") + num_decode_tokens = metadata_dict.pop("num_decode_tokens") + batch_type = metadata_dict.pop("batch_type") + + # Create an attention metadata. + prefill_attn_metadata = None + decode_attn_metadata = None + if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: + prefill_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + else: + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -646,6 +781,23 @@ def prepare_input_tensors( perform_sampling=False, ) + # if it is a mixed batch, decode attn_metadata is broadcasted + # separately. + if batch_type == BatchType.MIXED: + metadata_dict = broadcast_tensor_dict(src=0) + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + + attn_metadata = AttentionMetadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + prefill_metadata=prefill_attn_metadata, + decode_metadata=decode_attn_metadata, + kv_cache_dtype=self.kv_cache_dtype, + ) + return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_input) @@ -663,8 +815,10 @@ def execute_model( if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) - # Execute the model. - if attn_metadata.use_cuda_graph: + # Currently cuda graph is only supported by the decode phase. + prefill_meta = attn_metadata.prefill_metadata + decode_meta = attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: @@ -842,13 +996,10 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): # Create dummy attn_metadata. - attn_metadata = self.attn_backend.make_metadata( + decode_metadata = self.attn_backend.make_metadata( is_prompt=False, - slot_mapping=slot_mapping[:batch_size], prompt_lens=None, prompt_lens_tensor=None, - num_prompt_tokens=0, - num_generation_tokens=batch_size, max_subquery_len=None, max_context_len=self.max_context_len_to_capture, max_prompt_len=None, @@ -857,6 +1008,14 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: context_lens=context_lens[:batch_size], block_tables=block_tables[:batch_size], use_cuda_graph=True, + ) + attn_metadata = AttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], + prefill_metadata=None, + decode_metadata=decode_metadata, kv_cache_dtype=self.kv_cache_dtype, ) @@ -950,8 +1109,8 @@ def capture( "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, - "context_lens": attn_metadata.context_lens, - "block_tables": attn_metadata.block_tables, + "context_lens": attn_metadata.decode_metadata.context_lens, + "block_tables": attn_metadata.decode_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} return @@ -972,10 +1131,10 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["context_lens"].copy_(attn_metadata.context_lens, - non_blocking=True) - self.input_buffers["block_tables"].copy_(attn_metadata.block_tables, - non_blocking=True) + self.input_buffers["context_lens"].copy_( + attn_metadata.decode_metadata.context_lens, non_blocking=True) + self.input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) # Run the graph. self.graph.replay()