Skip to content

Commit 1af87ef

Browse files
committed
Add streaming response process.
1 parent 7857fa6 commit 1af87ef

File tree

7 files changed

+368
-60
lines changed

7 files changed

+368
-60
lines changed

pkg/epp/handlers/response.go

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@ limitations under the License.
1717
package handlers
1818

1919
import (
20+
"bytes"
2021
"context"
2122
"encoding/json"
22-
"fmt"
2323
"strings"
2424

2525
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2626
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
2727
"sigs.k8s.io/controller-runtime/pkg/log"
2828

2929
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
30+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3031
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3132
)
3233

@@ -36,49 +37,50 @@ const (
3637
)
3738

3839
// HandleResponseBody always returns the requestContext even in the error case, as the request context is used in error handling.
39-
func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, response map[string]any) (*RequestContext, error) {
40+
func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, body []byte) (*RequestContext, error) {
4041
logger := log.FromContext(ctx)
41-
responseBytes, err := json.Marshal(response)
42+
llmResponse, err := types.NewLLMResponseFromBytes(body)
4243
if err != nil {
43-
return reqCtx, fmt.Errorf("error marshalling responseBody - %w", err)
44+
logger.Error(err, "failed to create LLMResponse from bytes")
45+
return reqCtx, err
4446
}
45-
if response["usage"] != nil {
46-
usg := response["usage"].(map[string]any)
47-
usage := Usage{
48-
PromptTokens: int(usg["prompt_tokens"].(float64)),
49-
CompletionTokens: int(usg["completion_tokens"].(float64)),
50-
TotalTokens: int(usg["total_tokens"].(float64)),
51-
}
47+
reqCtx.SchedulingResponse = llmResponse
48+
if usage := reqCtx.SchedulingResponse.Usage(); usage != nil {
5249
reqCtx.Usage = usage
53-
logger.V(logutil.VERBOSE).Info("Response generated", "usage", reqCtx.Usage)
50+
logger.V(logutil.VERBOSE).Info("Response generated", "usage", usage)
5451
}
55-
reqCtx.ResponseSize = len(responseBytes)
52+
reqCtx.ResponseSize = len(body)
5653
// ResponseComplete is to indicate the response is complete. In non-streaming
5754
// case, it will be set to be true once the response is processed; in
5855
// streaming case, it will be set to be true once the last chunk is processed.
5956
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/178)
6057
// will add the processing for streaming case.
6158
reqCtx.ResponseComplete = true
6259

63-
reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true)
60+
reqCtx.respBodyResp = generateResponseBodyResponses(body, true)
6461

6562
return s.director.HandleResponseBodyComplete(ctx, reqCtx)
6663
}
6764

6865
// The function is to handle streaming response if the modelServer is streaming.
69-
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
66+
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, streamBody []byte) {
7067
logger := log.FromContext(ctx)
7168
_, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx)
7269
if err != nil {
7370
logger.Error(err, "error in HandleResponseBodyStreaming")
7471
}
75-
if strings.Contains(responseText, streamingEndMsg) {
72+
if bytes.Contains(streamBody, []byte(streamingEndMsg)) {
7673
reqCtx.ResponseComplete = true
77-
resp := parseRespForUsage(ctx, responseText)
78-
reqCtx.Usage = resp.Usage
79-
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens)
80-
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens)
81-
_, err := s.director.HandleResponseBodyComplete(ctx, reqCtx)
74+
resp, err := types.NewLLMResponseFromStream(streamBody)
75+
if err != nil {
76+
logger.Error(err, "error in converting stream response to LLMResponse.")
77+
}
78+
if usage := resp.Usage(); usage != nil {
79+
reqCtx.Usage = usage
80+
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, usage.PromptTokens)
81+
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, usage.CompletionTokens)
82+
}
83+
_, err = s.director.HandleResponseBodyComplete(ctx, reqCtx)
8284
if err != nil {
8385
logger.Error(err, "error in HandleResponseBodyComplete")
8486
}

pkg/epp/handlers/response_test.go

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ package handlers
1818

1919
import (
2020
"context"
21-
"encoding/json"
2221
"testing"
2322

2423
"github.com/google/go-cmp/cmp"
2524

2625
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
26+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2727
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2828
)
2929

@@ -52,12 +52,33 @@ const (
5252
}
5353
`
5454

55-
streamingBodyWithoutUsage = `data: {"id":"cmpl-41764c93-f9d2-4f31-be08-3ba04fa25394","object":"text_completion","created":1740002445,"model":"food-review-0","choices":[],"usage":null}
56-
`
55+
streamingBodyWithoutUsage = `
56+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"}}]}
5757
58-
streamingBodyWithUsage = `data: {"id":"cmpl-41764c93-f9d2-4f31-be08-3ba04fa25394","object":"text_completion","created":1740002445,"model":"food-review-0","choices":[],"usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}}
59-
data: [DONE]
60-
`
58+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"}}]}
59+
60+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" world"}}]}
61+
62+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
63+
64+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":null}
65+
66+
data: [DONE]
67+
`
68+
69+
streamingBodyWithUsage = `
70+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"}}]}
71+
72+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"}}]}
73+
74+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" world"}}]}
75+
76+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
77+
78+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":7,"total_tokens":12}}
79+
80+
data: [DONE]
81+
`
6182
)
6283

6384
type mockDirector struct{}
@@ -88,13 +109,13 @@ func TestHandleResponseBody(t *testing.T) {
88109
name string
89110
body []byte
90111
reqCtx *RequestContext
91-
want Usage
112+
want *types.Usage
92113
wantErr bool
93114
}{
94115
{
95116
name: "success",
96117
body: []byte(body),
97-
want: Usage{
118+
want: &types.Usage{
98119
PromptTokens: 11,
99120
TotalTokens: 111,
100121
CompletionTokens: 100,
@@ -110,12 +131,7 @@ func TestHandleResponseBody(t *testing.T) {
110131
if reqCtx == nil {
111132
reqCtx = &RequestContext{}
112133
}
113-
var responseMap map[string]any
114-
marshalErr := json.Unmarshal(test.body, &responseMap)
115-
if marshalErr != nil {
116-
t.Error(marshalErr, "Error unmarshaling request body")
117-
}
118-
_, err := server.HandleResponseBody(ctx, reqCtx, responseMap)
134+
_, err := server.HandleResponseBody(ctx, reqCtx, test.body)
119135
if err != nil {
120136
if !test.wantErr {
121137
t.Fatalf("HandleResponseBody returned unexpected error: %v, want %v", err, test.wantErr)
@@ -136,7 +152,7 @@ func TestHandleStreamedResponseBody(t *testing.T) {
136152
name string
137153
body string
138154
reqCtx *RequestContext
139-
want Usage
155+
want *types.Usage
140156
wantErr bool
141157
}{
142158
{
@@ -155,10 +171,10 @@ func TestHandleStreamedResponseBody(t *testing.T) {
155171
modelServerStreaming: true,
156172
},
157173
wantErr: false,
158-
want: Usage{
159-
PromptTokens: 7,
160-
TotalTokens: 17,
161-
CompletionTokens: 10,
174+
want: &types.Usage{
175+
PromptTokens: 5,
176+
TotalTokens: 12,
177+
CompletionTokens: 7,
162178
},
163179
},
164180
}
@@ -171,7 +187,7 @@ func TestHandleStreamedResponseBody(t *testing.T) {
171187
if reqCtx == nil {
172188
reqCtx = &RequestContext{}
173189
}
174-
server.HandleResponseBodyModelStreaming(ctx, reqCtx, test.body)
190+
server.HandleResponseBodyModelStreaming(ctx, reqCtx, []byte(test.body))
175191

176192
if diff := cmp.Diff(test.want, reqCtx.Usage); diff != "" {
177193
t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff)

pkg/epp/handlers/server.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,15 @@ type RequestContext struct {
8585
RequestReceivedTimestamp time.Time
8686
ResponseCompleteTimestamp time.Time
8787
RequestSize int
88-
Usage Usage
88+
Usage *schedulingtypes.Usage
8989
ResponseSize int
9090
ResponseComplete bool
9191
ResponseStatusCode string
9292
RequestRunning bool
9393
Request *Request
9494

95-
SchedulingRequest *schedulingtypes.LLMRequest
95+
SchedulingRequest *schedulingtypes.LLMRequest
96+
SchedulingResponse *schedulingtypes.LLMResponse
9697

9798
RequestState StreamRequestState
9899
modelServerStreaming bool
@@ -115,7 +116,6 @@ type Request struct {
115116
}
116117
type Response struct {
117118
Headers map[string]string
118-
Body []byte
119119
}
120120
type StreamRequestState int
121121

@@ -268,11 +268,10 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
268268
reqCtx.respHeaderResp = s.generateResponseHeaderResponse(reqCtx)
269269

270270
case *extProcPb.ProcessingRequest_ResponseBody:
271+
body = append(body, v.ResponseBody.Body...)
271272
if reqCtx.modelServerStreaming {
272273
// Currently we punt on response parsing if the modelServer is streaming, and we just passthrough.
273-
274-
responseText := string(v.ResponseBody.Body)
275-
s.HandleResponseBodyModelStreaming(ctx, reqCtx, responseText)
274+
s.HandleResponseBodyModelStreaming(ctx, reqCtx, body)
276275
if v.ResponseBody.EndOfStream {
277276
loggerTrace.Info("stream completed")
278277

@@ -283,8 +282,6 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
283282

284283
reqCtx.respBodyResp = generateResponseBodyResponses(v.ResponseBody.Body, v.ResponseBody.EndOfStream)
285284
} else {
286-
body = append(body, v.ResponseBody.Body...)
287-
288285
// Message is buffered, we can read and decode.
289286
if v.ResponseBody.EndOfStream {
290287
loggerTrace.Info("stream completed")
@@ -303,8 +300,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
303300
break
304301
}
305302

306-
reqCtx.Response.Body = body
307-
reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, responseBody)
303+
reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, body)
308304
if responseErr != nil {
309305
if logger.V(logutil.DEBUG).Enabled() {
310306
logger.V(logutil.DEBUG).Error(responseErr, "Failed to process response body", "request", req)

pkg/epp/requestcontrol/director.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,12 +292,11 @@ func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handl
292292
requestID := reqCtx.Request.Headers[requtil.RequestIdHeaderKey]
293293
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk", requtil.RequestIdHeaderKey, requestID)
294294
logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete")
295-
llmResponse, err := schedulingtypes.NewLLMResponseFromBytes(reqCtx.Response.Body)
296-
if err != nil {
297-
logger.Error(err, "HandleResponseBodyComplete: failed to convert the response to LLMResponse.")
295+
if reqCtx.SchedulingResponse == nil {
296+
err := fmt.Errorf("nil scheduling reponse from reqCtx")
298297
return reqCtx, err
299298
}
300-
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, llmResponse, reqCtx.TargetPod)
299+
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, reqCtx.SchedulingResponse, reqCtx.TargetPod)
301300

302301
logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")
303302
return reqCtx, nil

pkg/epp/requestcontrol/director_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,9 +690,9 @@ func TestDirector_HandleResponseComplete(t *testing.T) {
690690
},
691691
Response: &handlers.Response{
692692
Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"},
693-
Body: []byte(chatCompletionJSON),
694693
},
695-
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
694+
SchedulingResponse: wantLLMResponse,
695+
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
696696
}
697697

698698
_, err = director.HandleResponseBodyComplete(ctx, reqCtx)

0 commit comments

Comments
 (0)