diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py new file mode 100644 index 000000000000..e47f13f05316 --- /dev/null +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest + +from vllm import LLM, SamplingParams + +if os.getenv("VLLM_USE_V1", "0") != "1": + pytest.skip("Test package requires V1", allow_module_level=True) + +MODEL = "meta-llama/Llama-3.2-1B" +PROMPT = "Hello my name is Robert and I" + + +@pytest.fixture(scope="module") +def model() -> LLM: + return LLM(MODEL, enforce_eager=True) + + +def test_n_gt_1(model): + """ParallelSampling is supported.""" + + params = SamplingParams(n=3) + outputs = model.generate(PROMPT, params) + assert len(outputs[0].outputs) == 3 + + +def test_best_of(model): + """Raise a ValueError since best_of is deprecated.""" + + params = SamplingParams(n=2, best_of=3) + with pytest.raises(ValueError): + _ = model.generate(PROMPT, params) + + +def test_penalties(model): + """Check that we do not get errors if applied.""" + + params = SamplingParams( + temperature=1.2, + presence_penalty=1.2, + frequency_penalty=1.2, + repetition_penalty=1.2, + min_p=0.5, + top_p=0.5, + top_k=3, + ) + _ = model.generate(PROMPT, params) + + +def test_stop(model): + """Check that we respect the stop words.""" + + output = model.generate(PROMPT, SamplingParams(temperature=0)) + split_text = output[0].outputs[0].text.split() + + STOP_IDX = 5 + params = SamplingParams(temperature=0, stop=split_text[STOP_IDX]) + output = model.generate(PROMPT, params) + new_split_text = output[0].outputs[0].text.split() + + # Output should not contain the stop word. + assert len(new_split_text) == STOP_IDX + + params = SamplingParams(temperature=0, + stop=split_text[STOP_IDX], + include_stop_str_in_output=True) + output = model.generate(PROMPT, params) + new_split_text = output[0].outputs[0].text.split() + + # Output should contain the stop word. + assert len(new_split_text) == STOP_IDX + 1 + + +def test_stop_token_ids(model): + """Check that we respect the stop token ids.""" + + output = model.generate(PROMPT, SamplingParams(temperature=0)) + + stop_token_id_0 = output[0].outputs[0].token_ids[5] + stop_token_id_1 = output[0].outputs[0].token_ids[6] + + stop_token_ids = [stop_token_id_1, stop_token_id_0] + params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids) + output = model.generate(PROMPT, params) + assert output[0].outputs[0].token_ids[-1] == stop_token_id_0 + + stop_token_ids = [stop_token_id_0, stop_token_id_1] + params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids) + assert output[0].outputs[0].token_ids[-1] == stop_token_id_0 + + +def test_bad_words(model): + """Check that we respect bad words.""" + + with pytest.raises(ValueError): + _ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"])) + + +def test_logits_processor(model): + """Check that we reject logits processor.""" + + # This sample logits processor gives infinite score to the i-th token, + # where i is the length of the input sequence. + # We therefore expect the output token sequence to be [0, 1, 2, ...] + def pick_ith(token_ids, logits): + logits[len(token_ids)] = float("inf") + return logits + + with pytest.raises(ValueError): + _ = model.generate(PROMPT, + SamplingParams(logits_processors=[pick_ith])) + + +def test_allowed_token_ids(model): + """Check that we can use allowed_token_ids.""" + + TOKEN_ID = 10 + allowed_token_ids = [TOKEN_ID] + output = model.generate( + PROMPT, SamplingParams(allowed_token_ids=allowed_token_ids)) + assert output[0].outputs[0].token_ids[-1] == TOKEN_ID + + # Reject negative token id. + with pytest.raises(ValueError): + _ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[-1])) + + # Reject out of vocabulary. + with pytest.raises(ValueError): + _ = model.generate(PROMPT, + SamplingParams(allowed_token_ids=[10000000])) + + +def test_priority(model): + """Check that we reject requests with priority.""" + + # Reject all allowed token ids + with pytest.raises(ValueError): + _ = model.generate(PROMPT, priority=[1]) + + +def test_seed(model): + """Check that seed impacts randomness.""" + + out_1 = model.generate(PROMPT, SamplingParams(seed=42)) + out_2 = model.generate(PROMPT, SamplingParams(seed=42)) + out_3 = model.generate(PROMPT, SamplingParams(seed=43)) + + assert out_1[0].outputs[0].text == out_2[0].outputs[0].text + assert out_1[0].outputs[0].text != out_3[0].outputs[0].text diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 97c1ef5e9e52..713a5d38dfdd 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -55,11 +55,8 @@ def __init__( def _validate_logprobs( self, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams, ) -> None: - if not isinstance(params, SamplingParams): - return - max_logprobs = self.model_config.max_logprobs # Validate sample logprobs. if params.logprobs and params.logprobs > max_logprobs: @@ -79,17 +76,10 @@ def _validate_logprobs( raise ValueError("Prefix caching with prompt logprobs not yet " "supported on VLLM V1.") - def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: - if lora_request is not None and not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") - - def _validate_allowed_token_ids( + def _validate_sampling_params( self, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams, ) -> None: - if not isinstance(params, SamplingParams): - return if params.allowed_token_ids is None: return if not params.allowed_token_ids: @@ -99,6 +89,42 @@ def _validate_allowed_token_ids( raise ValueError( "allowed_token_ids contains out-of-vocab token id!") + def _validate_supported_sampling_params( + self, + params: SamplingParams, + ) -> None: + # Best of not yet supported. + if params.best_of: + raise ValueError("VLLM V1 does not yet support best_of.") + # Bad words not yet supported. + if params.bad_words: + raise ValueError("VLLM V1 does not yet support bad_words.") + # Logits processors not supported. + if params.logits_processors: + raise ValueError("VLLM V1 does not support per request " + "user provided logits processors.") + + def _validate_params( + self, + params: Union[SamplingParams, PoolingParams], + ): + """ + Validate supported SamplingParam. + Should raise ValueError if unsupported for API Server. + """ + + if not isinstance(params, SamplingParams): + raise ValueError("V1 does not yet support Pooling models.") + + self._validate_logprobs(params) + self._validate_sampling_params(params) + self._validate_supported_sampling_params(params) + + def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + def process_inputs( self, request_id: str, @@ -114,14 +140,17 @@ def process_inputs( # TODO(woosuk): Support pooling models. # TODO(woosuk): Support encoder-decoder models. - self._validate_logprobs(params) self._validate_lora(lora_request) - self._validate_allowed_token_ids(params) + self._validate_params(params) + if priority != 0: + raise ValueError("V1 does not support priority yet.") + if trace_headers is not None: + raise ValueError("V1 does not support tracing yet.") + if prompt_adapter_request is not None: + raise ValueError("V1 does not support prompt_adapter_request.") if arrival_time is None: arrival_time = time.time() - assert priority == 0, "vLLM V1 does not support priority at the moment." - assert trace_headers is None, "vLLM V1 does not support tracing yet." # Process inputs, which includes: # 1. Tokenize text prompt, with LoRA request if one exists. diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 2fe177ea4e12..c0e9ff0286d6 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -298,6 +298,11 @@ def add_request( if sampling_params.logit_bias is not None: self.logit_bias[req_index] = sampling_params.logit_bias + # FIXME: this implementation is incorrect. We create this mask + # then apply -inf to these specific tokens, which means we never + # select the allowed tokens! We cannot do the reverse, since + # this will impact the requests that do not have allowed_token_ids. + # This feature is currently disabled on V1 (we reject in Processor). if sampling_params.allowed_token_ids: self.has_allowed_token_ids.add(req_id) if self.allowed_token_ids_mask_cpu_tensor is None: