Skip to content

Commit f8bfc3c

Browse files
committed
- revert suggestion on h.Write calls in prefix plugin
- refactoring based on review Signed-off-by: Maroon Ayoub <[email protected]>
1 parent 61e31c0 commit f8bfc3c

File tree

4 files changed

+39
-142
lines changed

4 files changed

+39
-142
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
123123
}
124124

125125
// compile-time type assertion
126-
var _ framework.Scorer = &Plugin{}
127-
var _ requestcontrol.PreRequest = &Plugin{}
126+
var (
127+
_ framework.Scorer = &Plugin{}
128+
_ requestcontrol.PreRequest = &Plugin{}
129+
)
128130

129131
// PrefixCachePluginFactory defines the factory function for Prefix plugin.
130132
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
@@ -238,7 +240,6 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
238240
for server := range cachedServers {
239241
// Update servers with their longest prefix match.
240242
res[server]++
241-
242243
}
243244
}
244245
}
@@ -269,17 +270,17 @@ func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize i
269270
loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
270271
userInput = userInput[:maxPrefixBlocks*cacheBlockSize]
271272
}
272-
// Split the body into blocks of size cacheBlockSize.
273+
// Split the body into blocks of size cacheBlockSize.
273274
// If the last block is smaller than cacheBlockSize, it will be ignored.
274275
res := make([]BlockHash, 0, len(userInput)/cacheBlockSize)
275276
// Add the model to the first block hash so that different models have different hashes even with the same body.
276277
h := xxhash.New()
277-
h.Write([]byte(request.TargetModel))
278+
_, _ = h.Write([]byte(request.TargetModel))
278279
prevBlockHash := BlockHash(h.Sum64())
279280
for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize {
280281
h.Reset()
281-
h.Write(userInput[i : i+cacheBlockSize])
282-
h.Write(toBytes(prevBlockHash))
282+
_, _ = h.Write(userInput[i : i+cacheBlockSize])
283+
_, _ = h.Write(toBytes(prevBlockHash))
283284
res = append(res, BlockHash(h.Sum64()))
284285

285286
prevBlockHash = res[len(res)-1]

pkg/epp/scheduling/types/types.go

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import (
2323
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
2424
)
2525

26+
const nilString = "<nil>"
27+
2628
// LLMRequest is a structured representation of the fields we parse out of the LLMRequest body.
2729
type LLMRequest struct {
2830
// RequestId is the Envoy generated Id for the request being processed
@@ -36,6 +38,10 @@ type LLMRequest struct {
3638
}
3739

3840
func (r *LLMRequest) String() string {
41+
if r == nil {
42+
return nilString
43+
}
44+
3945
return fmt.Sprintf("RequestID: %s, TargetModel: %s, RequestData: %s, Headers: %v",
4046
r.RequestId, r.TargetModel, r.Data, r.Headers)
4147
}
@@ -51,12 +57,19 @@ type LLMRequestData struct {
5157
}
5258

5359
func (r *LLMRequestData) String() string {
60+
if r == nil {
61+
return nilString
62+
}
63+
5464
if r.Completions != nil {
55-
return "Completions: " + r.Completions.String()
65+
return r.Completions.String()
66+
}
67+
68+
if r.ChatCompletions != nil {
69+
return r.ChatCompletions.String()
5670
}
5771

58-
// Must be a ChatCompletionsRequest
59-
return "ChatCompletions: " + r.ChatCompletions.String()
72+
return ""
6073
}
6174

6275
// CompletionsRequest is a structured representation of the fields we parse out of the
@@ -65,10 +78,14 @@ func (r *LLMRequestData) String() string {
6578
// API spec.
6679
type CompletionsRequest struct {
6780
// Prompt is the prompt that was sent in the request body.
68-
Prompt string
81+
Prompt string `json:"prompt,omitempty"`
6982
}
7083

7184
func (r *CompletionsRequest) String() string {
85+
if r == nil {
86+
return nilString
87+
}
88+
7289
return fmt.Sprintf("{PromptLength: %d}", len(r.Prompt))
7390
}
7491

@@ -78,7 +95,7 @@ func (r *CompletionsRequest) String() string {
7895
// API spec.
7996
type ChatCompletionsRequest struct {
8097
/* parameters from the official OpenAI chat-completions API */
81-
Messages []Message
98+
Messages []Message `json:"messages,omitempty"`
8299
Tools []interface{} `json:"tools,omitempty"`
83100
/* parameters from the HuggingFace transformers chat-templates API */
84101
Documents []interface{} `json:"documents,omitempty"`
@@ -90,6 +107,10 @@ type ChatCompletionsRequest struct {
90107
}
91108

92109
func (r *ChatCompletionsRequest) String() string {
110+
if r == nil {
111+
return nilString
112+
}
113+
93114
messagesLen := 0
94115
for _, msg := range r.Messages {
95116
messagesLen += len(msg.Content)
@@ -117,8 +138,9 @@ type ScoredPod struct {
117138

118139
func (pm *PodMetrics) String() string {
119140
if pm == nil {
120-
return ""
141+
return nilString
121142
}
143+
122144
return fmt.Sprintf("%+v", *pm)
123145
}
124146

pkg/epp/util/request/body.go

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,5 @@ func validateChatCompletionsMessages(messages []types.Message) error {
5555
return errutil.Error{Code: errutil.BadRequest, Msg: "chat-completions request must have at least one message"}
5656
}
5757

58-
for i, msg := range messages {
59-
if msg.Role == "" {
60-
return errutil.Error{Code: errutil.BadRequest, Msg: "message at index " + string(rune(i)) + " is missing role"}
61-
}
62-
if msg.Content == "" {
63-
return errutil.Error{Code: errutil.BadRequest, Msg: "message at index " + string(rune(i)) + " is missing content"}
64-
}
65-
}
66-
6758
return nil
6859
}

pkg/epp/util/request/body_test.go

Lines changed: 3 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package request
1919
import (
2020
"testing"
2121

22+
"github.com/google/go-cmp/cmp"
2223
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2324
)
2425

@@ -127,26 +128,6 @@ func TestExtractRequestData(t *testing.T) {
127128
},
128129
wantErr: true,
129130
},
130-
{
131-
name: "message missing role",
132-
body: map[string]any{
133-
"model": "test",
134-
"messages": []any{
135-
map[string]any{"content": "hello"},
136-
},
137-
},
138-
wantErr: true,
139-
},
140-
{
141-
name: "message missing content",
142-
body: map[string]any{
143-
"model": "test",
144-
"messages": []any{
145-
map[string]any{"role": "user"},
146-
},
147-
},
148-
wantErr: true,
149-
},
150131
{
151132
name: "message with non-string role",
152133
body: map[string]any{
@@ -257,111 +238,13 @@ func TestExtractRequestData(t *testing.T) {
257238
return
258239
}
259240

260-
// Compare the results
261-
if !compareResults(got, tt.want, t) {
262-
t.Errorf("ExtractRequestData() result mismatch")
241+
if diff := cmp.Diff(tt.want, got); diff != "" {
242+
t.Errorf("ExtractRequestData() mismatch (-want +got):\n%s", diff)
263243
}
264244
})
265245
}
266246
}
267247

268-
func compareResults(got, want *types.LLMRequestData, t *testing.T) bool {
269-
switch {
270-
case got.Completions != nil && want.Completions != nil:
271-
return compareCompletionsRequest(got.Completions, want.Completions, t)
272-
case got.ChatCompletions != nil && want.ChatCompletions != nil:
273-
return compareChatCompletionsRequest(got.ChatCompletions, want.ChatCompletions, t)
274-
case got.Completions == nil && want.Completions == nil && got.ChatCompletions == nil && want.ChatCompletions == nil:
275-
return true
276-
default:
277-
t.Errorf("Result type mismatch: got completions=%v, chatCompletions=%v; want completions=%v, chatCompletions=%v",
278-
got.Completions != nil, got.ChatCompletions != nil, want.Completions != nil, want.ChatCompletions != nil)
279-
return false
280-
}
281-
}
282-
283-
func compareCompletionsRequest(got, want *types.CompletionsRequest, t *testing.T) bool {
284-
if got.Prompt != want.Prompt {
285-
t.Errorf("CompletionsRequest.Prompt = %v, want %v", got.Prompt, want.Prompt)
286-
return false
287-
}
288-
return true
289-
}
290-
291-
func compareChatCompletionsRequest(got, want *types.ChatCompletionsRequest, t *testing.T) bool {
292-
// Compare messages
293-
if len(got.Messages) != len(want.Messages) {
294-
t.Errorf("Messages length = %v, want %v", len(got.Messages), len(want.Messages))
295-
return false
296-
}
297-
for i, msg := range got.Messages {
298-
wantMsg := want.Messages[i]
299-
if msg.Role != wantMsg.Role || msg.Content != wantMsg.Content {
300-
t.Errorf("Message[%d] = %v, want %v", i, msg, wantMsg)
301-
return false
302-
}
303-
}
304-
305-
// Compare optional fields
306-
if got.ChatTemplate != want.ChatTemplate {
307-
t.Errorf("ChatTemplate = %v, want %v", got.ChatTemplate, want.ChatTemplate)
308-
return false
309-
}
310-
if got.ReturnAssistantTokensMask != want.ReturnAssistantTokensMask {
311-
t.Errorf("ReturnAssistantTokensMask = %v, want %v", got.ReturnAssistantTokensMask, want.ReturnAssistantTokensMask)
312-
return false
313-
}
314-
if got.ContinueFinalMessage != want.ContinueFinalMessage {
315-
t.Errorf("ContinueFinalMessage = %v, want %v", got.ContinueFinalMessage, want.ContinueFinalMessage)
316-
return false
317-
}
318-
if got.AddGenerationPrompt != want.AddGenerationPrompt {
319-
t.Errorf("AddGenerationPrompt = %v, want %v", got.AddGenerationPrompt, want.AddGenerationPrompt)
320-
return false
321-
}
322-
323-
// Compare tools (shallow comparison for test purposes)
324-
if !compareSliceAny(got.Tools, want.Tools) {
325-
t.Errorf("Tools mismatch")
326-
return false
327-
}
328-
329-
// Compare documents (shallow comparison for test purposes)
330-
if !compareSliceAny(got.Documents, want.Documents) {
331-
t.Errorf("Documents mismatch")
332-
return false
333-
}
334-
335-
// Compare chat template kwargs (shallow comparison for test purposes)
336-
if !compareMapAny(got.ChatTemplateKWArgs, want.ChatTemplateKWArgs) {
337-
t.Errorf("ChatTemplateKWArgs mismatch")
338-
return false
339-
}
340-
341-
return true
342-
}
343-
344-
func compareSliceAny(got, want []any) bool {
345-
if len(got) != len(want) {
346-
return false
347-
}
348-
// For test purposes, we'll do a simple length check and type check
349-
// In practice, you might want deeper comparison depending on your needs
350-
return true
351-
}
352-
353-
func compareMapAny(got, want map[string]any) bool {
354-
if len(got) != len(want) {
355-
return false
356-
}
357-
for k, v := range want {
358-
if gotV, exists := got[k]; !exists || gotV != v {
359-
return false
360-
}
361-
}
362-
return true
363-
}
364-
365248
// Benchmark tests for performance comparison
366249
func BenchmarkExtractRequestData_Completions(b *testing.B) {
367250
body := map[string]any{

0 commit comments

Comments
 (0)