@@ -19,6 +19,7 @@ package request
19
19
import (
20
20
"testing"
21
21
22
+ "github.com/google/go-cmp/cmp"
22
23
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
23
24
)
24
25
@@ -127,26 +128,6 @@ func TestExtractRequestData(t *testing.T) {
127
128
},
128
129
wantErr : true ,
129
130
},
130
- {
131
- name : "message missing role" ,
132
- body : map [string ]any {
133
- "model" : "test" ,
134
- "messages" : []any {
135
- map [string ]any {"content" : "hello" },
136
- },
137
- },
138
- wantErr : true ,
139
- },
140
- {
141
- name : "message missing content" ,
142
- body : map [string ]any {
143
- "model" : "test" ,
144
- "messages" : []any {
145
- map [string ]any {"role" : "user" },
146
- },
147
- },
148
- wantErr : true ,
149
- },
150
131
{
151
132
name : "message with non-string role" ,
152
133
body : map [string ]any {
@@ -257,111 +238,13 @@ func TestExtractRequestData(t *testing.T) {
257
238
return
258
239
}
259
240
260
- // Compare the results
261
- if ! compareResults (got , tt .want , t ) {
262
- t .Errorf ("ExtractRequestData() result mismatch" )
241
+ if diff := cmp .Diff (tt .want , got ); diff != "" {
242
+ t .Errorf ("ExtractRequestData() mismatch (-want +got):\n %s" , diff )
263
243
}
264
244
})
265
245
}
266
246
}
267
247
268
- func compareResults (got , want * types.LLMRequestData , t * testing.T ) bool {
269
- switch {
270
- case got .Completions != nil && want .Completions != nil :
271
- return compareCompletionsRequest (got .Completions , want .Completions , t )
272
- case got .ChatCompletions != nil && want .ChatCompletions != nil :
273
- return compareChatCompletionsRequest (got .ChatCompletions , want .ChatCompletions , t )
274
- case got .Completions == nil && want .Completions == nil && got .ChatCompletions == nil && want .ChatCompletions == nil :
275
- return true
276
- default :
277
- t .Errorf ("Result type mismatch: got completions=%v, chatCompletions=%v; want completions=%v, chatCompletions=%v" ,
278
- got .Completions != nil , got .ChatCompletions != nil , want .Completions != nil , want .ChatCompletions != nil )
279
- return false
280
- }
281
- }
282
-
283
- func compareCompletionsRequest (got , want * types.CompletionsRequest , t * testing.T ) bool {
284
- if got .Prompt != want .Prompt {
285
- t .Errorf ("CompletionsRequest.Prompt = %v, want %v" , got .Prompt , want .Prompt )
286
- return false
287
- }
288
- return true
289
- }
290
-
291
- func compareChatCompletionsRequest (got , want * types.ChatCompletionsRequest , t * testing.T ) bool {
292
- // Compare messages
293
- if len (got .Messages ) != len (want .Messages ) {
294
- t .Errorf ("Messages length = %v, want %v" , len (got .Messages ), len (want .Messages ))
295
- return false
296
- }
297
- for i , msg := range got .Messages {
298
- wantMsg := want .Messages [i ]
299
- if msg .Role != wantMsg .Role || msg .Content != wantMsg .Content {
300
- t .Errorf ("Message[%d] = %v, want %v" , i , msg , wantMsg )
301
- return false
302
- }
303
- }
304
-
305
- // Compare optional fields
306
- if got .ChatTemplate != want .ChatTemplate {
307
- t .Errorf ("ChatTemplate = %v, want %v" , got .ChatTemplate , want .ChatTemplate )
308
- return false
309
- }
310
- if got .ReturnAssistantTokensMask != want .ReturnAssistantTokensMask {
311
- t .Errorf ("ReturnAssistantTokensMask = %v, want %v" , got .ReturnAssistantTokensMask , want .ReturnAssistantTokensMask )
312
- return false
313
- }
314
- if got .ContinueFinalMessage != want .ContinueFinalMessage {
315
- t .Errorf ("ContinueFinalMessage = %v, want %v" , got .ContinueFinalMessage , want .ContinueFinalMessage )
316
- return false
317
- }
318
- if got .AddGenerationPrompt != want .AddGenerationPrompt {
319
- t .Errorf ("AddGenerationPrompt = %v, want %v" , got .AddGenerationPrompt , want .AddGenerationPrompt )
320
- return false
321
- }
322
-
323
- // Compare tools (shallow comparison for test purposes)
324
- if ! compareSliceAny (got .Tools , want .Tools ) {
325
- t .Errorf ("Tools mismatch" )
326
- return false
327
- }
328
-
329
- // Compare documents (shallow comparison for test purposes)
330
- if ! compareSliceAny (got .Documents , want .Documents ) {
331
- t .Errorf ("Documents mismatch" )
332
- return false
333
- }
334
-
335
- // Compare chat template kwargs (shallow comparison for test purposes)
336
- if ! compareMapAny (got .ChatTemplateKWArgs , want .ChatTemplateKWArgs ) {
337
- t .Errorf ("ChatTemplateKWArgs mismatch" )
338
- return false
339
- }
340
-
341
- return true
342
- }
343
-
344
- func compareSliceAny (got , want []any ) bool {
345
- if len (got ) != len (want ) {
346
- return false
347
- }
348
- // For test purposes, we'll do a simple length check and type check
349
- // In practice, you might want deeper comparison depending on your needs
350
- return true
351
- }
352
-
353
- func compareMapAny (got , want map [string ]any ) bool {
354
- if len (got ) != len (want ) {
355
- return false
356
- }
357
- for k , v := range want {
358
- if gotV , exists := got [k ]; ! exists || gotV != v {
359
- return false
360
- }
361
- }
362
- return true
363
- }
364
-
365
248
// Benchmark tests for performance comparison
366
249
func BenchmarkExtractRequestData_Completions (b * testing.B ) {
367
250
body := map [string ]any {
0 commit comments