Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
PlaceholderModule, StoreBoolean, bind_kv_cache,
deprecate_kwargs, get_open_port, memory_profiling,
merge_async_iterators, supports_kw)
merge_async_iterators, supports_kw, swap_dict_values)

from .utils import error_on_warning, fork_new_process_for_each_test

Expand Down Expand Up @@ -449,3 +449,26 @@ def build_ctx():
with build_ctx():
# Test conflict with internal __module attribute
_ = placeholder_attr.module


@pytest.mark.parametrize(
"obj,key1,key2",
[
# Tests for both keys exist
({1: "a", 2: "b"}, 1, 2),
# Tests for one key does not exist
({1: "a", 2: "b"}, 1, 3),
# Tests for both keys do not exist
({1: "a", 2: "b"}, 3, 4),
])
def test_swap_dict_values(obj, key1, key2):
original_obj = obj.copy()
swap_dict_values(obj, key1, key2)
if key1 in original_obj:
assert obj[key2] == original_obj[key1]
else:
assert key2 not in obj
if key2 in original_obj:
assert obj[key1] == original_obj[key2]
else:
assert key1 not in obj
1 change: 1 addition & 0 deletions tests/v1/sample/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def create_sampling_metadata(spec_tokens: list[list[int]]) -> SamplingMetadata:
min_tokens={},
logit_bias=[None] * batch_size,
allowed_token_ids_mask=None,
bad_words_token_ids={},
)


Expand Down
76 changes: 76 additions & 0 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,49 @@ def _create_allowed_token_ids(
return mask


def _create_bad_words_token_ids(
batch_size: int, vocab_size: int,
bad_words_lengths: list[tuple[int]]) -> dict[int, list[list[int]]]:
bad_words_token_ids = {}
for batch_idx in range(batch_size):
token_ids_single_batch = []
for bad_words_length in bad_words_lengths:
token_ids = np.random.choice(vocab_size,
size=bad_words_length,
replace=True).tolist()
token_ids_single_batch.append(token_ids)
bad_words_token_ids[batch_idx] = token_ids_single_batch
if batch_size >= 2:
# Test no bad_words for some batch
no_bad_words_batch_idx = np.random.choice(batch_size)
bad_words_token_ids.pop(no_bad_words_batch_idx, None)
return bad_words_token_ids


def _update_output_token_ids_for_bad_words(
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
bad_words_last_tokens = {}
for batch_idx, bad_words_token_ids in metadata.bad_words_token_ids.items():
output_token_ids = metadata.output_token_ids[batch_idx]
bad_words_last_token: list[int] = []
for i, bad_word_token_ids in enumerate(bad_words_token_ids):
if len(bad_word_token_ids) == 1:
# Single token id always affects logits
bad_words_last_token.append(bad_word_token_ids[0])
else:
prefix_length = len(bad_word_token_ids) - 1
has_bad_words = np.random.choice([True, False])
if has_bad_words:
output_token_ids[-prefix_length:] = bad_word_token_ids[:-1]
bad_words_last_token.append(bad_word_token_ids[-1])
break # Maximum one update to output_token_ids
else: # Make sure no accidental match to bad words
output_token_ids[-1] = (bad_word_token_ids[-2] +
1) % vocab_size
bad_words_last_tokens[batch_idx] = bad_words_last_token
return bad_words_last_tokens


def _create_default_sampling_metadata(
num_output_tokens: int,
batch_size: int,
Expand Down Expand Up @@ -112,6 +155,7 @@ def _create_default_sampling_metadata(
min_tokens={},
logit_bias=[None] * batch_size,
allowed_token_ids_mask=None,
bad_words_token_ids={},
)
return fake_sampling_metadata

Expand Down Expand Up @@ -467,3 +511,35 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
"inf"), f"{batch_idx}, {token_id}"
else:
assert logits_for_req[token_id] != -float("inf")


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
def test_sampler_bad_words(device: str, batch_size: int,
bad_words_lengths: list[tuple[int]]):
"""
Test to verify that when the bad words restriction is present, tokens
are penalized based on their match with the bad words.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids(
batch_size, VOCAB_SIZE, bad_words_lengths)
bad_words_last_tokens = _update_output_token_ids_for_bad_words(
sampling_metadata, VOCAB_SIZE)
sampler = Sampler()
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
for token_id in range(VOCAB_SIZE):
if (batch_idx in bad_words_last_tokens
and token_id in bad_words_last_tokens[batch_idx]):
assert logits_for_req[token_id] == -float("inf")
else:
assert logits_for_req[token_id] != -float("inf")
18 changes: 16 additions & 2 deletions tests/v1/sample/test_sampling_params_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,22 @@ def test_detokenize_false(model):
def test_bad_words(model):
"""Check that we respect bad words."""

with pytest.raises(ValueError):
_ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"]))
output = model.generate(PROMPT, SamplingParams(temperature=0))
split_text = output[0].outputs[0].text.split()

bad_words_1 = " ".join(split_text[:2])
params = SamplingParams(temperature=0, bad_words=[bad_words_1])
output = model.generate(PROMPT, params)
new_text = output[0].outputs[0].text
assert bad_words_1 not in new_text

bad_words_2 = new_text.split()[-1]
params = SamplingParams(temperature=0,
bad_words=[bad_words_1, bad_words_2])
output = model.generate(PROMPT, params)
new_text = output[0].outputs[0].text
assert bad_words_1 not in new_text
assert bad_words_2 not in new_text


def test_logits_processor(model):
Expand Down
6 changes: 6 additions & 0 deletions tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _construct_expected_sampling_metadata(
VOCAB_SIZE,
dtype=torch.bool,
device=device)
bad_words_token_ids = {}
for req in reqs:
if req.req_id not in req_ids_retained:
continue
Expand All @@ -123,6 +124,8 @@ def _construct_expected_sampling_metadata(
if req.sampling_params.allowed_token_ids:
allowed_token_ids_mask[index_in_input_batch][
req.sampling_params.allowed_token_ids] = True
bad_words_token_ids[
index_in_input_batch] = req.sampling_params.bad_words_token_ids

return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float,
Expand Down Expand Up @@ -159,6 +162,7 @@ def _construct_expected_sampling_metadata(
and all(x == 1 for x in repetition_penalties)),
logit_bias=logit_bias,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=bad_words_token_ids,
)


Expand Down Expand Up @@ -284,6 +288,8 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask,
sampling_metadata.allowed_token_ids_mask)
assert expected_sampling_metadata.bad_words_token_ids == \
sampling_metadata.bad_words_token_ids


@pytest.mark.parametrize("device", CUDA_DEVICES)
Expand Down
52 changes: 51 additions & 1 deletion vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer

logger = init_logger(__name__)

Expand Down Expand Up @@ -202,7 +204,6 @@ class SamplingParams(
seed: Optional[int] = None
stop: Optional[Union[str, list[str]]] = None
stop_token_ids: Optional[list[int]] = None
bad_words: Optional[list[str]] = None
ignore_eos: bool = False
max_tokens: Optional[int] = 16
min_tokens: int = 0
Expand Down Expand Up @@ -232,6 +233,10 @@ class SamplingParams(
allowed_token_ids: Optional[list[int]] = None
extra_args: Optional[dict[str, Any]] = None

# Fields used for bad words
bad_words: Optional[list[str]] = None
_bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list)

@staticmethod
def from_optional(
n: Optional[int] = 1,
Expand Down Expand Up @@ -464,6 +469,46 @@ def update_from_generation_config(
eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids)

def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
if self.bad_words is None:
return
for bad_word in self.bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()

if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(
text=prompt, add_special_tokens=False)

# If no space at the beginning
# or if prefix space produces a new word token
if (not add_prefix_space) or (
add_prefix_space and prompt_token_ids[0]
!= self._bad_words_token_ids[-1][0]
and len(prompt_token_ids) == len(
self._bad_words_token_ids[-1])):
self._bad_words_token_ids.append(prompt_token_ids)

invalid_token_ids = [
token_id for bad_words_token_ids in self._bad_words_token_ids
for token_id in bad_words_token_ids
if token_id < 0 or token_id > tokenizer.max_token_id
]
if len(invalid_token_ids) > 0:
raise ValueError(
f"The model vocabulary size is {tokenizer.max_token_id+1},"
f" but the following tokens"
f" were specified as bad: {invalid_token_ids}."
f" All token id values should be integers satisfying:"
f" 0 <= token_id <= {tokenizer.max_token_id}.")

@cached_property
def sampling_type(self) -> SamplingType:
if self.temperature < _SAMPLING_EPS:
Expand All @@ -476,6 +521,11 @@ def sampling_type(self) -> SamplingType:
def all_stop_token_ids(self) -> set[int]:
return self._all_stop_token_ids

@property
def bad_words_token_ids(self) -> list[list[int]]:
# For internal use only. Backward compatibility not guaranteed
return self._bad_words_token_ids

def clone(self) -> "SamplingParams":
"""Deep copy, but maybe not the LogitsProcessor objects.

Expand Down
16 changes: 16 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2361,3 +2361,19 @@ def __dir__(self) -> list[str]:
if self._module is None:
self._module = self._load()
return dir(self._module)


def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
"""
Helper function to swap values for two keys
"""
v1 = obj.get(key1)
v2 = obj.get(key2)
if v1 is not None:
obj[key2] = v1
else:
obj.pop(key2, None)
if v2 is not None:
obj[key1] = v2
else:
obj.pop(key1, None)
5 changes: 2 additions & 3 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ def _validate_supported_sampling_params(
# Best of not yet supported.
if params.best_of is not None and params.best_of > 1:
raise ValueError("VLLM V1 does not yet support best_of.")
# Bad words not yet supported.
if params.bad_words:
raise ValueError("VLLM V1 does not yet support bad_words.")
# Logits processors not supported.
if params.logits_processors:
raise ValueError("VLLM V1 does not support per request "
Expand Down Expand Up @@ -203,6 +200,8 @@ def process_inputs(
sampling_params = params.clone()
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))

# Multimodal related.
# Compute MM hashes (if enabled)
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ class SamplingMetadata:
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
allowed_token_ids_mask: Optional[torch.Tensor]

# req_index -> bad_words_token_ids
bad_words_token_ids: dict[int, list[list[int]]]
38 changes: 38 additions & 0 deletions vllm/v1/sample/ops/bad_words.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0

import torch

_SMALLEST_LOGIT = float("-inf")


def _apply_bad_words_single_batch(
logits: torch.Tensor,
bad_words_token_ids: list[list[int]],
past_tokens_ids: list[int],
) -> None:
for bad_word_ids in bad_words_token_ids:
if len(bad_word_ids) > len(past_tokens_ids) + 1:
continue

prefix_length = len(bad_word_ids) - 1
last_token_id = bad_word_ids[-1]
if prefix_length > 0:
actual_prefix = past_tokens_ids[-prefix_length:]
else:
actual_prefix = []
expected_prefix = bad_word_ids[:prefix_length]

assert len(actual_prefix) == len(expected_prefix)

if actual_prefix == expected_prefix:
logits[last_token_id] = _SMALLEST_LOGIT


def apply_bad_words(
logits: torch.Tensor,
bad_words_token_ids: dict[int, list[list[int]]],
past_tokens_ids: list[list[int]],
) -> None:
for i, bad_words_ids in bad_words_token_ids.items():
_apply_bad_words_single_batch(logits[i], bad_words_ids,
past_tokens_ids[i])
Loading