@@ -17,16 +17,17 @@ limitations under the License.
1717package handlers
1818
1919import (
20+ "bytes"
2021 "context"
2122 "encoding/json"
22- "fmt"
2323 "strings"
2424
2525 configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2626 extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
2727 "sigs.k8s.io/controller-runtime/pkg/log"
2828
2929 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
30+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3031 logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3132)
3233
@@ -36,49 +37,50 @@ const (
3637)
3738
3839// HandleResponseBody always returns the requestContext even in the error case, as the request context is used in error handling.
39- func (s * StreamingServer ) HandleResponseBody (ctx context.Context , reqCtx * RequestContext , response map [ string ] any ) (* RequestContext , error ) {
40+ func (s * StreamingServer ) HandleResponseBody (ctx context.Context , reqCtx * RequestContext , body [] byte ) (* RequestContext , error ) {
4041 logger := log .FromContext (ctx )
41- responseBytes , err := json . Marshal ( response )
42+ llmResponse , err := types . NewLLMResponseFromBytes ( body )
4243 if err != nil {
43- return reqCtx , fmt .Errorf ("error marshalling responseBody - %w" , err )
44+ logger .Error (err , "failed to create LLMResponse from bytes" )
45+ return reqCtx , err
4446 }
45- if response ["usage" ] != nil {
46- usg := response ["usage" ].(map [string ]any )
47- usage := Usage {
48- PromptTokens : int (usg ["prompt_tokens" ].(float64 )),
49- CompletionTokens : int (usg ["completion_tokens" ].(float64 )),
50- TotalTokens : int (usg ["total_tokens" ].(float64 )),
51- }
47+ reqCtx .SchedulingResponse = llmResponse
48+ if usage := reqCtx .SchedulingResponse .Usage (); usage != nil {
5249 reqCtx .Usage = usage
53- logger .V (logutil .VERBOSE ).Info ("Response generated" , "usage" , reqCtx . Usage )
50+ logger .V (logutil .VERBOSE ).Info ("Response generated" , "usage" , usage )
5451 }
55- reqCtx .ResponseSize = len (responseBytes )
52+ reqCtx .ResponseSize = len (body )
5653 // ResponseComplete is to indicate the response is complete. In non-streaming
5754 // case, it will be set to be true once the response is processed; in
5855 // streaming case, it will be set to be true once the last chunk is processed.
5956 // TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/178)
6057 // will add the processing for streaming case.
6158 reqCtx .ResponseComplete = true
6259
63- reqCtx .respBodyResp = generateResponseBodyResponses (responseBytes , true )
60+ reqCtx .respBodyResp = generateResponseBodyResponses (body , true )
6461
6562 return s .director .HandleResponseBodyComplete (ctx , reqCtx )
6663}
6764
6865// The function is to handle streaming response if the modelServer is streaming.
69- func (s * StreamingServer ) HandleResponseBodyModelStreaming (ctx context.Context , reqCtx * RequestContext , responseText string ) {
66+ func (s * StreamingServer ) HandleResponseBodyModelStreaming (ctx context.Context , reqCtx * RequestContext , streamBody [] byte ) {
7067 logger := log .FromContext (ctx )
7168 _ , err := s .director .HandleResponseBodyStreaming (ctx , reqCtx )
7269 if err != nil {
7370 logger .Error (err , "error in HandleResponseBodyStreaming" )
7471 }
75- if strings .Contains (responseText , streamingEndMsg ) {
72+ if bytes .Contains (streamBody , [] byte ( streamingEndMsg ) ) {
7673 reqCtx .ResponseComplete = true
77- resp := parseRespForUsage (ctx , responseText )
78- reqCtx .Usage = resp .Usage
79- metrics .RecordInputTokens (reqCtx .IncomingModelName , reqCtx .TargetModelName , resp .Usage .PromptTokens )
80- metrics .RecordOutputTokens (reqCtx .IncomingModelName , reqCtx .TargetModelName , resp .Usage .CompletionTokens )
81- _ , err := s .director .HandleResponseBodyComplete (ctx , reqCtx )
74+ resp , err := types .NewLLMResponseFromStream (streamBody )
75+ if err != nil {
76+ logger .Error (err , "error in converting stream response to LLMResponse." )
77+ }
78+ if usage := resp .Usage (); usage != nil {
79+ reqCtx .Usage = usage
80+ metrics .RecordInputTokens (reqCtx .IncomingModelName , reqCtx .TargetModelName , usage .PromptTokens )
81+ metrics .RecordOutputTokens (reqCtx .IncomingModelName , reqCtx .TargetModelName , usage .CompletionTokens )
82+ }
83+ _ , err = s .director .HandleResponseBodyComplete (ctx , reqCtx )
8284 if err != nil {
8385 logger .Error (err , "error in HandleResponseBodyComplete" )
8486 }
0 commit comments