Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ steps:
- pip install awscli tensorizer
- python3 offline_inference.py
- python3 cpu_offload.py
- python3 offline_inference_chat.py
- python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py
- python3 llava_example.py
Expand Down
69 changes: 69 additions & 0 deletions examples/offline_inference_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from vllm import LLM, SamplingParams

llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
sampling_params = SamplingParams(temperature=0.5)


def print_outputs(outputs):
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print("-" * 80)


print("=" * 80)

# In this script, we demonstrate two ways to pass input to the chat method:

# Conversation with a list of dictionaries
conversation = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{
"role": "user",
"content": "Write an essay about the importance of higher education.",
},
]
outputs = llm.chat(
conversation, sampling_params=sampling_params, use_tqdm=False
)
print_outputs(outputs)

# Multiple conversations
conversations = [
[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "What is dark matter?"},
],
[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "How are you?"},
{
"role": "assistant",
"content": "I'm an AI without feelings, but I'm here to help!",
},
{"role": "user", "content": "Tell me a joke."},
],
]

outputs = llm.chat(
conversations,
sampling_params=sampling_params,
use_tqdm=False,
)
print_outputs(outputs)

# A chat template can be optionally supplied.
# If not, the model will use its default chat template.

# with open('template_falcon_180b.jinja', "r") as f:
# chat_template = f.read()

# outputs = llm.chat(
# conversations,
# sampling_params=sampling_params,
# use_tqdm=False,
# chat_template=chat_template,
# )
41 changes: 41 additions & 0 deletions tests/entrypoints/llm/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,44 @@ def test_multiple_sampling_params(llm: LLM):
# sampling_params is None, default params should be applied
outputs = llm.generate(PROMPTS, sampling_params=None)
assert len(PROMPTS) == len(outputs)


def test_chat():

llm = LLM(model=MODEL_NAME)

prompt1 = "Explain the concept of entropy."
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
outputs = llm.chat(messages)
assert len(outputs) == 1

prompt2 = "Describe Bangkok in 150 words."
multiple_messages = [messages] + [[
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt2
},
]]
outputs = llm.chat(multiple_messages)
assert len(outputs) == len(multiple_messages)

sampling_params = [
SamplingParams(temperature=0.01, top_p=0.95),
SamplingParams(temperature=0.3, top_p=0.95),
]

outputs = llm.chat(messages, sampling_params=sampling_params)
assert len(outputs) == len(messages)
Loading