Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ linters:
- makezero
- errcheck
- goconst
- gocyclo
- gofmt
- goimports
- gosimple
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ test: manifests generate fmt vet envtest ## Run tests.

.PHONY: test-integration
test-integration: manifests generate fmt vet envtest ## Run tests.
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test ./test/integration -coverprofile cover.out
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test ./test/integration/epp/... -race -coverprofile cover.out

.PHONY: test-e2e
test-e2e: ## Run end-to-end tests against an existing Kubernetes cluster with at least 3 available GPUs.
Expand Down
98 changes: 66 additions & 32 deletions pkg/epp/handlers/streamingserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"strconv"
"strings"
"time"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
Expand Down Expand Up @@ -131,9 +132,14 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
case *extProcPb.ProcessingRequest_ResponseHeaders:
loggerVerbose.Info("got response headers", "headers", v.ResponseHeaders.Headers.GetHeaders())
for _, header := range v.ResponseHeaders.Headers.GetHeaders() {
code := header.RawValue[0]
if header.Key == "status" && string(code) != "200" {
value := string(header.RawValue)
logger.Error(nil, "header", "key", header.Key, "value", value)
if header.Key == "status" && value != "200" {
reqCtx.ResponseStatusCode = errutil.ModelServerError
} else if header.Key == "content-type" && strings.Contains(value, "text/event-stream") {
reqCtx.modelServerStreaming = true
loggerVerbose.Info("model server is streaming response")
logger.Error(nil, "made it here")
}
}
reqCtx.RequestState = ResponseRecieved
Expand All @@ -158,36 +164,57 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
}

case *extProcPb.ProcessingRequest_ResponseBody:
go func() {
_, err := writer.Write(v.ResponseBody.Body)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Error populating writer")
}
}()

// Message is buffered, we can read and decode.
if v.ResponseBody.EndOfStream {
err = decoder.Decode(&responseBody)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body")
if reqCtx.modelServerStreaming {
// Currently we punt on response parsing if the modelServer is streaming, and we just passthrough.
reqCtx.respBodyResp = &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ResponseBody{
ResponseBody: &extProcPb.BodyResponse{
Response: &extProcPb.CommonResponse{
BodyMutation: &extProcPb.BodyMutation{
Mutation: &extProcPb.BodyMutation_StreamedResponse{
StreamedResponse: &extProcPb.StreamedBodyResponse{
Body: v.ResponseBody.Body,
EndOfStream: v.ResponseBody.EndOfStream,
},
},
},
},
},
},
}
// Body stream complete. Close the reader pipe.
reader.Close()

reqCtx, err = s.HandleResponseBody(ctx, reqCtx, responseBody)
if err == nil && reqCtx.ResponseComplete {
reqCtx.ResponseCompleteTimestamp = time.Now()
metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize)
metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.PromptTokens)
metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.CompletionTokens)
} else {
go func() {
_, err := writer.Write(v.ResponseBody.Body)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Error populating writer")
}
}()

// Message is buffered, we can read and decode.
if v.ResponseBody.EndOfStream {
err = decoder.Decode(&responseBody)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body")
}
// Body stream complete. Close the reader pipe.
reader.Close()

reqCtx, err = s.HandleResponseBody(ctx, reqCtx, responseBody)
if err == nil && reqCtx.ResponseComplete {
reqCtx.ResponseCompleteTimestamp = time.Now()
metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp)
metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize)
metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.PromptTokens)
metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.CompletionTokens)
}
loggerVerbose.Info("Request context after HandleResponseBody", "context", reqCtx)
}
loggerVerbose.Info("Request context after HandleResponseBody", "context", reqCtx)
}
case *extProcPb.ProcessingRequest_ResponseTrailers:
// This is currently unused.
}

// Handle the err and fire an immediate response.
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Failed to process request", "request", req)
resp, err := BuildErrResponse(err)
Expand Down Expand Up @@ -246,7 +273,11 @@ func (r *StreamingRequestContext) updateStateAndSendIfNeeded(srv extProcPb.Exter
if err := srv.Send(r.respBodyResp); err != nil {
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
}
r.RequestState = BodyResponseResponsesComplete

body := r.respBodyResp.Response.(*extProcPb.ProcessingResponse_ResponseBody)
if body.ResponseBody.Response.GetBodyMutation().GetStreamedResponse().GetEndOfStream() {
r.RequestState = BodyResponseResponsesComplete
}
// Dump the response so a new stream message can begin
r.reqBodyResp = nil
}
Expand All @@ -273,6 +304,8 @@ type StreamingRequestContext struct {
ResponseComplete bool
ResponseStatusCode string

modelServerStreaming bool

reqHeaderResp *extProcPb.ProcessingResponse
reqBodyResp *extProcPb.ProcessingResponse
reqTrailerResp *extProcPb.ProcessingResponse
Expand Down Expand Up @@ -339,14 +372,15 @@ func (s *StreamingServer) HandleRequestBody(
// Update target models in the body.
if llmReq.Model != llmReq.ResolvedTargetModel {
requestBodyMap["model"] = llmReq.ResolvedTargetModel
requestBodyBytes, err = json.Marshal(requestBodyMap)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body")
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)}
}
loggerVerbose.Info("Updated request body marshalled", "body", string(requestBodyBytes))
}

requestBodyBytes, err = json.Marshal(requestBodyMap)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body")
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)}
}
loggerVerbose.Info("Updated request body marshalled", "body", string(requestBodyBytes))

target, err := s.scheduler.Schedule(ctx, llmReq)
if err != nil {
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
Expand Down
11 changes: 9 additions & 2 deletions pkg/epp/server/controller_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/cache"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/manager"
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
)

Expand All @@ -40,7 +41,7 @@ func init() {

// NewDefaultManager creates a new controller manager with default configuration.
func NewDefaultManager(namespace, name string, restConfig *rest.Config) (ctrl.Manager, error) {
manager, err := ctrl.NewManager(restConfig, ctrl.Options{
defaultOpts := ctrl.Options{
Scheme: scheme,
Cache: cache.Options{
ByObject: map[client.Object]cache.ByObject{
Expand All @@ -65,7 +66,13 @@ func NewDefaultManager(namespace, name string, restConfig *rest.Config) (ctrl.Ma
},
},
},
})
}
return NewManagerWithOptions(restConfig, defaultOpts)
}

// NewManagerWithOptions creates a new controller manager with injectable options.
func NewManagerWithOptions(restConfig *rest.Config, opts manager.Options) (ctrl.Manager, error) {
manager, err := ctrl.NewManager(restConfig, opts)
if err != nil {
return nil, fmt.Errorf("failed to create controller manager: %v", err)
}
Expand Down
24 changes: 23 additions & 1 deletion pkg/epp/util/testing/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package testing
import (
"encoding/json"

envoyCorev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
"github.com/go-logr/logr"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
Expand All @@ -38,8 +39,29 @@ func GenerateRequest(logger logr.Logger, prompt, model string) *extProcPb.Proces
}
req := &extProcPb.ProcessingRequest{
Request: &extProcPb.ProcessingRequest_RequestBody{
RequestBody: &extProcPb.HttpBody{Body: llmReq},
RequestBody: &extProcPb.HttpBody{Body: llmReq, EndOfStream: true},
},
}
return req
}

func GenerateStreamedRequestSet(logger logr.Logger, prompt, model string) []*extProcPb.ProcessingRequest {
requests := []*extProcPb.ProcessingRequest{}
headerReq := &extProcPb.ProcessingRequest{
Request: &extProcPb.ProcessingRequest_RequestHeaders{
RequestHeaders: &extProcPb.HttpHeaders{
Headers: &envoyCorev3.HeaderMap{
Headers: []*envoyCorev3.HeaderValue{
{
Key: "hi",
Value: "mom",
},
},
},
},
},
}
requests = append(requests, headerReq)
requests = append(requests, GenerateRequest(logger, prompt, model))
return requests
}
Loading