Skip to content

Commit d559869

Browse files
committed
fix param check
1 parent 88ca0b5 commit d559869

File tree

8 files changed

+53
-43
lines changed

8 files changed

+53
-43
lines changed

fastdeploy/engine/sampling_params.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,17 +209,15 @@ def _verify_args(self) -> None:
209209
)
210210

211211
if os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "0":
212-
if self.logprobs is not None and self.logprobs < 0:
213-
raise ValueError(f"logprobs must be greater than 0, got {self.logprobs}.")
214-
if self.logprobs is not None and self.logprobs > 20:
215-
raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
212+
if self.logprobs is not None and (self.logprobs < 0 or self.logprobs > 20):
213+
raise ValueError("Invalid value for 'top_logprobs': must be between 0 and 20.")
216214
if self.prompt_logprobs is not None:
217215
raise ValueError("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled.")
218216
elif os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "1":
219217
if self.logprobs is not None and self.logprobs < -1:
220-
raise ValueError(f"logprobs must be greater than -1, got {self.logprobs}.")
218+
raise ValueError(f"logprobs must be a non-negative value or -1, got {self.logprobs}.")
221219
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
222-
raise ValueError(f"prompt_logprobs must be greater than or equal to -1, got {self.prompt_logprobs}.")
220+
raise ValueError(f"prompt_logprobs a must be non-negative value or -1, got {self.prompt_logprobs}.")
223221

224222
if not 0 <= self.seed <= 922337203685477580:
225223
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")

fastdeploy/entrypoints/engine_client.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,15 @@ def valid_parameters(self, data):
357357
api_server_logger.error(err_msg)
358358
raise ParameterError("prompt_logprobs", err_msg)
359359

360-
if prompt_logprobs == -1:
361-
prompt_logprobs = self.ori_vocab_size
360+
if prompt_logprobs == -1 and self.ori_vocab_size > max_logprobs:
361+
err_msg = f"The requested value of ({self.ori_vocab_size}) for prompt_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
362+
api_server_logger.error(err_msg)
363+
raise ValueError("prompt_logprobs", err_msg)
362364

363365
if prompt_logprobs < -1:
364-
err_msg = f"Invalid 'prompt_logprobs': must be >= -1, got {prompt_logprobs}."
366+
err_msg = (
367+
f"prompt_logprobs must be a non-negative value or -1; the current value is {prompt_logprobs}."
368+
)
365369
api_server_logger.error(err_msg)
366370
raise ValueError("prompt_logprobs", err_msg)
367371

@@ -384,19 +388,18 @@ def valid_parameters(self, data):
384388
raise ParameterError("top_logprobs", err_msg)
385389

386390
if os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "0":
387-
if top_logprobs < 0:
388-
err_msg = "Invalid value for 'top_logprobs': must be >= 0."
389-
raise ValueError("top_logprobs", err_msg)
390-
391-
if top_logprobs > 20:
392-
err_msg = "Invalid value for 'top_logprobs': must be <= 20."
391+
if top_logprobs < 0 or top_logprobs > 20:
392+
err_msg = f"top_logprobs must be between 0 and 20; the current value is {top_logprobs}."
393+
api_server_logger.error(err_msg)
393394
raise ValueError("top_logprobs", err_msg)
394395
else:
395-
if top_logprobs == -1:
396-
top_logprobs = self.ori_vocab_size
396+
if top_logprobs == -1 and self.ori_vocab_size > max_logprobs:
397+
err_msg = f"The requested value of ({self.ori_vocab_size}) for top_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
398+
api_server_logger.error(err_msg)
399+
raise ValueError("top_logprobs", err_msg)
397400

398401
if top_logprobs < -1:
399-
err_msg = f"Invalid 'top_logprobs': must be >= -1, got {top_logprobs}."
402+
err_msg = f"top_logprobs must be a non-negative value or -1; the current value is {top_logprobs}."
400403
api_server_logger.error(err_msg)
401404
raise ValueError("top_logprobs", err_msg)
402405

fastdeploy/entrypoints/llm.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,10 @@ def _add_request(
351351

352352
if current_sampling_params.logprobs is not None:
353353
num_logprobs = current_sampling_params.logprobs
354-
if num_logprobs == -1:
355-
num_logprobs = ori_vocab_size
354+
if num_logprobs == -1 and ori_vocab_size > max_logprobs:
355+
raise ValueError(
356+
f"Number of logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
357+
)
356358
if num_logprobs > max_logprobs:
357359
raise ValueError(
358360
f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})."
@@ -363,8 +365,10 @@ def _add_request(
363365
if kwargs.get("stream"):
364366
raise ValueError("prompt_logprobs is not supported with streaming.")
365367
num_prompt_logprobs = current_sampling_params.prompt_logprobs
366-
if num_prompt_logprobs == -1:
367-
num_prompt_logprobs = ori_vocab_size
368+
if num_prompt_logprobs == -1 and ori_vocab_size > max_logprobs:
369+
raise ValueError(
370+
f"Number of prompt_logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
371+
)
368372
if num_prompt_logprobs > max_logprobs:
369373
raise ValueError(
370374
f"Number of logprobs requested ({num_prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})."
@@ -561,7 +565,7 @@ def _run_engine(
561565
result.outputs.logprobs = self._build_sample_logprobs(
562566
result.outputs.top_logprobs, topk_logprobs
563567
)
564-
if result.prompt_logprobs and num_prompt_logprobs:
568+
if result.prompt_logprobs is not None and num_prompt_logprobs is not None:
565569
if num_prompt_logprobs == -1:
566570
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
567571
result.prompt_logprobs = self._build_prompt_logprobs(

fastdeploy/entrypoints/openai/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ def check_logprobs(cls, data):
788788
if top_logprobs < -1:
789789
raise ValueError("`top_logprobs` must be a greater than -1.")
790790

791-
if top_logprobs > 0 and not data.get("logprobs"):
791+
if not data.get("logprobs"):
792792
raise ValueError("when using `top_logprobs`, `logprobs` must be set to true.")
793793

794794
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ async def chat_completion_stream_generator(
299299
for i in range(num_choices):
300300
prompt_logprobs_res: Optional[PromptLogprobs] = None
301301
prompt_logprobs_tensors = res.get("prompt_logprobs", None)
302-
if request.prompt_logprobs and prompt_logprobs_tensors is not None:
302+
if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None:
303303
num_prompt_logprobs = (
304304
request.prompt_logprobs
305305
if request.prompt_logprobs != -1
@@ -583,7 +583,7 @@ async def chat_completion_full_generator(
583583
if draft_logprobs_res and draft_logprobs_res.content is not None:
584584
draft_logprob_contents[idx].extend(draft_logprobs_res.content)
585585
prompt_logprobs_tensors = data.get("prompt_logprobs", None)
586-
if request.prompt_logprobs and prompt_logprobs_tensors is not None:
586+
if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None:
587587
num_prompt_logprobs = (
588588
request.prompt_logprobs
589589
if request.prompt_logprobs != -1

fastdeploy/entrypoints/openai/serving_completion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ async def completion_stream_generator(
445445
prompt_logprobs_res: Optional[PromptLogprobs] = None
446446
if first_iteration[idx]:
447447
prompt_logprobs_tensors = res.get("prompt_logprobs", None)
448-
if request.prompt_logprobs and prompt_logprobs_tensors is not None:
448+
if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None:
449449
num_prompt_logprobs = (
450450
request.prompt_logprobs
451451
if request.prompt_logprobs != -1
@@ -495,7 +495,7 @@ async def completion_stream_generator(
495495
output_draft_top_logprobs = output["draft_top_logprobs"]
496496
logprobs_res: Optional[CompletionLogprobs] = None
497497
draft_logprobs_res: Optional[CompletionLogprobs] = None
498-
if request.logprobs and output_top_logprobs is not None:
498+
if request.logprobs is not None and output_top_logprobs is not None:
499499
num_logprobs = (
500500
request.logprobs if request.logprobs != -1 else self.engine_client.ori_vocab_size
501501
)
@@ -644,7 +644,7 @@ def request_output_to_completion_response(
644644
)
645645
prompt_logprobs_res: Optional[PromptLogprobs] = None
646646
prompt_logprobs_tensors = final_res.get("prompt_logprobs_tensors", None)
647-
if request.prompt_logprobs and prompt_logprobs_tensors is not None:
647+
if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None:
648648
num_prompt_logprobs = (
649649
request.prompt_logprobs if request.prompt_logprobs != -1 else self.engine_client.ori_vocab_size
650650
)

tests/engine/test_sampling_params.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_logprobs_invalid_less_than_minus_one(self):
6161
params = SamplingParams(logprobs=-2)
6262
params._verify_args()
6363

64-
self.assertIn("logprobs must be greater than -1", str(cm.exception))
64+
self.assertIn("logprobs must be a non-negative value or -1", str(cm.exception))
6565
self.assertIn("got -2", str(cm.exception))
6666

6767
def test_logprobs_invalid_less_than_zero(self):
@@ -71,8 +71,7 @@ def test_logprobs_invalid_less_than_zero(self):
7171
params = SamplingParams(logprobs=-1)
7272
params._verify_args()
7373

74-
self.assertIn("logprobs must be greater than 0", str(cm.exception))
75-
self.assertIn("got -1", str(cm.exception))
74+
self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", str(cm.exception))
7675

7776
def test_logprobs_greater_than_20_with_v1_disabled(self):
7877
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is disabled"""
@@ -81,7 +80,7 @@ def test_logprobs_greater_than_20_with_v1_disabled(self):
8180
params = SamplingParams(logprobs=21)
8281
params._verify_args()
8382

84-
self.assertEqual("Invalid value for 'top_logprobs': must be less than or equal to 20.", str(cm.exception))
83+
self.assertEqual("Invalid value for 'top_logprobs': must be between 0 and 20.", str(cm.exception))
8584

8685
def test_logprobs_greater_than_20_with_v1_enabled(self):
8786
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is enabled"""
@@ -127,7 +126,7 @@ def test_prompt_logprobs_invalid_less_than_minus_one(self):
127126
params = SamplingParams(prompt_logprobs=-2)
128127
params._verify_args()
129128

130-
self.assertIn("prompt_logprobs must be greater than or equal to -1", str(cm.exception))
129+
self.assertIn("prompt_logprobs a must be non-negative value or -1", str(cm.exception))
131130
self.assertIn("got -2", str(cm.exception))
132131

133132
def test_combined_logprobs_and_prompt_logprobs(self):
@@ -234,7 +233,7 @@ def test_error_message_formatting(self):
234233
params._verify_args()
235234

236235
error_msg = str(cm.exception)
237-
self.assertIn("logprobs must be greater than -1", error_msg)
236+
self.assertIn("logprobs must be a non-negative value or -1", error_msg)
238237
self.assertIn("got -5", error_msg)
239238

240239
# Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0"
@@ -244,8 +243,7 @@ def test_error_message_formatting(self):
244243
params._verify_args()
245244

246245
error_msg = str(cm.exception)
247-
self.assertIn("logprobs must be greater than 0", error_msg)
248-
self.assertIn("got -1", error_msg)
246+
self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", error_msg)
249247

250248
# Test prompt_logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1"
251249
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
@@ -254,7 +252,7 @@ def test_error_message_formatting(self):
254252
params._verify_args()
255253

256254
error_msg = str(cm.exception)
257-
self.assertIn("prompt_logprobs must be greater than or equal to -1", error_msg)
255+
self.assertIn("prompt_logprobs a must be non-negative value or -1", error_msg)
258256
self.assertIn("got -10", error_msg)
259257

260258
# Test prompt_logprobs not supported error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0"

tests/entrypoints/test_engine_client.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def test_max_logprobs_invalid_values(self):
133133

134134
self.assertIn("max_logprobs", str(context.exception))
135135
self.assertIn("must be >= -1", str(context.exception))
136+
self.assertIn("got -2", str(context.exception))
136137

137138
def test_max_logprobs_exceeds_vocab_size(self):
138139
"""Test max_logprobs exceeding vocab_size"""
@@ -146,7 +147,7 @@ def test_max_logprobs_exceeds_vocab_size(self):
146147
self.assertIn("max_logprobs", str(context.exception))
147148
self.assertIn("must be <= vocab_size", str(context.exception))
148149
self.assertIn("1000", str(context.exception))
149-
self.assertIn("1500", str(context.exception))
150+
self.assertIn("got 1500", str(context.exception))
150151

151152
def test_max_logprobs_unlimited(self):
152153
"""Test max_logprobs = -1 (unlimited) sets to ori_vocab_size"""
@@ -237,7 +238,8 @@ def test_prompt_logprobs_invalid_values(self):
237238
self.engine_client.valid_parameters(data)
238239

239240
self.assertIn("prompt_logprobs", str(context.exception))
240-
self.assertIn("must be >= -1", str(context.exception))
241+
self.assertIn("must be a non-negative value or -1", str(context.exception))
242+
self.assertIn("current value is -2", str(context.exception))
241243

242244
def test_prompt_logprobs_exceeds_max_logprobs(self):
243245
"""Test prompt_logprobs exceeding max_logprobs"""
@@ -252,6 +254,8 @@ def test_prompt_logprobs_exceeds_max_logprobs(self):
252254

253255
self.assertIn("prompt_logprobs", str(context.exception))
254256
self.assertIn("exceeds maximum allowed value", str(context.exception))
257+
self.assertIn("15", str(context.exception))
258+
self.assertIn("10", str(context.exception))
255259

256260
def test_top_logprobs_validation_with_fd_use_get_save_output_v1_enabled(self):
257261
"""Test top_logprobs validation when FD_USE_GET_SAVE_OUTPUT_V1 is enabled"""
@@ -275,7 +279,8 @@ def test_top_logprobs_validation_with_fd_use_get_save_output_v1_enabled(self):
275279
data = {"logprobs": True, "top_logprobs": -2, "request_id": "test"}
276280
with self.assertRaises(ValueError) as context:
277281
self.engine_client.valid_parameters(data)
278-
self.assertIn("must be >= -1", str(context.exception))
282+
self.assertIn("must be a non-negative value or -1", str(context.exception))
283+
self.assertIn("current value is -2", str(context.exception))
279284

280285
# Test value exceeding max_logprobs - should raise ValueError
281286
data = {"logprobs": True, "top_logprobs": 25, "request_id": "test"}
@@ -293,13 +298,15 @@ def test_top_logprobs_validation_with_fd_use_get_save_output_v1_disabled(self):
293298
data = {"logprobs": True, "top_logprobs": -1, "request_id": "test"}
294299
with self.assertRaises(ValueError) as context:
295300
self.engine_client.valid_parameters(data)
296-
self.assertIn("must be >= 0", str(context.exception))
301+
self.assertIn("top_logprobs must be between 0 and 20", str(context.exception))
302+
self.assertIn("current value is -1", str(context.exception))
297303

298304
# Test value > 20 - should raise ValueError
299305
data = {"logprobs": True, "top_logprobs": 25, "request_id": "test"}
300306
with self.assertRaises(ValueError) as context:
301307
self.engine_client.valid_parameters(data)
302-
self.assertIn("must be <= 20", str(context.exception))
308+
self.assertIn("top_logprobs must be between 0 and 20", str(context.exception))
309+
self.assertIn("current value is 25", str(context.exception))
303310

304311
# Test valid value
305312
data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"}

0 commit comments

Comments
 (0)