diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index bfcf2ec6d..85c8ee34f 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -62,9 +62,9 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo if !ok { return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request"} } - prompt, ok := requestBodyMap["prompt"].(string) - if !ok { - return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "prompt not found in request"} + prompt, err := requtil.ExtractPromptFromRequestBody(requestBodyMap) + if err != nil { + return reqCtx, err } // NOTE: The nil checking for the modelObject means that we DO allow passthrough currently. diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index 05dc1b3b8..e4384a80b 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -85,7 +85,7 @@ func TestHandleRequest(t *testing.T) { wantRespBody map[string]interface{} }{ { - name: "successful request", + name: "successful completions request", reqBodyMap: map[string]interface{}{ "model": tsModel, "prompt": "test prompt", @@ -102,7 +102,69 @@ func TestHandleRequest(t *testing.T) { }, }, { - name: "successful request with target model", + name: "successful chat completions request", + reqBodyMap: map[string]interface{}{ + "model": tsModel, + "messages": []interface{}{ + map[string]interface{}{ + "role": "user", + "content": "test prompt", + }, + }, + }, + wantReqCtx: &handlers.RequestContext{ + Model: tsModel, + ResolvedTargetModel: tsModel, + TargetPod: "/pod1", + TargetEndpoint: "address-1:8000", + }, + wantRespBody: map[string]interface{}{ + "model": tsModel, + "messages": []interface{}{ + map[string]interface{}{ + "role": "user", + "content": "test prompt", + }, + }, + }, + }, + { + name: "successful chat completions request with multiple messages", + reqBodyMap: map[string]interface{}{ + "model": tsModel, + "messages": []interface{}{ + map[string]interface{}{ + "role": "developer", + "content": "You are a helpful assistant.", + }, + map[string]interface{}{ + "role": "user", + "content": "Hello!", + }, + }, + }, + wantReqCtx: &handlers.RequestContext{ + Model: tsModel, + ResolvedTargetModel: tsModel, + TargetPod: "/pod1", + TargetEndpoint: "address-1:8000", + }, + wantRespBody: map[string]interface{}{ + "model": tsModel, + "messages": []interface{}{ + map[string]interface{}{ + "role": "developer", + "content": "You are a helpful assistant.", + }, + map[string]interface{}{ + "role": "user", + "content": "Hello!", + }, + }, + }, + }, + { + name: "successful completions request with target model", reqBodyMap: map[string]interface{}{ "model": modelWithTarget, "prompt": "test prompt", @@ -122,6 +184,21 @@ func TestHandleRequest(t *testing.T) { name: "no model defined, expect err", wantErrCode: errutil.BadRequest, }, + { + name: "prompt or messages not found, expect err", + reqBodyMap: map[string]interface{}{ + "model": tsModel, + }, + wantErrCode: errutil.BadRequest, + }, + { + name: "empty messages, expect err", + reqBodyMap: map[string]interface{}{ + "model": tsModel, + "messages": []interface{}{}, + }, + wantErrCode: errutil.BadRequest, + }, { name: "invalid model defined, expect err", reqBodyMap: map[string]interface{}{ diff --git a/pkg/epp/util/request/body.go b/pkg/epp/util/request/body.go new file mode 100644 index 000000000..83a600f08 --- /dev/null +++ b/pkg/epp/util/request/body.go @@ -0,0 +1,86 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package request + +import ( + "fmt" + + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" +) + +func ExtractPromptFromRequestBody(body map[string]interface{}) (string, error) { + if _, ok := body["messages"]; ok { + return extractPromptFromMessagesField(body) + } + return extractPromptField(body) +} + +func extractPromptField(body map[string]interface{}) (string, error) { + prompt, ok := body["prompt"] + if !ok { + return "", errutil.Error{Code: errutil.BadRequest, Msg: "prompt not found in request"} + } + promptStr, ok := prompt.(string) + if !ok { + return "", errutil.Error{Code: errutil.BadRequest, Msg: "prompt is not a string"} + } + return promptStr, nil +} + +func extractPromptFromMessagesField(body map[string]interface{}) (string, error) { + messages, ok := body["messages"] + if !ok { + return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages not found in request"} + } + messageList, ok := messages.([]interface{}) + if !ok { + return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages is not a list"} + } + if len(messageList) == 0 { + return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages is empty"} + } + + prompt := "" + for _, msg := range messageList { + msgMap, ok := msg.(map[string]interface{}) + if !ok { + continue + } + content, ok := msgMap["content"] + if !ok { + continue + } + contentStr, ok := content.(string) + if !ok { + continue + } + role, ok := msgMap["role"] + if !ok { + continue + } + roleStr, ok := role.(string) + if !ok { + continue + } + prompt += constructChatMessage(roleStr, contentStr) + } + return prompt, nil +} + +func constructChatMessage(role string, content string) string { + return fmt.Sprintf("<|im_start|>%s\n%s<|im_end|>\n", role, content) +} diff --git a/pkg/epp/util/request/body_test.go b/pkg/epp/util/request/body_test.go new file mode 100644 index 000000000..563fc8cf8 --- /dev/null +++ b/pkg/epp/util/request/body_test.go @@ -0,0 +1,191 @@ +package request + +import ( + "testing" +) + +func TestExtractPromptFromRequestBody(t *testing.T) { + tests := []struct { + name string + body map[string]interface{} + want string + wantErr bool + errType error + }{ + { + name: "chat completions request body", + body: map[string]interface{}{ + "model": "test", + "messages": []interface{}{ + map[string]interface{}{ + "role": "system", "content": "this is a system message", + }, + map[string]interface{}{ + "role": "user", "content": "hello", + }, + map[string]interface{}{ + "role": "assistant", "content": "hi, what can I do for you?", + }, + }, + }, + want: "<|im_start|>system\nthis is a system message<|im_end|>\n" + + "<|im_start|>user\nhello<|im_end|>\n" + + "<|im_start|>assistant\nhi, what can I do for you?<|im_end|>\n", + }, + { + name: "completions request body", + body: map[string]interface{}{ + "model": "test", + "prompt": "test prompt", + }, + want: "test prompt", + }, + { + name: "invalid prompt format", + body: map[string]interface{}{ + "model": "test", + "prompt": []interface{}{ + map[string]interface{}{ + "role": "system", "content": "this is a system message", + }, + map[string]interface{}{ + "role": "user", "content": "hello", + }, + map[string]interface{}{ + "role": "assistant", "content": "hi, what can I", + }, + }, + }, + wantErr: true, + }, + { + name: "invalid messaged format", + body: map[string]interface{}{ + "model": "test", + "messages": map[string]interface{}{ + "role": "system", "content": "this is a system message", + }, + }, + wantErr: true, + }, + { + name: "prompt does not exist", + body: map[string]interface{}{ + "model": "test", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ExtractPromptFromRequestBody(tt.body) + if (err != nil) != tt.wantErr { + t.Errorf("ExtractPromptFromRequestBody() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ExtractPromptFromRequestBody() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestExtractPromptField(t *testing.T) { + tests := []struct { + name string + body map[string]interface{} + want string + wantErr bool + }{ + { + name: "valid prompt", + body: map[string]interface{}{ + "prompt": "test prompt", + }, + want: "test prompt", + }, + { + name: "prompt not found", + body: map[string]interface{}{}, + wantErr: true, + }, + { + name: "non-string prompt", + body: map[string]interface{}{ + "prompt": 123, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractPromptField(tt.body) + if (err != nil) != tt.wantErr { + t.Errorf("extractPromptField() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("extractPromptField() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestExtractPromptFromMessagesField(t *testing.T) { + tests := []struct { + name string + body map[string]interface{} + want string + wantErr bool + }{ + { + name: "valid messages", + body: map[string]interface{}{ + "messages": []interface{}{ + map[string]interface{}{"role": "user", "content": "test1"}, + map[string]interface{}{"role": "assistant", "content": "test2"}, + }, + }, + want: "<|im_start|>user\ntest1<|im_end|>\n<|im_start|>assistant\ntest2<|im_end|>\n", + }, + { + name: "invalid messages format", + body: map[string]interface{}{ + "messages": "invalid", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractPromptFromMessagesField(tt.body) + if (err != nil) != tt.wantErr { + t.Errorf("extractPromptFromMessagesField() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("extractPromptFromMessagesField() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConstructChatMessage(t *testing.T) { + tests := []struct { + role string + content string + want string + }{ + {"user", "hello", "<|im_start|>user\nhello<|im_end|>\n"}, + {"assistant", "hi", "<|im_start|>assistant\nhi<|im_end|>\n"}, + } + + for _, tt := range tests { + if got := constructChatMessage(tt.role, tt.content); got != tt.want { + t.Errorf("constructChatMessage() = %v, want %v", got, tt.want) + } + } +}