Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def generate(self, job_input: JobInput):
yield {"error": create_error_response(str(e)).model_dump()}

async def _generate_vllm(self, llm_input, validated_sampling_params, batch_size, stream, apply_chat_template, request_id, batch_size_growth_factor, min_batch_size: str) -> AsyncGenerator[dict, None]:
if apply_chat_template or isinstance(llm_input, list):
if apply_chat_template:
tokenizer_wrapper = self._get_tokenizer_for_chat_template()
llm_input = tokenizer_wrapper.apply_chat_template(llm_input)
results_generator = self.llm.generate(llm_input, validated_sampling_params, request_id)
Expand Down Expand Up @@ -299,4 +299,4 @@ async def _handle_chat_or_completion_request(self, openai_request: JobInput):
if self.raw_openai_output:
batch = "".join(batch)
yield batch


33 changes: 30 additions & 3 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,39 @@ def count_physical_cores():
return len(cores)


# These are to support sending multiple prompts or token arrays in a single request
def prompt_to_vllm_prompt(prompt):
if len(prompt) == 0:
return vllm.TextPrompt(prompt=prompt)
elif prompt is list:
return [vllm.TextPrompt(prompt=p) for p in prompt]
else:
return vllm.TextPrompt(prompt=prompt)

def tokens_to_vllm_prompt(tokens):
if len(tokens) == 0:
return vllm.TokensPrompt(prompt_token_ids=tokens)
elif tokens[0] is list: # Multiple prompts in one entry
return [vllm.TokensPrompt(prompt_token_ids=toks) for toks in tokens]
else:
return vllm.TokensPrompt(prompt_token_ids=tokens)

def get_llm_input(job):
for k, fn in [
("messages", lambda messages: messages),
("prompt", prompt_to_vllm_prompt),
("tokens", tokens_to_vllm_prompt)]:
value = job.get(k)
if value:
return fn(value)
return None

class JobInput:
def __init__(self, job):
self.llm_input = job.get("messages", job.get("prompt"))
self.llm_input = get_llm_input(job)
self.stream = job.get("stream", False)
self.max_batch_size = job.get("max_batch_size")
self.apply_chat_template = job.get("apply_chat_template", False)
self.apply_chat_template = job.get("apply_chat_template", job.get("messages") is not None)
self.use_openai_format = job.get("use_openai_format", False)
samp_param = job.get("sampling_params", {})
if "max_tokens" not in samp_param:
Expand Down Expand Up @@ -103,4 +130,4 @@ def wrapper(*args, **kwargs):
end = time()
logging.info(f"{func.__name__} completed in {end - start:.2f} seconds")
return result
return wrapper
return wrapper