diff --git a/tests/test_utils.py b/tests/test_utils.py index 8b67e92fca68..49fb02fd0403 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,7 +14,7 @@ from vllm.utils import (FlexibleArgumentParser, MemorySnapshot, PlaceholderModule, StoreBoolean, bind_kv_cache, deprecate_kwargs, get_open_port, memory_profiling, - merge_async_iterators, supports_kw) + merge_async_iterators, supports_kw, swap_dict_values) from .utils import error_on_warning, fork_new_process_for_each_test @@ -449,3 +449,26 @@ def build_ctx(): with build_ctx(): # Test conflict with internal __module attribute _ = placeholder_attr.module + + +@pytest.mark.parametrize( + "obj,key1,key2", + [ + # Tests for both keys exist + ({1: "a", 2: "b"}, 1, 2), + # Tests for one key does not exist + ({1: "a", 2: "b"}, 1, 3), + # Tests for both keys do not exist + ({1: "a", 2: "b"}, 3, 4), + ]) +def test_swap_dict_values(obj, key1, key2): + original_obj = obj.copy() + swap_dict_values(obj, key1, key2) + if key1 in original_obj: + assert obj[key2] == original_obj[key1] + else: + assert key2 not in obj + if key2 in original_obj: + assert obj[key1] == original_obj[key2] + else: + assert key1 not in obj diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index b1862455d0ec..190927745f1f 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -42,6 +42,7 @@ def create_sampling_metadata(spec_tokens: list[list[int]]) -> SamplingMetadata: min_tokens={}, logit_bias=[None] * batch_size, allowed_token_ids_mask=None, + bad_words_token_ids={}, ) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index b702d9ed7f83..5f041b448937 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -77,6 +77,49 @@ def _create_allowed_token_ids( return mask +def _create_bad_words_token_ids( + batch_size: int, vocab_size: int, + bad_words_lengths: list[tuple[int]]) -> dict[int, list[list[int]]]: + bad_words_token_ids = {} + for batch_idx in range(batch_size): + token_ids_single_batch = [] + for bad_words_length in bad_words_lengths: + token_ids = np.random.choice(vocab_size, + size=bad_words_length, + replace=True).tolist() + token_ids_single_batch.append(token_ids) + bad_words_token_ids[batch_idx] = token_ids_single_batch + if batch_size >= 2: + # Test no bad_words for some batch + no_bad_words_batch_idx = np.random.choice(batch_size) + bad_words_token_ids.pop(no_bad_words_batch_idx, None) + return bad_words_token_ids + + +def _update_output_token_ids_for_bad_words( + metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]: + bad_words_last_tokens = {} + for batch_idx, bad_words_token_ids in metadata.bad_words_token_ids.items(): + output_token_ids = metadata.output_token_ids[batch_idx] + bad_words_last_token: list[int] = [] + for i, bad_word_token_ids in enumerate(bad_words_token_ids): + if len(bad_word_token_ids) == 1: + # Single token id always affects logits + bad_words_last_token.append(bad_word_token_ids[0]) + else: + prefix_length = len(bad_word_token_ids) - 1 + has_bad_words = np.random.choice([True, False]) + if has_bad_words: + output_token_ids[-prefix_length:] = bad_word_token_ids[:-1] + bad_words_last_token.append(bad_word_token_ids[-1]) + break # Maximum one update to output_token_ids + else: # Make sure no accidental match to bad words + output_token_ids[-1] = (bad_word_token_ids[-2] + + 1) % vocab_size + bad_words_last_tokens[batch_idx] = bad_words_last_token + return bad_words_last_tokens + + def _create_default_sampling_metadata( num_output_tokens: int, batch_size: int, @@ -112,6 +155,7 @@ def _create_default_sampling_metadata( min_tokens={}, logit_bias=[None] * batch_size, allowed_token_ids_mask=None, + bad_words_token_ids={}, ) return fake_sampling_metadata @@ -467,3 +511,35 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int, "inf"), f"{batch_idx}, {token_id}" else: assert logits_for_req[token_id] != -float("inf") + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)]) +def test_sampler_bad_words(device: str, batch_size: int, + bad_words_lengths: list[tuple[int]]): + """ + Test to verify that when the bad words restriction is present, tokens + are penalized based on their match with the bad words. + """ + torch.set_default_device(device) + # Create fake logits where each token is assigned the same + # logit value. + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids( + batch_size, VOCAB_SIZE, bad_words_lengths) + bad_words_last_tokens = _update_output_token_ids_for_bad_words( + sampling_metadata, VOCAB_SIZE) + sampler = Sampler() + logits = sampler.apply_bad_words(fake_logits, sampling_metadata) + logits = logits.cpu() + for batch_idx in range(batch_size): + logits_for_req = logits[batch_idx] + for token_id in range(VOCAB_SIZE): + if (batch_idx in bad_words_last_tokens + and token_id in bad_words_last_tokens[batch_idx]): + assert logits_for_req[token_id] == -float("inf") + else: + assert logits_for_req[token_id] != -float("inf") diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index dcb0fa20b1a0..0512a1e02660 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -120,8 +120,22 @@ def test_detokenize_false(model): def test_bad_words(model): """Check that we respect bad words.""" - with pytest.raises(ValueError): - _ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"])) + output = model.generate(PROMPT, SamplingParams(temperature=0)) + split_text = output[0].outputs[0].text.split() + + bad_words_1 = " ".join(split_text[:2]) + params = SamplingParams(temperature=0, bad_words=[bad_words_1]) + output = model.generate(PROMPT, params) + new_text = output[0].outputs[0].text + assert bad_words_1 not in new_text + + bad_words_2 = new_text.split()[-1] + params = SamplingParams(temperature=0, + bad_words=[bad_words_1, bad_words_2]) + output = model.generate(PROMPT, params) + new_text = output[0].outputs[0].text + assert bad_words_1 not in new_text + assert bad_words_2 not in new_text def test_logits_processor(model): diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 5f0cb1d3d3b3..192ddefe102d 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -100,6 +100,7 @@ def _construct_expected_sampling_metadata( VOCAB_SIZE, dtype=torch.bool, device=device) + bad_words_token_ids = {} for req in reqs: if req.req_id not in req_ids_retained: continue @@ -123,6 +124,8 @@ def _construct_expected_sampling_metadata( if req.sampling_params.allowed_token_ids: allowed_token_ids_mask[index_in_input_batch][ req.sampling_params.allowed_token_ids] = True + bad_words_token_ids[ + index_in_input_batch] = req.sampling_params.bad_words_token_ids return SamplingMetadata( temperature=torch.tensor(temperature, dtype=torch.float, @@ -159,6 +162,7 @@ def _construct_expected_sampling_metadata( and all(x == 1 for x in repetition_penalties)), logit_bias=logit_bias, allowed_token_ids_mask=allowed_token_ids_mask, + bad_words_token_ids=bad_words_token_ids, ) @@ -284,6 +288,8 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: assert torch.allclose( expected_sampling_metadata.allowed_token_ids_mask, sampling_metadata.allowed_token_ids_mask) + assert expected_sampling_metadata.bad_words_token_ids == \ + sampling_metadata.bad_words_token_ids @pytest.mark.parametrize("device", CUDA_DEVICES) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index ca577a6721fe..110efa229822 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -11,6 +11,8 @@ from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer logger = init_logger(__name__) @@ -202,7 +204,6 @@ class SamplingParams( seed: Optional[int] = None stop: Optional[Union[str, list[str]]] = None stop_token_ids: Optional[list[int]] = None - bad_words: Optional[list[str]] = None ignore_eos: bool = False max_tokens: Optional[int] = 16 min_tokens: int = 0 @@ -232,6 +233,10 @@ class SamplingParams( allowed_token_ids: Optional[list[int]] = None extra_args: Optional[dict[str, Any]] = None + # Fields used for bad words + bad_words: Optional[list[str]] = None + _bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list) + @staticmethod def from_optional( n: Optional[int] = 1, @@ -464,6 +469,46 @@ def update_from_generation_config( eos_ids.update(self.stop_token_ids) self.stop_token_ids = list(eos_ids) + def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None: + if self.bad_words is None: + return + for bad_word in self.bad_words: + # To prohibit words both at the beginning + # and in the middle of text + # (related to add_prefix_space tokenizer parameter) + for add_prefix_space in [False, True]: + prefix = " " if add_prefix_space else "" + prompt = prefix + bad_word.lstrip() + + if isinstance(tokenizer, MistralTokenizer): + # Mistral tokenizers should not add special tokens + prompt_token_ids = tokenizer.encode(text=prompt) + else: + prompt_token_ids = tokenizer.encode( + text=prompt, add_special_tokens=False) + + # If no space at the beginning + # or if prefix space produces a new word token + if (not add_prefix_space) or ( + add_prefix_space and prompt_token_ids[0] + != self._bad_words_token_ids[-1][0] + and len(prompt_token_ids) == len( + self._bad_words_token_ids[-1])): + self._bad_words_token_ids.append(prompt_token_ids) + + invalid_token_ids = [ + token_id for bad_words_token_ids in self._bad_words_token_ids + for token_id in bad_words_token_ids + if token_id < 0 or token_id > tokenizer.max_token_id + ] + if len(invalid_token_ids) > 0: + raise ValueError( + f"The model vocabulary size is {tokenizer.max_token_id+1}," + f" but the following tokens" + f" were specified as bad: {invalid_token_ids}." + f" All token id values should be integers satisfying:" + f" 0 <= token_id <= {tokenizer.max_token_id}.") + @cached_property def sampling_type(self) -> SamplingType: if self.temperature < _SAMPLING_EPS: @@ -476,6 +521,11 @@ def sampling_type(self) -> SamplingType: def all_stop_token_ids(self) -> set[int]: return self._all_stop_token_ids + @property + def bad_words_token_ids(self) -> list[list[int]]: + # For internal use only. Backward compatibility not guaranteed + return self._bad_words_token_ids + def clone(self) -> "SamplingParams": """Deep copy, but maybe not the LogitsProcessor objects. diff --git a/vllm/utils.py b/vllm/utils.py index 883a9e504065..9cad2b8854a2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2361,3 +2361,19 @@ def __dir__(self) -> list[str]: if self._module is None: self._module = self._load() return dir(self._module) + + +def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: + """ + Helper function to swap values for two keys + """ + v1 = obj.get(key1) + v2 = obj.get(key2) + if v1 is not None: + obj[key2] = v1 + else: + obj.pop(key2, None) + if v2 is not None: + obj[key1] = v2 + else: + obj.pop(key1, None) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 247fb046e81a..38638c1ee361 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -94,9 +94,6 @@ def _validate_supported_sampling_params( # Best of not yet supported. if params.best_of is not None and params.best_of > 1: raise ValueError("VLLM V1 does not yet support best_of.") - # Bad words not yet supported. - if params.bad_words: - raise ValueError("VLLM V1 does not yet support bad_words.") # Logits processors not supported. if params.logits_processors: raise ValueError("VLLM V1 does not support per request " @@ -203,6 +200,8 @@ def process_inputs( sampling_params = params.clone() sampling_params.update_from_generation_config( self.generation_config_fields, eos_token_id) + sampling_params.update_from_tokenizer( + self.tokenizer.get_lora_tokenizer(lora_request)) # Multimodal related. # Compute MM hashes (if enabled) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 55d9739b8007..e97e1235fb36 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -38,3 +38,6 @@ class SamplingMetadata: # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size, # vocab size). allowed_token_ids_mask: Optional[torch.Tensor] + + # req_index -> bad_words_token_ids + bad_words_token_ids: dict[int, list[list[int]]] diff --git a/vllm/v1/sample/ops/bad_words.py b/vllm/v1/sample/ops/bad_words.py new file mode 100644 index 000000000000..2984d4e4806f --- /dev/null +++ b/vllm/v1/sample/ops/bad_words.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + +_SMALLEST_LOGIT = float("-inf") + + +def _apply_bad_words_single_batch( + logits: torch.Tensor, + bad_words_token_ids: list[list[int]], + past_tokens_ids: list[int], +) -> None: + for bad_word_ids in bad_words_token_ids: + if len(bad_word_ids) > len(past_tokens_ids) + 1: + continue + + prefix_length = len(bad_word_ids) - 1 + last_token_id = bad_word_ids[-1] + if prefix_length > 0: + actual_prefix = past_tokens_ids[-prefix_length:] + else: + actual_prefix = [] + expected_prefix = bad_word_ids[:prefix_length] + + assert len(actual_prefix) == len(expected_prefix) + + if actual_prefix == expected_prefix: + logits[last_token_id] = _SMALLEST_LOGIT + + +def apply_bad_words( + logits: torch.Tensor, + bad_words_token_ids: dict[int, list[list[int]]], + past_tokens_ids: list[list[int]], +) -> None: + for i, bad_words_ids in bad_words_token_ids.items(): + _apply_bad_words_single_batch(logits[i], bad_words_ids, + past_tokens_ids[i]) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index b0eb533ae2e5..96f6d807b10c 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -6,6 +6,7 @@ from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.bad_words import apply_bad_words from vllm.v1.sample.ops.penalties import (apply_all_penalties, apply_min_token_penalties) from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler @@ -38,6 +39,8 @@ def forward( logits = logits.to(torch.float32) # Apply allowed token ids. logits = self.apply_allowed_token_ids(logits, sampling_metadata) + # Apply bad words exclusion. + logits = self.apply_bad_words(logits, sampling_metadata) # Apply logits bias. logits = self.apply_logits_bias(logits, sampling_metadata) # Apply penalties (e.g., min_tokens, freq_penalties). @@ -237,3 +240,16 @@ def apply_allowed_token_ids( logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf")) return logits + + def apply_bad_words( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + if sampling_metadata.bad_words_token_ids: + apply_bad_words( + logits, + sampling_metadata.bad_words_token_ids, + sampling_metadata.output_token_ids, + ) + return logits diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 6239a182e31e..9707cb5774cd 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -10,6 +10,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType +from vllm.utils import swap_dict_values from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import copy_slice from vllm.v1.worker.block_table import BlockTable @@ -204,6 +205,9 @@ def __init__( self.allowed_token_ids_mask: Optional[torch.Tensor] = None self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None + # req_index -> bad_words_token_ids + self.bad_words_token_ids: dict[int, list[list[int]]] = {} + self.req_output_token_ids: list[Optional[list[int]]] = [] # This is updated each time the batch constituents change. @@ -320,6 +324,9 @@ def add_request( self.allowed_token_ids_mask_cpu_tensor[req_index][ sampling_params.allowed_token_ids] = False + self.bad_words_token_ids[ + req_index] = sampling_params.bad_words_token_ids + # Add request lora ID if request.lora_request: lora_id = request.lora_request.lora_int_id @@ -369,6 +376,7 @@ def remove_request(self, req_id: str) -> Optional[int]: if self.allowed_token_ids_mask_cpu_tensor is not None: # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) + self.bad_words_token_ids.pop(req_index, None) return req_index def swap_states(self, i1: int, i2: int) -> None: @@ -413,27 +421,9 @@ def swap_states(self, i1: int, i2: int) -> None: self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp - g1 = self.generators.get(i1) - g2 = self.generators.get(i2) - if g1 is not None: - self.generators[i2] = g1 - else: - self.generators.pop(i2, None) - if g2 is not None: - self.generators[i1] = g2 - else: - self.generators.pop(i1, None) - - t1 = self.min_tokens.get(i1) - t2 = self.min_tokens.get(i2) - if t1 is not None: - self.min_tokens[i2] = t1 - else: - self.min_tokens.pop(i2, None) - if t2 is not None: - self.min_tokens[i1] = t2 - else: - self.min_tokens.pop(i1, None) + swap_dict_values(self.generators, i1, i2) + swap_dict_values(self.min_tokens, i1, i2) + swap_dict_values(self.bad_words_token_ids, i1, i2) self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ self.request_lora_mapping[i2], self.request_lora_mapping[i1] @@ -518,6 +508,10 @@ def condense(self, empty_req_indices: list[int]) -> None: empty_index] = self.allowed_token_ids_mask_cpu_tensor[ last_req_index] + bad_words_token_ids = self.bad_words_token_ids.pop( + last_req_index, None) + if bad_words_token_ids is not None: + self.bad_words_token_ids[empty_index] = bad_words_token_ids # Decrement last_req_index since it is now empty. last_req_index -= 1 @@ -585,6 +579,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata: no_penalties=self.no_penalties, logit_bias=self.logit_bias[:num_reqs], allowed_token_ids_mask=allowed_token_ids_mask, + bad_words_token_ids=self.bad_words_token_ids, ) def _make_prompt_token_ids_tensor(self) -> torch.Tensor: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 81dec429b425..e41427f6c452 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1268,6 +1268,7 @@ def _dummy_sampler_run( min_tokens={}, logit_bias=[None for _ in range(num_reqs)], allowed_token_ids_mask=None, + bad_words_token_ids={}, ) sampler_output = self.model.sample(logits=logits, sampling_metadata=dummy_metadata)