Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
if !ok {
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request"}
}
prompt, ok := requestBodyMap["prompt"].(string)
if !ok {
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "prompt not found in request"}
prompt, err := requtil.ExtractPromptFromRequestBody(requestBodyMap)
if err != nil {
return reqCtx, err
}

// NOTE: The nil checking for the modelObject means that we DO allow passthrough currently.
Expand Down
81 changes: 79 additions & 2 deletions pkg/epp/requestcontrol/director_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestHandleRequest(t *testing.T) {
wantRespBody map[string]interface{}
}{
{
name: "successful request",
name: "successful completions request",
reqBodyMap: map[string]interface{}{
"model": tsModel,
"prompt": "test prompt",
Expand All @@ -102,7 +102,69 @@ func TestHandleRequest(t *testing.T) {
},
},
{
name: "successful request with target model",
name: "successful chat completions request",
reqBodyMap: map[string]interface{}{
"model": tsModel,
"messages": []interface{}{
map[string]interface{}{
"role": "user",
"content": "test prompt",
},
},
},
wantReqCtx: &handlers.RequestContext{
Model: tsModel,
ResolvedTargetModel: tsModel,
TargetPod: "/pod1",
TargetEndpoint: "address-1:8000",
},
wantRespBody: map[string]interface{}{
"model": tsModel,
"messages": []interface{}{
map[string]interface{}{
"role": "user",
"content": "test prompt",
},
},
},
},
{
name: "successful chat completions request with multiple messages",
reqBodyMap: map[string]interface{}{
"model": tsModel,
"messages": []interface{}{
map[string]interface{}{
"role": "developer",
"content": "You are a helpful assistant.",
},
map[string]interface{}{
"role": "user",
"content": "Hello!",
},
},
},
wantReqCtx: &handlers.RequestContext{
Model: tsModel,
ResolvedTargetModel: tsModel,
TargetPod: "/pod1",
TargetEndpoint: "address-1:8000",
},
wantRespBody: map[string]interface{}{
"model": tsModel,
"messages": []interface{}{
map[string]interface{}{
"role": "developer",
"content": "You are a helpful assistant.",
},
map[string]interface{}{
"role": "user",
"content": "Hello!",
},
},
},
},
{
name: "successful completions request with target model",
reqBodyMap: map[string]interface{}{
"model": modelWithTarget,
"prompt": "test prompt",
Expand All @@ -122,6 +184,21 @@ func TestHandleRequest(t *testing.T) {
name: "no model defined, expect err",
wantErrCode: errutil.BadRequest,
},
{
name: "prompt or messages not found, expect err",
reqBodyMap: map[string]interface{}{
"model": tsModel,
},
wantErrCode: errutil.BadRequest,
},
{
name: "empty messages, expect err",
reqBodyMap: map[string]interface{}{
"model": tsModel,
"messages": []interface{}{},
},
wantErrCode: errutil.BadRequest,
},
{
name: "invalid model defined, expect err",
reqBodyMap: map[string]interface{}{
Expand Down
86 changes: 86 additions & 0 deletions pkg/epp/util/request/body.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package request

import (
"fmt"

errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
)

func ExtractPromptFromRequestBody(body map[string]interface{}) (string, error) {
if _, ok := body["messages"]; ok {
return extractPromptFromMessagesField(body)
}
return extractPromptField(body)
}

func extractPromptField(body map[string]interface{}) (string, error) {
prompt, ok := body["prompt"]
if !ok {
return "", errutil.Error{Code: errutil.BadRequest, Msg: "prompt not found in request"}
}
promptStr, ok := prompt.(string)
if !ok {
return "", errutil.Error{Code: errutil.BadRequest, Msg: "prompt is not a string"}
}
return promptStr, nil
}

func extractPromptFromMessagesField(body map[string]interface{}) (string, error) {
messages, ok := body["messages"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, since prefix-aware routing is an attempt at estimating the locations of KVCache, this may be sufficient to some degree, but a chat-completions request is more complex. Two chat-completion requests can have the same messages but lead to entirely different KV blocks.

See this struct for example: https://github.com/sashabaranov/go-openai/blob/6181facea7e6e5525b6b8da42205d7cce822c249/chat.go#L95

And an example to how a chat-completions request is templated before tokenization in vLLM: https://github.com/vllm-project/vllm/blob/main/examples/tool_chat_template_llama3.2_json.jinja

Copy link
Contributor Author

@delavet delavet May 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your valuable suggestions. I indeed designed this function with relatively simple logic because, in my understanding, EPP should be as model-agnostic as possible: for chat completions requests sent to the same model, we only need to ensure that requests with the same message prefix receive the same prompt, which should suffice for Prefix Cache Aware Routing.

Based on this, I referenced the simple template from https://github.com/vllm-project/vllm/blob/main/examples/template_chatml.jinja to perform basic processing on the message list.

I do notice that the template you provided includes some more complex details, and I plan to implement some improvements to better handle these cases:

  • In the OpenAI schema, content may not always be a string but can also be an array: additional handling can be added in the code to deal with this case.
  • Requests may include multimodal content such as images or videos: since the current purpose of extracting the prompt is mainly for prefix cache aware routing, I assume we can temporarily ignore the multimodal parts, especially since GIE currently is not claimed to support multimodal models. I believe these can be addressed later when multimodal support is introduced.
  • The request body may contain different tools, and this might result in different system prompts. I plan to add logic to simulate this behavior, referencing the example provided here.

This should cover most of the common scenarios. If any others have any further comments or suggestions, please help to point it out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's best to keep this PR small and focused since it resolves a known bug. We can iterate, if needed, to resolve any potential kv cache inefficiencies.

if !ok {
return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages not found in request"}
}
messageList, ok := messages.([]interface{})
if !ok {
return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages is not a list"}
}
if len(messageList) == 0 {
return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages is empty"}
}

prompt := ""
for _, msg := range messageList {
msgMap, ok := msg.(map[string]interface{})
if !ok {
continue
}
content, ok := msgMap["content"]
if !ok {
continue
}
contentStr, ok := content.(string)
if !ok {
continue
}
role, ok := msgMap["role"]
if !ok {
continue
}
roleStr, ok := role.(string)
if !ok {
continue
}
prompt += constructChatMessage(roleStr, contentStr)
}
return prompt, nil
}

func constructChatMessage(role string, content string) string {
return fmt.Sprintf("<|im_start|>%s\n%s<|im_end|>\n", role, content)
}
191 changes: 191 additions & 0 deletions pkg/epp/util/request/body_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package request

import (
"testing"
)

func TestExtractPromptFromRequestBody(t *testing.T) {
tests := []struct {
name string
body map[string]interface{}
want string
wantErr bool
errType error
}{
{
name: "chat completions request body",
body: map[string]interface{}{
"model": "test",
"messages": []interface{}{
map[string]interface{}{
"role": "system", "content": "this is a system message",
},
map[string]interface{}{
"role": "user", "content": "hello",
},
map[string]interface{}{
"role": "assistant", "content": "hi, what can I do for you?",
},
},
},
want: "<|im_start|>system\nthis is a system message<|im_end|>\n" +
"<|im_start|>user\nhello<|im_end|>\n" +
"<|im_start|>assistant\nhi, what can I do for you?<|im_end|>\n",
},
{
name: "completions request body",
body: map[string]interface{}{
"model": "test",
"prompt": "test prompt",
},
want: "test prompt",
},
{
name: "invalid prompt format",
body: map[string]interface{}{
"model": "test",
"prompt": []interface{}{
map[string]interface{}{
"role": "system", "content": "this is a system message",
},
map[string]interface{}{
"role": "user", "content": "hello",
},
map[string]interface{}{
"role": "assistant", "content": "hi, what can I",
},
},
},
wantErr: true,
},
{
name: "invalid messaged format",
body: map[string]interface{}{
"model": "test",
"messages": map[string]interface{}{
"role": "system", "content": "this is a system message",
},
},
wantErr: true,
},
{
name: "prompt does not exist",
body: map[string]interface{}{
"model": "test",
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ExtractPromptFromRequestBody(tt.body)
if (err != nil) != tt.wantErr {
t.Errorf("ExtractPromptFromRequestBody() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ExtractPromptFromRequestBody() got = %v, want %v", got, tt.want)
}
})
}
}

func TestExtractPromptField(t *testing.T) {
tests := []struct {
name string
body map[string]interface{}
want string
wantErr bool
}{
{
name: "valid prompt",
body: map[string]interface{}{
"prompt": "test prompt",
},
want: "test prompt",
},
{
name: "prompt not found",
body: map[string]interface{}{},
wantErr: true,
},
{
name: "non-string prompt",
body: map[string]interface{}{
"prompt": 123,
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractPromptField(tt.body)
if (err != nil) != tt.wantErr {
t.Errorf("extractPromptField() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("extractPromptField() got = %v, want %v", got, tt.want)
}
})
}
}

func TestExtractPromptFromMessagesField(t *testing.T) {
tests := []struct {
name string
body map[string]interface{}
want string
wantErr bool
}{
{
name: "valid messages",
body: map[string]interface{}{
"messages": []interface{}{
map[string]interface{}{"role": "user", "content": "test1"},
map[string]interface{}{"role": "assistant", "content": "test2"},
},
},
want: "<|im_start|>user\ntest1<|im_end|>\n<|im_start|>assistant\ntest2<|im_end|>\n",
},
{
name: "invalid messages format",
body: map[string]interface{}{
"messages": "invalid",
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractPromptFromMessagesField(tt.body)
if (err != nil) != tt.wantErr {
t.Errorf("extractPromptFromMessagesField() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("extractPromptFromMessagesField() got = %v, want %v", got, tt.want)
}
})
}
}

func TestConstructChatMessage(t *testing.T) {
tests := []struct {
role string
content string
want string
}{
{"user", "hello", "<|im_start|>user\nhello<|im_end|>\n"},
{"assistant", "hi", "<|im_start|>assistant\nhi<|im_end|>\n"},
}

for _, tt := range tests {
if got := constructChatMessage(tt.role, tt.content); got != tt.want {
t.Errorf("constructChatMessage() = %v, want %v", got, tt.want)
}
}
}