Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 89 additions & 1 deletion tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import inspect
from typing import Optional

import numpy as np
Expand All @@ -9,7 +10,8 @@
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState,
InputBatch)

VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
Expand All @@ -20,6 +22,34 @@
MAX_NUM_PROMPT_TOKENS = 64


def _compare_objs(obj1, obj2):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
a[0] for a in attrs
if not (a[0].startswith('__') and a[0].endswith('__'))
])
for attr_name in attr_names:
a = getattr(obj1, attr_name)
b = getattr(obj2, attr_name)

is_same = False
if isinstance(a, torch.Tensor):
if (a.numel() == 0 or b.numel() == 0):
is_same = (a.numel() == 0 and b.numel() == 0)
elif torch.allclose(a, b):
is_same = True
elif isinstance(a, np.ndarray):
if np.allclose(a, b):
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata)):
_compare_objs(a, b)
is_same = True # if we make it here must be same
elif a == b:
is_same = True
assert is_same, f"Attribute {attr_name} is different"\
f" in {obj1} and {obj2}: {a} != {b}"


def _remove_requests(
input_batch: InputBatch, batch_size: int,
reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]:
Expand Down Expand Up @@ -254,3 +284,61 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask,
sampling_metadata.allowed_token_ids_mask)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("swap_list", [((0, 1), )])
def test_swap_states_in_input_batch(device: str, batch_size: int,
swap_list: list):
"""
Tests the logic for managing sampling metadata in the InputBatch.

This test involves adding a set of requests to the InputBatch,
followed by removing a subset of them. Afterward, the batch is compacted,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
)

reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
input_batch.add_request(req, req_index)
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids

reordered_reqs = reqs.copy()
for swap_pair in swap_list:
reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \
reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]]
input_batch.swap_states(swap_pair[0], swap_pair[1])

for req_index in range(batch_size):
req = reordered_reqs[req_index]
ref_input_batch.add_request(req, req_index)

input_batch.refresh_sampling_metadata()
ref_input_batch.refresh_sampling_metadata()

_compare_objs(input_batch, ref_input_batch)
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def __init__(self, runner: "GPUModelRunner"):
self.runner = runner

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput"):
pass
scheduler_output: "SchedulerOutput") -> bool:
return False

def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
Expand Down
Loading