-
Notifications
You must be signed in to change notification settings - Fork 180
support extracting prompt from chat completions API #798
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
/* | ||
Copyright 2025 The Kubernetes Authors. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package request | ||
delavet marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
import ( | ||
"fmt" | ||
|
||
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" | ||
) | ||
|
||
func ExtractPromptFromRequestBody(body map[string]interface{}) (string, error) { | ||
if _, ok := body["messages"]; ok { | ||
return extractPromptFromMessagesField(body) | ||
} | ||
return extractPromptField(body) | ||
} | ||
|
||
func extractPromptField(body map[string]interface{}) (string, error) { | ||
prompt, ok := body["prompt"] | ||
if !ok { | ||
return "", errutil.Error{Code: errutil.BadRequest, Msg: "prompt not found in request"} | ||
} | ||
promptStr, ok := prompt.(string) | ||
if !ok { | ||
return "", errutil.Error{Code: errutil.BadRequest, Msg: "prompt is not a string"} | ||
} | ||
return promptStr, nil | ||
} | ||
|
||
func extractPromptFromMessagesField(body map[string]interface{}) (string, error) { | ||
messages, ok := body["messages"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, since prefix-aware routing is an attempt at estimating the locations of KVCache, this may be sufficient to some degree, but a chat-completions request is more complex. Two chat-completion requests can have the same messages but lead to entirely different KV blocks. See this struct for example: https://github.com/sashabaranov/go-openai/blob/6181facea7e6e5525b6b8da42205d7cce822c249/chat.go#L95 And an example to how a chat-completions request is templated before tokenization in vLLM: https://github.com/vllm-project/vllm/blob/main/examples/tool_chat_template_llama3.2_json.jinja There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for your valuable suggestions. I indeed designed this function with relatively simple logic because, in my understanding, EPP should be as model-agnostic as possible: for chat completions requests sent to the same model, we only need to ensure that requests with the same message prefix receive the same prompt, which should suffice for Prefix Cache Aware Routing. Based on this, I referenced the simple template from https://github.com/vllm-project/vllm/blob/main/examples/template_chatml.jinja to perform basic processing on the message list. I do notice that the template you provided includes some more complex details, and I plan to implement some improvements to better handle these cases:
This should cover most of the common scenarios. If any others have any further comments or suggestions, please help to point it out. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's best to keep this PR small and focused since it resolves a known bug. We can iterate, if needed, to resolve any potential kv cache inefficiencies. |
||
if !ok { | ||
return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages not found in request"} | ||
} | ||
messageList, ok := messages.([]interface{}) | ||
if !ok { | ||
return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages is not a list"} | ||
} | ||
if len(messageList) == 0 { | ||
return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages is empty"} | ||
} | ||
|
||
prompt := "" | ||
for _, msg := range messageList { | ||
msgMap, ok := msg.(map[string]interface{}) | ||
if !ok { | ||
continue | ||
} | ||
content, ok := msgMap["content"] | ||
if !ok { | ||
continue | ||
} | ||
contentStr, ok := content.(string) | ||
if !ok { | ||
continue | ||
} | ||
role, ok := msgMap["role"] | ||
if !ok { | ||
continue | ||
} | ||
roleStr, ok := role.(string) | ||
if !ok { | ||
continue | ||
} | ||
prompt += constructChatMessage(roleStr, contentStr) | ||
} | ||
return prompt, nil | ||
} | ||
|
||
func constructChatMessage(role string, content string) string { | ||
return fmt.Sprintf("<|im_start|>%s\n%s<|im_end|>\n", role, content) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
package request | ||
|
||
import ( | ||
"testing" | ||
) | ||
|
||
func TestExtractPromptFromRequestBody(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
body map[string]interface{} | ||
want string | ||
wantErr bool | ||
errType error | ||
}{ | ||
{ | ||
name: "chat completions request body", | ||
body: map[string]interface{}{ | ||
"model": "test", | ||
"messages": []interface{}{ | ||
map[string]interface{}{ | ||
"role": "system", "content": "this is a system message", | ||
}, | ||
map[string]interface{}{ | ||
"role": "user", "content": "hello", | ||
}, | ||
map[string]interface{}{ | ||
"role": "assistant", "content": "hi, what can I do for you?", | ||
}, | ||
}, | ||
}, | ||
want: "<|im_start|>system\nthis is a system message<|im_end|>\n" + | ||
"<|im_start|>user\nhello<|im_end|>\n" + | ||
"<|im_start|>assistant\nhi, what can I do for you?<|im_end|>\n", | ||
}, | ||
{ | ||
name: "completions request body", | ||
body: map[string]interface{}{ | ||
"model": "test", | ||
"prompt": "test prompt", | ||
}, | ||
want: "test prompt", | ||
}, | ||
{ | ||
name: "invalid prompt format", | ||
body: map[string]interface{}{ | ||
"model": "test", | ||
"prompt": []interface{}{ | ||
map[string]interface{}{ | ||
"role": "system", "content": "this is a system message", | ||
}, | ||
map[string]interface{}{ | ||
"role": "user", "content": "hello", | ||
}, | ||
map[string]interface{}{ | ||
"role": "assistant", "content": "hi, what can I", | ||
}, | ||
}, | ||
}, | ||
wantErr: true, | ||
}, | ||
{ | ||
name: "invalid messaged format", | ||
body: map[string]interface{}{ | ||
"model": "test", | ||
"messages": map[string]interface{}{ | ||
"role": "system", "content": "this is a system message", | ||
}, | ||
}, | ||
wantErr: true, | ||
}, | ||
{ | ||
name: "prompt does not exist", | ||
body: map[string]interface{}{ | ||
"model": "test", | ||
}, | ||
wantErr: true, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
got, err := ExtractPromptFromRequestBody(tt.body) | ||
if (err != nil) != tt.wantErr { | ||
t.Errorf("ExtractPromptFromRequestBody() error = %v, wantErr %v", err, tt.wantErr) | ||
return | ||
} | ||
if got != tt.want { | ||
t.Errorf("ExtractPromptFromRequestBody() got = %v, want %v", got, tt.want) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestExtractPromptField(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
body map[string]interface{} | ||
want string | ||
wantErr bool | ||
}{ | ||
{ | ||
name: "valid prompt", | ||
body: map[string]interface{}{ | ||
"prompt": "test prompt", | ||
}, | ||
want: "test prompt", | ||
}, | ||
{ | ||
name: "prompt not found", | ||
body: map[string]interface{}{}, | ||
wantErr: true, | ||
}, | ||
{ | ||
name: "non-string prompt", | ||
body: map[string]interface{}{ | ||
"prompt": 123, | ||
}, | ||
wantErr: true, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
got, err := extractPromptField(tt.body) | ||
if (err != nil) != tt.wantErr { | ||
t.Errorf("extractPromptField() error = %v, wantErr %v", err, tt.wantErr) | ||
return | ||
} | ||
if got != tt.want { | ||
t.Errorf("extractPromptField() got = %v, want %v", got, tt.want) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestExtractPromptFromMessagesField(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
body map[string]interface{} | ||
want string | ||
wantErr bool | ||
}{ | ||
{ | ||
name: "valid messages", | ||
body: map[string]interface{}{ | ||
"messages": []interface{}{ | ||
map[string]interface{}{"role": "user", "content": "test1"}, | ||
map[string]interface{}{"role": "assistant", "content": "test2"}, | ||
}, | ||
}, | ||
want: "<|im_start|>user\ntest1<|im_end|>\n<|im_start|>assistant\ntest2<|im_end|>\n", | ||
}, | ||
{ | ||
name: "invalid messages format", | ||
body: map[string]interface{}{ | ||
"messages": "invalid", | ||
}, | ||
wantErr: true, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
got, err := extractPromptFromMessagesField(tt.body) | ||
if (err != nil) != tt.wantErr { | ||
t.Errorf("extractPromptFromMessagesField() error = %v, wantErr %v", err, tt.wantErr) | ||
return | ||
} | ||
if got != tt.want { | ||
t.Errorf("extractPromptFromMessagesField() got = %v, want %v", got, tt.want) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestConstructChatMessage(t *testing.T) { | ||
tests := []struct { | ||
role string | ||
content string | ||
want string | ||
}{ | ||
{"user", "hello", "<|im_start|>user\nhello<|im_end|>\n"}, | ||
{"assistant", "hi", "<|im_start|>assistant\nhi<|im_end|>\n"}, | ||
} | ||
|
||
for _, tt := range tests { | ||
if got := constructChatMessage(tt.role, tt.content); got != tt.want { | ||
t.Errorf("constructChatMessage() = %v, want %v", got, tt.want) | ||
} | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.