From f32557909a1fe43ad69c1f5afbf157579ebed7a9 Mon Sep 17 00:00:00 2001 From: Kellen Swain Date: Tue, 20 May 2025 23:18:29 +0000 Subject: [PATCH] wiring up chunked response logic --- pkg/epp/handlers/request.go | 25 ++++----- pkg/epp/handlers/response.go | 36 +++++++------ pkg/epp/handlers/server.go | 89 +++++++++++++++------------------ pkg/epp/handlers/server_test.go | 7 ++- 4 files changed, 74 insertions(+), 83 deletions(-) diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 1df05ce5d..ab93e023a 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -60,23 +60,20 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ return nil } -func (s *StreamingServer) generateRequestBodyResponse(requestBodyBytes []byte) *extProcPb.ProcessingResponse { - return &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_RequestBody{ - RequestBody: &extProcPb.BodyResponse{ - Response: &extProcPb.CommonResponse{ - BodyMutation: &extProcPb.BodyMutation{ - Mutation: &extProcPb.BodyMutation_StreamedResponse{ - StreamedResponse: &extProcPb.StreamedBodyResponse{ - Body: requestBodyBytes, - EndOfStream: true, - }, - }, - }, +func (s *StreamingServer) generateRequestBodyResponses(requestBodyBytes []byte) []*extProcPb.ProcessingResponse { + commonResponses := buildCommonResponses(requestBodyBytes, bodyByteLimit, true) + responses := []*extProcPb.ProcessingResponse{} + for _, commonResp := range commonResponses { + resp := &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: commonResp, }, }, - }, + } + responses = append(responses, resp) } + return responses } func (s *StreamingServer) generateRequestHeaderResponse(reqCtx *RequestContext) *extProcPb.ProcessingResponse { diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go index bbc46c930..7284628cd 100644 --- a/pkg/epp/handlers/response.go +++ b/pkg/epp/handlers/response.go @@ -63,25 +63,7 @@ func (s *StreamingServer) HandleResponseBody( // will add the processing for streaming case. reqCtx.ResponseComplete = true - reqCtx.respBodyResp = &extProcPb.ProcessingResponse{ - // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header - // and as an unstructure ext-proc response metadata key/value pair. This enables different integration - // options for gateway providers. - Response: &extProcPb.ProcessingResponse_ResponseBody{ - ResponseBody: &extProcPb.BodyResponse{ - Response: &extProcPb.CommonResponse{ - BodyMutation: &extProcPb.BodyMutation{ - Mutation: &extProcPb.BodyMutation_StreamedResponse{ - StreamedResponse: &extProcPb.StreamedBodyResponse{ - Body: responseBytes, - EndOfStream: true, - }, - }, - }, - }, - }, - }, - } + reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true) return reqCtx, nil } @@ -127,6 +109,22 @@ func (s *StreamingServer) generateResponseHeaderResponse(reqCtx *RequestContext) } } +func generateResponseBodyResponses(responseBodyBytes []byte, setEoS bool) []*extProcPb.ProcessingResponse { + commonResponses := buildCommonResponses(responseBodyBytes, bodyByteLimit, setEoS) + responses := []*extProcPb.ProcessingResponse{} + for _, commonResp := range commonResponses { + resp := &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: commonResp, + }, + }, + } + responses = append(responses, resp) + } + return responses +} + func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*configPb.HeaderValueOption { // can likely refactor these two bespoke headers to be updated in PostDispatch, to centralize logic. headers := []*configPb.HeaderValueOption{ diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 4b849c8aa..debf23964 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -99,11 +99,11 @@ type RequestContext struct { Response *Response reqHeaderResp *extProcPb.ProcessingResponse - reqBodyResp *extProcPb.ProcessingResponse + reqBodyResp []*extProcPb.ProcessingResponse reqTrailerResp *extProcPb.ProcessingResponse respHeaderResp *extProcPb.ProcessingResponse - respBodyResp *extProcPb.ProcessingResponse + respBodyResp []*extProcPb.ProcessingResponse respTrailerResp *extProcPb.ProcessingResponse } @@ -222,7 +222,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) } reqCtx.RequestSize = len(requestBodyBytes) reqCtx.reqHeaderResp = s.generateRequestHeaderResponse(reqCtx) - reqCtx.reqBodyResp = s.generateRequestBodyResponse(requestBodyBytes) + reqCtx.reqBodyResp = s.generateRequestBodyResponses(requestBodyBytes) metrics.RecordRequestCounter(reqCtx.Model, reqCtx.ResolvedTargetModel) metrics.RecordRequestSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestSize) @@ -264,22 +264,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize) } - reqCtx.respBodyResp = &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ResponseBody{ - ResponseBody: &extProcPb.BodyResponse{ - Response: &extProcPb.CommonResponse{ - BodyMutation: &extProcPb.BodyMutation{ - Mutation: &extProcPb.BodyMutation_StreamedResponse{ - StreamedResponse: &extProcPb.StreamedBodyResponse{ - Body: v.ResponseBody.Body, - EndOfStream: v.ResponseBody.EndOfStream, - }, - }, - }, - }, - }, - }, - } + reqCtx.respBodyResp = generateResponseBodyResponses(v.ResponseBody.Body, v.ResponseBody.EndOfStream) } else { body = append(body, v.ResponseBody.Body...) @@ -293,22 +278,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) responseErr = json.Unmarshal(body, &responseBody) if responseErr != nil { logger.V(logutil.DEFAULT).Error(responseErr, "Error unmarshaling request body", "body", string(body)) - reqCtx.respBodyResp = &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ResponseBody{ - ResponseBody: &extProcPb.BodyResponse{ - Response: &extProcPb.CommonResponse{ - BodyMutation: &extProcPb.BodyMutation{ - Mutation: &extProcPb.BodyMutation_StreamedResponse{ - StreamedResponse: &extProcPb.StreamedBodyResponse{ - Body: body, - EndOfStream: true, - }, - }, - }, - }, - }, - }, - } + reqCtx.respBodyResp = generateResponseBodyResponses(body, true) break } @@ -361,10 +331,13 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces } r.RequestState = HeaderRequestResponseComplete } - if r.RequestState == HeaderRequestResponseComplete && r.reqBodyResp != nil { - loggerTrace.Info("Sending request body response") - if err := srv.Send(r.reqBodyResp); err != nil { - return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) + if r.RequestState == HeaderRequestResponseComplete && r.reqBodyResp != nil && len(r.reqBodyResp) > 0 { + loggerTrace.Info("Sending request body response(s)") + + for _, response := range r.reqBodyResp { + if err := srv.Send(response); err != nil { + return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) + } } r.RequestState = BodyRequestResponsesComplete metrics.IncRunningRequests(r.Model) @@ -385,15 +358,17 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces } r.RequestState = HeaderResponseResponseComplete } - if r.RequestState == HeaderResponseResponseComplete && r.respBodyResp != nil { - loggerTrace.Info("Sending response body response") - if err := srv.Send(r.respBodyResp); err != nil { - return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) - } + if r.RequestState == HeaderResponseResponseComplete && r.respBodyResp != nil && len(r.respBodyResp) > 0 { + loggerTrace.Info("Sending response body response(s)") + for _, response := range r.respBodyResp { + if err := srv.Send(response); err != nil { + return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) + } - body := r.respBodyResp.Response.(*extProcPb.ProcessingResponse_ResponseBody) - if body.ResponseBody.Response.GetBodyMutation().GetStreamedResponse().GetEndOfStream() { - r.RequestState = BodyResponseResponsesComplete + body := response.Response.(*extProcPb.ProcessingResponse_ResponseBody) + if body.ResponseBody.Response.GetBodyMutation().GetStreamedResponse().GetEndOfStream() { + r.RequestState = BodyResponseResponsesComplete + } } // Dump the response so a new stream message can begin r.respBodyResp = nil @@ -466,16 +441,31 @@ func BuildErrResponse(err error) (*extProcPb.ProcessingResponse, error) { return resp, nil } -func buildCommonResponses(bodyBytes []byte, byteLimit int) []*extProcPb.CommonResponse { +func buildCommonResponses(bodyBytes []byte, byteLimit int, setEos bool) []*extProcPb.CommonResponse { responses := []*extProcPb.CommonResponse{} startingIndex := 0 bodyLen := len(bodyBytes) + if bodyLen == 0 { + return []*extProcPb.CommonResponse{ + { + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: bodyBytes, + EndOfStream: setEos, + }, + }, + }, + }, + } + } + for startingIndex < bodyLen { eos := false len := min(bodyLen-startingIndex, byteLimit) chunk := bodyBytes[startingIndex : len+startingIndex] - if len+startingIndex == bodyLen { + if setEos && len+startingIndex >= bodyLen { eos = true } @@ -492,5 +482,6 @@ func buildCommonResponses(bodyBytes []byte, byteLimit int) []*extProcPb.CommonRe responses = append(responses, commonResp) startingIndex += len } + return responses } diff --git a/pkg/epp/handlers/server_test.go b/pkg/epp/handlers/server_test.go index cc99a517b..e836bf510 100644 --- a/pkg/epp/handlers/server_test.go +++ b/pkg/epp/handlers/server_test.go @@ -11,6 +11,11 @@ func TestBuildCommonResponses(t *testing.T) { count int expectedMessageCount int }{ + { + name: "zero case", + count: 0, + expectedMessageCount: 1, + }, { name: "below limit", count: bodyByteLimit - 1000, @@ -40,7 +45,7 @@ func TestBuildCommonResponses(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { arr := generateBytes(test.count) - responses := buildCommonResponses(arr, bodyByteLimit) + responses := buildCommonResponses(arr, bodyByteLimit, true) for i, response := range responses { eos := response.BodyMutation.GetStreamedResponse().GetEndOfStream() if eos == true && i+1 != len(responses) {