Skip to content

Commit 2623885

Browse files
Break PostResponse requestcontrol plugin into 3 separate plugins to add streamed request functionality (kubernetes-sigs#1661)
* Break out PostResponse plugin into 3 constituent plugins for request recieved, streaming, and complete * Fix typo in variable names * Log typed name in director.go and remove redundant director nil check in response.go * Renamed the post response plugins to not include the word post. * Fix function comment and pass existing logger into HandleResponseBodyStreaming * Update pkg/epp/requestcontrol/plugins.go Co-authored-by: Nir Rozenbaum <[email protected]> * Update pkg/epp/requestcontrol/request_control_config.go Co-authored-by: Nir Rozenbaum <[email protected]> * Update pkg/epp/requestcontrol/director.go Co-authored-by: Nir Rozenbaum <[email protected]> * Fix comments andlogs, simplify Director defintion to take in config * Revert logging parameter addition, keeping consistent with existing format for plugins --------- Co-authored-by: Nir Rozenbaum <[email protected]>
1 parent ec3194b commit 2623885

File tree

9 files changed

+279
-119
lines changed

9 files changed

+279
-119
lines changed

pkg/epp/handlers/response.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,24 +63,28 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques
6363
reqCtx.ResponseComplete = true
6464

6565
reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true, reqCtx, logger)
66-
return reqCtx, nil
66+
67+
return s.director.HandleResponseBodyComplete(ctx, reqCtx)
6768
}
6869

6970
// The function is to handle streaming response if the modelServer is streaming.
7071
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
72+
logger := log.FromContext(ctx)
73+
_, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx)
74+
if err != nil {
75+
logger.Error(err, "error in HandleResponseBodyStreaming")
76+
}
7177
if strings.Contains(responseText, streamingEndMsg) {
7278
reqCtx.ResponseComplete = true
7379
resp := parseRespForUsage(ctx, responseText)
7480
reqCtx.Usage = resp.Usage
7581
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens)
7682
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens)
77-
if s.director != nil {
78-
s.director.HandleResponseBodyComplete(ctx, reqCtx)
83+
_, err := s.director.HandleResponseBodyComplete(ctx, reqCtx)
84+
if err != nil {
85+
logger.Error(err, "error in HandleResponseBodyComplete")
7986
}
8087
}
81-
if s.director != nil {
82-
s.director.HandleResponseBodyChunk(ctx, reqCtx)
83-
}
8488
}
8589

8690
func (s *StreamingServer) HandleResponseHeaders(ctx context.Context, reqCtx *RequestContext, resp *extProcPb.ProcessingRequest_ResponseHeaders) (*RequestContext, error) {
@@ -92,7 +96,7 @@ func (s *StreamingServer) HandleResponseHeaders(ctx context.Context, reqCtx *Req
9296
}
9397
}
9498

95-
reqCtx, err := s.director.HandleResponse(ctx, reqCtx)
99+
reqCtx, err := s.director.HandleResponseReceived(ctx, reqCtx)
96100

97101
return reqCtx, err
98102
}

pkg/epp/handlers/response_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323

2424
"github.com/google/go-cmp/cmp"
2525

26+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2627
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2728
)
2829

@@ -59,6 +60,27 @@ data: [DONE]
5960
`
6061
)
6162

63+
type mockDirector struct{}
64+
65+
func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
66+
return reqCtx, nil
67+
}
68+
func (m *mockDirector) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
69+
return reqCtx, nil
70+
}
71+
func (m *mockDirector) HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
72+
return reqCtx, nil
73+
}
74+
func (m *mockDirector) HandlePreRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
75+
return reqCtx, nil
76+
}
77+
func (m *mockDirector) GetRandomPod() *backend.Pod {
78+
return &backend.Pod{}
79+
}
80+
func (m *mockDirector) HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
81+
return reqCtx, nil
82+
}
83+
6284
func TestHandleResponseBody(t *testing.T) {
6385
ctx := logutil.NewTestLoggerIntoContext(context.Background())
6486

@@ -83,6 +105,7 @@ func TestHandleResponseBody(t *testing.T) {
83105
for _, test := range tests {
84106
t.Run(test.name, func(t *testing.T) {
85107
server := &StreamingServer{}
108+
server.director = &mockDirector{}
86109
reqCtx := test.reqCtx
87110
if reqCtx == nil {
88111
reqCtx = &RequestContext{}
@@ -143,6 +166,7 @@ func TestHandleStreamedResponseBody(t *testing.T) {
143166
for _, test := range tests {
144167
t.Run(test.name, func(t *testing.T) {
145168
server := &StreamingServer{}
169+
server.director = &mockDirector{}
146170
reqCtx := test.reqCtx
147171
if reqCtx == nil {
148172
reqCtx = &RequestContext{}

pkg/epp/handlers/server.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ func NewStreamingServer(datastore Datastore, director Director) *StreamingServer
5555

5656
type Director interface {
5757
HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
58-
HandleResponse(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
59-
HandleResponseBodyChunk(ctx context.Context, reqCtx *RequestContext) error
60-
HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) error
58+
HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
59+
HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
60+
HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
6161
GetRandomPod() *backend.Pod
6262
}
6363

@@ -138,7 +138,7 @@ const (
138138
HeaderRequestResponseComplete StreamRequestState = 1
139139
BodyRequestResponsesComplete StreamRequestState = 2
140140
TrailerRequestResponsesComplete StreamRequestState = 3
141-
ResponseRecieved StreamRequestState = 4
141+
ResponseReceived StreamRequestState = 4
142142
HeaderResponseResponseComplete StreamRequestState = 5
143143
BodyResponseResponsesComplete StreamRequestState = 6
144144
TrailerResponseResponsesComplete StreamRequestState = 7
@@ -269,7 +269,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
269269
loggerTrace.Info("model server is streaming response")
270270
}
271271
}
272-
reqCtx.RequestState = ResponseRecieved
272+
reqCtx.RequestState = ResponseReceived
273273

274274
var responseErr error
275275
reqCtx, responseErr = s.HandleResponseHeaders(ctx, reqCtx, v)
@@ -396,7 +396,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces
396396
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
397397
}
398398
}
399-
if r.RequestState == ResponseRecieved && r.respHeaderResp != nil {
399+
if r.RequestState == ResponseReceived && r.respHeaderResp != nil {
400400
loggerTrace.Info("Sending response header response", "obj", r.respHeaderResp)
401401
if err := srv.Send(r.respHeaderResp); err != nil {
402402
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)

pkg/epp/requestcontrol/director.go

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,11 @@ func NewDirectorWithConfig(
156156
config *Config,
157157
) *Director {
158158
return &Director{
159-
datastore: datastore,
160-
scheduler: scheduler,
161-
admissionController: admissionController,
162-
preRequestPlugins: config.preRequestPlugins,
163-
postResponsePlugins: config.postResponsePlugins,
164-
postResponseChunkPlugins: config.postResponseChunkPlugins,
165-
postResponseCompletePlugins: config.postResponseCompletePlugins,
166-
defaultPriority: 0, // define default priority explicitly
159+
datastore: datastore,
160+
scheduler: scheduler,
161+
admissionController: admissionController,
162+
requestControlPlugins: *config,
163+
defaultPriority: 0, // define default priority explicitly
167164
}
168165
}
169166

@@ -177,13 +174,10 @@ func NewDirectorWithConfig(
177174
// - Preparing the request context for the Envoy ext_proc filter to route the request.
178175
// - Running PostResponse plugins.
179176
type Director struct {
180-
datastore Datastore
181-
scheduler Scheduler
182-
admissionController AdmissionController
183-
preRequestPlugins []PreRequest
184-
postResponsePlugins []PostResponse
185-
postResponseChunkPlugins []PostResponseChunk
186-
postResponseCompletePlugins []PostResponseComplete
177+
datastore Datastore
178+
scheduler Scheduler
179+
admissionController AdmissionController
180+
requestControlPlugins Config
187181
// we just need a pointer to an int variable since priority is a pointer in InferenceObjective
188182
// no need to set this in the constructor, since the value we want is the default int val
189183
// and value types cannot be nil
@@ -391,36 +385,47 @@ func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []sch
391385
return pm
392386
}
393387

394-
// HandleResponseHeaders is called when the first chunk of the response arrives.
395-
func (d *Director) HandleResponse(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
396-
logger := log.FromContext(ctx).WithValues("stage", "headers")
397-
logger.V(logutil.DEBUG).Info("Entering HandleResponseHeaders")
388+
// HandleResponseReceived is called when the response headers are received.
389+
func (d *Director) HandleResponseReceived(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
390+
response := &Response{
391+
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
392+
Headers: reqCtx.Response.Headers,
393+
}
398394

399-
d.runPostResponsePlugins(ctx, reqCtx)
395+
// TODO: to extend fallback functionality, handle cases where target pod is unavailable
396+
// https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1224
397+
d.runResponseReceivedPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
400398

401-
logger.V(logutil.DEBUG).Info("Exiting HandleResponseHeaders")
402399
return reqCtx, nil
403400
}
404401

405-
func (d *Director) HandleResponseBodyChunk(ctx context.Context, reqCtx *handlers.RequestContext) error {
402+
// HandleResponseBodyStreaming is called every time a chunk of the response body is received.
403+
func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
406404
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
407405
logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk")
406+
response := &Response{
407+
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
408+
Headers: reqCtx.Response.Headers,
409+
}
408410

409-
d.runPostResponseChunkPlugins(ctx, reqCtx)
411+
d.runResponseStreamingPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
410412
logger.V(logutil.TRACE).Info("Exiting HandleResponseBodyChunk")
411-
return nil
413+
return reqCtx, nil
412414
}
413415

414416
// HandleResponseBodyComplete is called when the response body is fully received.
415-
// It runs the PostResponseComplete plugins.
416-
func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) error {
417+
func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
417418
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
418419
logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete")
420+
response := &Response{
421+
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
422+
Headers: reqCtx.Response.Headers,
423+
}
419424

420-
d.runPostResponseCompletePlugins(ctx, reqCtx)
425+
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
421426

422427
logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")
423-
return nil
428+
return reqCtx, nil
424429
}
425430

426431
func (d *Director) GetRandomPod() *backend.Pod {
@@ -436,43 +441,44 @@ func (d *Director) GetRandomPod() *backend.Pod {
436441
func (d *Director) runPreRequestPlugins(ctx context.Context, request *schedulingtypes.LLMRequest,
437442
schedulingResult *schedulingtypes.SchedulingResult, targetPort int) {
438443
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
439-
for _, plugin := range d.preRequestPlugins {
440-
loggerDebug.Info("Running pre-request plugin", "plugin", plugin.TypedName())
444+
for _, plugin := range d.requestControlPlugins.preRequestPlugins {
445+
loggerDebug.Info("Running PreRequest plugin", "plugin", plugin.TypedName())
441446
before := time.Now()
442447
plugin.PreRequest(ctx, request, schedulingResult, targetPort)
443448
metrics.RecordPluginProcessingLatency(PreRequestExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
444-
loggerDebug.Info("Completed running pre-request plugin successfully", "plugin", plugin.TypedName())
449+
loggerDebug.Info("Completed running PreRequest plugin successfully", "plugin", plugin.TypedName())
445450
}
446451
}
447452

448-
func (d *Director) runPostResponsePlugins(ctx context.Context, reqCtx *handlers.RequestContext) {
453+
func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
449454
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
450-
for _, plugin := range d.postResponsePlugins {
451-
loggerDebug.Info("Running post-response plugin", "plugin", plugin.TypedName())
455+
for _, plugin := range d.requestControlPlugins.responseReceivedPlugins {
456+
loggerDebug.Info("Running ResponseReceived plugin", "plugin", plugin.TypedName())
452457
before := time.Now()
453-
plugin.PostResponse(ctx, reqCtx)
454-
metrics.RecordPluginProcessingLatency(PostResponseExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
455-
loggerDebug.Info("Completed running post-response plugin successfully", "plugin", plugin.TypedName())
458+
plugin.ResponseReceived(ctx, request, response, targetPod)
459+
metrics.RecordPluginProcessingLatency(ResponseReceivedExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
460+
loggerDebug.Info("Completed running ResponseReceived plugin successfully", "plugin", plugin.TypedName())
456461
}
457462
}
458463

459-
func (d *Director) runPostResponseChunkPlugins(ctx context.Context, reqCtx *handlers.RequestContext) {
464+
func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
460465
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
461-
for _, plugin := range d.postResponseChunkPlugins {
462-
loggerTrace.Info("Running post-response chunk plugin", "plugin", plugin.TypedName().Type)
466+
for _, plugin := range d.requestControlPlugins.responseStreamingPlugins {
467+
loggerTrace.Info("Running ResponseStreaming plugin", "plugin", plugin.TypedName())
463468
before := time.Now()
464-
plugin.PostResponseChunk(ctx, reqCtx)
465-
metrics.RecordPluginProcessingLatency(PostResponseChunkExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
469+
plugin.ResponseStreaming(ctx, request, response, targetPod)
470+
metrics.RecordPluginProcessingLatency(ResponseStreamingExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
471+
loggerTrace.Info("Completed running ResponseStreaming plugin successfully", "plugin", plugin.TypedName())
466472
}
467473
}
468474

469-
func (d *Director) runPostResponseCompletePlugins(ctx context.Context, reqCtx *handlers.RequestContext) {
475+
func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
470476
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
471-
for _, plugin := range d.postResponseCompletePlugins {
472-
loggerDebug.Info("Running post-response complete plugin", "plugin", plugin.TypedName().Type)
477+
for _, plugin := range d.requestControlPlugins.responseCompletePlugins {
478+
loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName())
473479
before := time.Now()
474-
plugin.PostResponseComplete(ctx, reqCtx)
475-
metrics.RecordPluginProcessingLatency(PostResponseCompleteExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
476-
loggerDebug.Info("Completed running post-response complete plugin successfully", "plugin", plugin.TypedName())
480+
plugin.ResponseComplete(ctx, request, response, targetPod)
481+
metrics.RecordPluginProcessingLatency(ResponseCompleteExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
482+
loggerDebug.Info("Completed running ResponseComplete plugin successfully", "plugin", plugin.TypedName())
477483
}
478484
}

0 commit comments

Comments
 (0)