diff --git a/.golangci.yml b/.golangci.yml index 2ad3b93da..d1b1e112a 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -25,7 +25,6 @@ linters: - makezero - errcheck - goconst - - gocyclo - gofmt - goimports - gosimple diff --git a/Makefile b/Makefile index 257d2cbb9..40cb0b751 100644 --- a/Makefile +++ b/Makefile @@ -123,7 +123,7 @@ test: manifests generate fmt vet envtest ## Run tests. .PHONY: test-integration test-integration: manifests generate fmt vet envtest ## Run tests. - KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test ./test/integration -coverprofile cover.out + KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test ./test/integration/epp/... -race -coverprofile cover.out .PHONY: test-e2e test-e2e: ## Run end-to-end tests against an existing Kubernetes cluster with at least 3 available GPUs. diff --git a/pkg/epp/handlers/streamingserver.go b/pkg/epp/handlers/streamingserver.go index c8de7bb73..2aaca7f35 100644 --- a/pkg/epp/handlers/streamingserver.go +++ b/pkg/epp/handlers/streamingserver.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "strconv" + "strings" "time" configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -131,9 +132,14 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) case *extProcPb.ProcessingRequest_ResponseHeaders: loggerVerbose.Info("got response headers", "headers", v.ResponseHeaders.Headers.GetHeaders()) for _, header := range v.ResponseHeaders.Headers.GetHeaders() { - code := header.RawValue[0] - if header.Key == "status" && string(code) != "200" { + value := string(header.RawValue) + logger.Error(nil, "header", "key", header.Key, "value", value) + if header.Key == "status" && value != "200" { reqCtx.ResponseStatusCode = errutil.ModelServerError + } else if header.Key == "content-type" && strings.Contains(value, "text/event-stream") { + reqCtx.modelServerStreaming = true + loggerVerbose.Info("model server is streaming response") + logger.Error(nil, "made it here") } } reqCtx.RequestState = ResponseRecieved @@ -158,36 +164,57 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) } case *extProcPb.ProcessingRequest_ResponseBody: - go func() { - _, err := writer.Write(v.ResponseBody.Body) - if err != nil { - logger.V(logutil.DEFAULT).Error(err, "Error populating writer") - } - }() - - // Message is buffered, we can read and decode. - if v.ResponseBody.EndOfStream { - err = decoder.Decode(&responseBody) - if err != nil { - logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body") + if reqCtx.modelServerStreaming { + // Currently we punt on response parsing if the modelServer is streaming, and we just passthrough. + reqCtx.respBodyResp = &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: v.ResponseBody.Body, + EndOfStream: v.ResponseBody.EndOfStream, + }, + }, + }, + }, + }, + }, } - // Body stream complete. Close the reader pipe. - reader.Close() - - reqCtx, err = s.HandleResponseBody(ctx, reqCtx, responseBody) - if err == nil && reqCtx.ResponseComplete { - reqCtx.ResponseCompleteTimestamp = time.Now() - metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp) - metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize) - metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.PromptTokens) - metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.CompletionTokens) + } else { + go func() { + _, err := writer.Write(v.ResponseBody.Body) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Error populating writer") + } + }() + + // Message is buffered, we can read and decode. + if v.ResponseBody.EndOfStream { + err = decoder.Decode(&responseBody) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body") + } + // Body stream complete. Close the reader pipe. + reader.Close() + + reqCtx, err = s.HandleResponseBody(ctx, reqCtx, responseBody) + if err == nil && reqCtx.ResponseComplete { + reqCtx.ResponseCompleteTimestamp = time.Now() + metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp) + metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize) + metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.PromptTokens) + metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.CompletionTokens) + } + loggerVerbose.Info("Request context after HandleResponseBody", "context", reqCtx) } - loggerVerbose.Info("Request context after HandleResponseBody", "context", reqCtx) } case *extProcPb.ProcessingRequest_ResponseTrailers: // This is currently unused. } + // Handle the err and fire an immediate response. if err != nil { logger.V(logutil.DEFAULT).Error(err, "Failed to process request", "request", req) resp, err := BuildErrResponse(err) @@ -246,7 +273,11 @@ func (r *StreamingRequestContext) updateStateAndSendIfNeeded(srv extProcPb.Exter if err := srv.Send(r.respBodyResp); err != nil { return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) } - r.RequestState = BodyResponseResponsesComplete + + body := r.respBodyResp.Response.(*extProcPb.ProcessingResponse_ResponseBody) + if body.ResponseBody.Response.GetBodyMutation().GetStreamedResponse().GetEndOfStream() { + r.RequestState = BodyResponseResponsesComplete + } // Dump the response so a new stream message can begin r.reqBodyResp = nil } @@ -273,6 +304,8 @@ type StreamingRequestContext struct { ResponseComplete bool ResponseStatusCode string + modelServerStreaming bool + reqHeaderResp *extProcPb.ProcessingResponse reqBodyResp *extProcPb.ProcessingResponse reqTrailerResp *extProcPb.ProcessingResponse @@ -339,14 +372,15 @@ func (s *StreamingServer) HandleRequestBody( // Update target models in the body. if llmReq.Model != llmReq.ResolvedTargetModel { requestBodyMap["model"] = llmReq.ResolvedTargetModel - requestBodyBytes, err = json.Marshal(requestBodyMap) - if err != nil { - logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body") - return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)} - } - loggerVerbose.Info("Updated request body marshalled", "body", string(requestBodyBytes)) } + requestBodyBytes, err = json.Marshal(requestBodyMap) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body") + return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)} + } + loggerVerbose.Info("Updated request body marshalled", "body", string(requestBodyBytes)) + target, err := s.scheduler.Schedule(ctx, llmReq) if err != nil { return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} diff --git a/pkg/epp/server/controller_manager.go b/pkg/epp/server/controller_manager.go index fd505d002..46694f7b9 100644 --- a/pkg/epp/server/controller_manager.go +++ b/pkg/epp/server/controller_manager.go @@ -28,6 +28,7 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/cache" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" ) @@ -40,7 +41,7 @@ func init() { // NewDefaultManager creates a new controller manager with default configuration. func NewDefaultManager(namespace, name string, restConfig *rest.Config) (ctrl.Manager, error) { - manager, err := ctrl.NewManager(restConfig, ctrl.Options{ + defaultOpts := ctrl.Options{ Scheme: scheme, Cache: cache.Options{ ByObject: map[client.Object]cache.ByObject{ @@ -65,7 +66,13 @@ func NewDefaultManager(namespace, name string, restConfig *rest.Config) (ctrl.Ma }, }, }, - }) + } + return NewManagerWithOptions(restConfig, defaultOpts) +} + +// NewManagerWithOptions creates a new controller manager with injectable options. +func NewManagerWithOptions(restConfig *rest.Config, opts manager.Options) (ctrl.Manager, error) { + manager, err := ctrl.NewManager(restConfig, opts) if err != nil { return nil, fmt.Errorf("failed to create controller manager: %v", err) } diff --git a/pkg/epp/util/testing/request.go b/pkg/epp/util/testing/request.go index fe9a0d089..30772ad54 100644 --- a/pkg/epp/util/testing/request.go +++ b/pkg/epp/util/testing/request.go @@ -19,6 +19,7 @@ package testing import ( "encoding/json" + envoyCorev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "github.com/go-logr/logr" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -38,8 +39,29 @@ func GenerateRequest(logger logr.Logger, prompt, model string) *extProcPb.Proces } req := &extProcPb.ProcessingRequest{ Request: &extProcPb.ProcessingRequest_RequestBody{ - RequestBody: &extProcPb.HttpBody{Body: llmReq}, + RequestBody: &extProcPb.HttpBody{Body: llmReq, EndOfStream: true}, }, } return req } + +func GenerateStreamedRequestSet(logger logr.Logger, prompt, model string) []*extProcPb.ProcessingRequest { + requests := []*extProcPb.ProcessingRequest{} + headerReq := &extProcPb.ProcessingRequest{ + Request: &extProcPb.ProcessingRequest_RequestHeaders{ + RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &envoyCorev3.HeaderMap{ + Headers: []*envoyCorev3.HeaderValue{ + { + Key: "hi", + Value: "mom", + }, + }, + }, + }, + }, + } + requests = append(requests, headerReq) + requests = append(requests, GenerateRequest(logger, prompt, model)) + return requests +} diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index c5e7c10a3..7dc9bdb8f 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -43,6 +43,8 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/structpb" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" utilruntime "k8s.io/apimachinery/pkg/util/runtime" @@ -51,7 +53,10 @@ import ( "k8s.io/component-base/metrics/legacyregistry" metricsutils "k8s.io/component-base/metrics/testutil" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" k8sclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/config" "sigs.k8s.io/controller-runtime/pkg/envtest" "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" @@ -78,6 +83,13 @@ var ( logger = logutil.NewTestLogger().V(logutil.VERBOSE) ) +func TestMain(m *testing.M) { + cleanup := BeforeSuite() + code := m.Run() + cleanup() + os.Exit(code) +} + func TestKubeInferenceModelRequest(t *testing.T) { tests := []struct { name string @@ -196,57 +208,814 @@ func TestKubeInferenceModelRequest(t *testing.T) { WaitingQueueSize: 10, KVCacheUsagePercent: 0.2, ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, + "foo": 1, + "bar": 1, + }, + }, + fakePod(1): { + WaitingQueueSize: 200, + KVCacheUsagePercent: 0.1, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg2": 1, + }, + }, + fakePod(2): { + WaitingQueueSize: 6, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + }, + }, + }, + wantHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: runserver.DefaultDestinationEndpointHintKey, + RawValue: []byte("192.168.1.3:8000"), + }, + }, + { + Header: &configPb.HeaderValue{ + Key: "Content-Length", + RawValue: []byte("76"), + }, + }, + }, + wantMetadata: makeMetadata("192.168.1.3:8000"), + wantBody: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg2\",\"prompt\":\"test3\",\"temperature\":0}"), + wantMetrics: ` + # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. + # TYPE inference_model_request_total counter + inference_model_request_total{model_name="sql-lora",target_model_name="sql-lora-1fdg2"} 1 + `, + wantErr: false, + }, + { + name: "noncritical and all models past threshold, shed request", + req: utiltesting.GenerateRequest(logger, "test4", "sql-lora-sheddable"), + // no pods will be picked as all models are either above kv threshold, + // queue threshold, or both. + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { + WaitingQueueSize: 6, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + "sql-lora-1fdg3": 1, + }, + }, + fakePod(1): { + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.85, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg3": 1, + }, + }, + fakePod(2): { + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.9, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg3": 1, + }, + }, + }, + wantHeaders: []*configPb.HeaderValueOption{}, + wantMetadata: &structpb.Struct{}, + wantBody: []byte(""), + wantErr: false, + immediateResponse: &extProcPb.ImmediateResponse{ + Status: &envoyTypePb.HttpStatus{ + Code: envoyTypePb.StatusCode_TooManyRequests, + }, + }, + wantMetrics: "", + }, + { + name: "noncritical, but one server has capacity, do not shed", + req: utiltesting.GenerateRequest(logger, "test5", "sql-lora-sheddable"), + // pod 0 will be picked as all other models are above threshold + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { + WaitingQueueSize: 4, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + "sql-lora-1fdg3": 1, + }, + }, + fakePod(1): { + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.85, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg3": 1, + }, + }, + fakePod(2): { + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.9, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg3": 1, + }, + }, + }, + wantHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: runserver.DefaultDestinationEndpointHintKey, + RawValue: []byte("192.168.1.1:8000"), + }, + }, + { + Header: &configPb.HeaderValue{ + Key: "Content-Length", + RawValue: []byte("76"), + }, + }, + }, + wantMetadata: makeMetadata("192.168.1.1:8000"), + wantBody: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg3\",\"prompt\":\"test5\",\"temperature\":0}"), + wantMetrics: ` + # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. + # TYPE inference_model_request_total counter + inference_model_request_total{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3"} 1 + `, + wantErr: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + client, cleanup := setUpHermeticServer(t, test.pods, false) + t.Cleanup(cleanup) + want := &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: test.wantHeaders, + }, + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_Body{ + Body: test.wantBody, + }, + }, + }, + }, + }, + DynamicMetadata: test.wantMetadata, + } + res, err := sendRequest(t, client, test.req) + + if err != nil && !test.wantErr { + t.Errorf("Unexpected error, got: %v, want error: %v", err, test.wantErr) + } + if test.immediateResponse != nil { + want = &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_ImmediateResponse{ + ImmediateResponse: test.immediateResponse, + }, + } + } + if diff := cmp.Diff(want, res, protocmp.Transform()); diff != "" { + t.Errorf("Unexpected response, (-want +got): %v", diff) + } + + if test.wantMetrics != "" { + if err := metricsutils.GatherAndCompare(legacyregistry.DefaultGatherer, strings.NewReader(test.wantMetrics), "inference_model_request_total"); err != nil { + t.Error(err) + } + } + + legacyregistry.Reset() + }) + } +} + +func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { + tests := []struct { + name string + requests []*extProcPb.ProcessingRequest + pods map[backendmetrics.Pod]*backendmetrics.Metrics + wantResponses []*extProcPb.ProcessingResponse + wantMetrics string + wantErr bool + immediateResponse *extProcPb.ImmediateResponse + }{ + // Request flow tests + { + name: "select lower queue and kv cache, no active lora", + requests: utiltesting.GenerateStreamedRequestSet(logger, "test1", "my-model"), + // pod-1 will be picked because it has relatively low queue size and low KV cache. + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { + WaitingQueueSize: 3, + KVCacheUsagePercent: 0.2, + }, + fakePod(1): { + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.1, + }, + fakePod(2): { + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.2, + }, + }, + wantMetrics: ` + # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. + # TYPE inference_model_request_total counter + inference_model_request_total{model_name="my-model",target_model_name="my-model-12345"} 1 + `, + wantErr: false, + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: "x-gateway-destination-endpoint", + RawValue: []byte("192.168.1.2:8000"), + }, + }, + { + Header: &configPb.HeaderValue{ + Key: "Content-Length", + RawValue: []byte(strconv.Itoa(76)), + }, + }, + }}, + }, + }, + }, + DynamicMetadata: makeMetadata("192.168.1.2:8000"), + }, + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte("{\"max_tokens\":100,\"model\":\"my-model-12345\",\"prompt\":\"test1\",\"temperature\":0}"), + EndOfStream: true, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "select active lora, low queue", + requests: utiltesting.GenerateStreamedRequestSet(logger, "test2", "sql-lora"), + // pod-1 will be picked because it has relatively low queue size, with the requested + // model being active, and has low KV cache. + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + }, + fakePod(1): { + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.1, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg2": 1, + }, + }, + fakePod(2): { + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + }, + }, + wantMetrics: ` + # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. + # TYPE inference_model_request_total counter + inference_model_request_total{model_name="sql-lora",target_model_name="sql-lora-1fdg2"} 1 + `, + wantErr: false, + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: "x-gateway-destination-endpoint", + RawValue: []byte("192.168.1.2:8000"), + }, + }, + { + Header: &configPb.HeaderValue{ + Key: "Content-Length", + RawValue: []byte(strconv.Itoa(76)), + }, + }, + }}, + }, + }, + }, + DynamicMetadata: makeMetadata("192.168.1.2:8000"), + }, + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg2\",\"prompt\":\"test2\",\"temperature\":0}"), + EndOfStream: true, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "select no lora despite active model, avoid excessive queue size", + requests: utiltesting.GenerateStreamedRequestSet(logger, "test3", "sql-lora"), + // pod-2 will be picked despite it NOT having the requested model being active + // as it's above the affinity for queue size. Also is critical, so we should + // still honor request despite all queues > 5 + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + }, + fakePod(1): { + WaitingQueueSize: 200, + KVCacheUsagePercent: 0.1, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg2": 1, + }, + }, + fakePod(2): { + WaitingQueueSize: 6, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + }, + }, + }, + wantMetrics: ` + # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. + # TYPE inference_model_request_total counter + inference_model_request_total{model_name="sql-lora",target_model_name="sql-lora-1fdg2"} 1 + `, + wantErr: false, + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: "x-gateway-destination-endpoint", + RawValue: []byte("192.168.1.3:8000"), + }, + }, + { + Header: &configPb.HeaderValue{ + Key: "Content-Length", + RawValue: []byte(strconv.Itoa(76)), + }, + }, + }}, + }, + }, + }, + DynamicMetadata: makeMetadata("192.168.1.3:8000"), + }, + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg2\",\"prompt\":\"test3\",\"temperature\":0}"), + EndOfStream: true, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "noncritical and all models past threshold, shed request", + requests: utiltesting.GenerateStreamedRequestSet(logger, "test4", "sql-lora-sheddable"), + // no pods will be picked as all models are either above kv threshold, + // queue threshold, or both. + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { + WaitingQueueSize: 6, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + "sql-lora-1fdg3": 1, + }, + }, + fakePod(1): { + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.85, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg3": 1, + }, + }, + fakePod(2): { + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.9, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg3": 1, + }, + }, + }, + wantErr: false, + wantMetrics: "", + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_ImmediateResponse{ + ImmediateResponse: &extProcPb.ImmediateResponse{ + Status: &envoyTypePb.HttpStatus{ + Code: envoyTypePb.StatusCode_TooManyRequests, + }, + }, + }, + }, + }, + }, + { + name: "noncritical, but one server has capacity, do not shed", + requests: utiltesting.GenerateStreamedRequestSet(logger, "test5", "sql-lora-sheddable"), + // pod 0 will be picked as all other models are above threshold + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { + WaitingQueueSize: 4, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + "sql-lora-1fdg3": 1, + }, + }, + fakePod(1): { + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.85, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg3": 1, + }, + }, + fakePod(2): { + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.9, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg3": 1, + }, + }, + }, + wantMetrics: ` + # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. + # TYPE inference_model_request_total counter + inference_model_request_total{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3"} 1 + `, + wantErr: false, + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: "x-gateway-destination-endpoint", + RawValue: []byte("192.168.1.1:8000"), + }, + }, + { + Header: &configPb.HeaderValue{ + Key: "Content-Length", + RawValue: []byte(strconv.Itoa(76)), + }, + }, + }}, + }, + }, + }, + DynamicMetadata: makeMetadata("192.168.1.1:8000"), + }, + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg3\",\"prompt\":\"test5\",\"temperature\":0}"), + EndOfStream: true, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "body sent over multiple requests, noncritical, but one server has capacity, do not shed", + requests: []*extProcPb.ProcessingRequest{ + { + Request: &extProcPb.ProcessingRequest_RequestHeaders{ + RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{ + Headers: []*configPb.HeaderValue{ + { + Key: "hi", + Value: "mom", + }, + }, + }, + }, + }, + }, + { + Request: &extProcPb.ProcessingRequest_RequestBody{ + RequestBody: &extProcPb.HttpBody{Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lo"), EndOfStream: false}, + }, + }, + { + Request: &extProcPb.ProcessingRequest_RequestBody{ + RequestBody: &extProcPb.HttpBody{Body: []byte("ra-sheddable\",\"prompt\":\"test6\",\"temperature\":0}"), EndOfStream: true}, + }, + }, + }, + + // + // pod 0 will be picked as all other models are above threshold + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { + WaitingQueueSize: 4, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + "sql-lora-1fdg3": 1, + }, + }, + fakePod(1): { + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.85, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg3": 1, + }, + }, + fakePod(2): { + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.9, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg3": 1, + }, + }, + }, + wantMetrics: ` + # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. + # TYPE inference_model_request_total counter + inference_model_request_total{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3"} 1 + `, + wantErr: false, + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: "x-gateway-destination-endpoint", + RawValue: []byte("192.168.1.1:8000"), + }, + }, + { + Header: &configPb.HeaderValue{ + Key: "Content-Length", + RawValue: []byte(strconv.Itoa(76)), + }, + }, + }}, + }, + }, + }, + DynamicMetadata: makeMetadata("192.168.1.1:8000"), + }, + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg3\",\"prompt\":\"test6\",\"temperature\":0}"), + EndOfStream: true, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "inferencemodel's modelName is not translated, passthrough", + requests: []*extProcPb.ProcessingRequest{ + { + Request: &extProcPb.ProcessingRequest_RequestHeaders{ + RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{ + Headers: []*configPb.HeaderValue{ + { + Key: "hi", + Value: "mom", + }, + }, + }, + }, + }, + }, + { + Request: &extProcPb.ProcessingRequest_RequestBody{ + RequestBody: &extProcPb.HttpBody{Body: []byte("{\"max_tokens\":100,\"model\":\"direct-"), EndOfStream: false}, + }, + }, + { + Request: &extProcPb.ProcessingRequest_RequestBody{ + RequestBody: &extProcPb.HttpBody{Body: []byte("model\",\"prompt\":\"test6\",\"temperature\":0}"), EndOfStream: true}, + }, + }, + }, + + // + // pod 0 will be picked as all other models are above threshold + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { + WaitingQueueSize: 4, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + "sql-lora-1fdg3": 1, }, }, fakePod(1): { - WaitingQueueSize: 200, - KVCacheUsagePercent: 0.1, + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.85, ActiveModels: map[string]int{ "foo": 1, - "sql-lora-1fdg2": 1, + "sql-lora-1fdg3": 1, }, }, fakePod(2): { - WaitingQueueSize: 6, - KVCacheUsagePercent: 0.2, + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.9, ActiveModels: map[string]int{ - "foo": 1, + "foo": 1, + "sql-lora-1fdg3": 1, }, }, }, - wantHeaders: []*configPb.HeaderValueOption{ + wantMetrics: ` + # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. + # TYPE inference_model_request_total counter + inference_model_request_total{model_name="direct-model",target_model_name="direct-model"} 1 + `, + wantErr: false, + wantResponses: []*extProcPb.ProcessingResponse{ { - Header: &configPb.HeaderValue{ - Key: runserver.DefaultDestinationEndpointHintKey, - RawValue: []byte("192.168.1.3:8000"), + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: "x-gateway-destination-endpoint", + RawValue: []byte("192.168.1.2:8000"), + }, + }, + { + Header: &configPb.HeaderValue{ + Key: "Content-Length", + RawValue: []byte(strconv.Itoa(74)), + }, + }, + }}, + }, + }, }, + DynamicMetadata: makeMetadata("192.168.1.2:8000"), }, { - Header: &configPb.HeaderValue{ - Key: "Content-Length", - RawValue: []byte("76"), + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte("{\"max_tokens\":100,\"model\":\"direct-model\",\"prompt\":\"test6\",\"temperature\":0}"), + EndOfStream: true, + }, + }, + }, + }, + }, }, }, }, - wantMetadata: makeMetadata("192.168.1.3:8000"), - wantBody: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg2\",\"prompt\":\"test3\",\"temperature\":0}"), - wantMetrics: ` - # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. - # TYPE inference_model_request_total counter - inference_model_request_total{model_name="sql-lora",target_model_name="sql-lora-1fdg2"} 1 - `, - wantErr: false, }, + // Response flow tests { - name: "noncritical and all models past threshold, shed request", - req: utiltesting.GenerateRequest(logger, "test4", "sql-lora-sheddable"), - // no pods will be picked as all models are either above kv threshold, - // queue threshold, or both. + name: "responsebody sent over multiple requests, content-type is json, buffer", + requests: []*extProcPb.ProcessingRequest{ + { + Request: &extProcPb.ProcessingRequest_ResponseHeaders{ + ResponseHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{ + Headers: []*configPb.HeaderValue{ + { + Key: "content-type", + Value: "application/json", + }, + }, + }, + }, + }, + }, + { + Request: &extProcPb.ProcessingRequest_ResponseBody{ + ResponseBody: &extProcPb.HttpBody{Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lo"), EndOfStream: false}, + }, + }, + { + Request: &extProcPb.ProcessingRequest_ResponseBody{ + ResponseBody: &extProcPb.HttpBody{Body: []byte("ra-sheddable\",\"prompt\":\"test6\",\"temperature\":0}"), EndOfStream: true}, + }, + }, + }, + + // + // pod 0 will be picked as all other models are above threshold pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ fakePod(0): { - WaitingQueueSize: 6, + WaitingQueueSize: 4, KVCacheUsagePercent: 0.2, ActiveModels: map[string]int{ "foo": 1, @@ -271,20 +1040,74 @@ func TestKubeInferenceModelRequest(t *testing.T) { }, }, }, - wantHeaders: []*configPb.HeaderValueOption{}, - wantMetadata: &structpb.Struct{}, - wantBody: []byte(""), - wantErr: false, - immediateResponse: &extProcPb.ImmediateResponse{ - Status: &envoyTypePb.HttpStatus{ - Code: envoyTypePb.StatusCode_TooManyRequests, + wantErr: false, + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: "x-went-into-resp-headers", + RawValue: []byte("true"), + }, + }, + }, + }, + }, + }, + }, + }, + { + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-sheddable\",\"prompt\":\"test6\",\"temperature\":0}"), + EndOfStream: true, + }, + }, + }, + }, + }, + }, }, }, - wantMetrics: "", }, { - name: "noncritical, but one server has capacity, do not shed", - req: utiltesting.GenerateRequest(logger, "test5", "sql-lora-sheddable"), + name: "responsebody sent over a single request, but empty body with EndOfStream in the second request(this is how envoy operates); content-type is json, buffer", + requests: []*extProcPb.ProcessingRequest{ + { + Request: &extProcPb.ProcessingRequest_ResponseHeaders{ + ResponseHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{ + Headers: []*configPb.HeaderValue{ + { + Key: "content-type", + Value: "application/json", + }, + }, + }, + }, + }, + }, + { + Request: &extProcPb.ProcessingRequest_ResponseBody{ + ResponseBody: &extProcPb.HttpBody{Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-sheddable\",\"prompt\":\"test6\",\"temperature\":0}"), EndOfStream: false}, + }, + }, + { + Request: &extProcPb.ProcessingRequest_ResponseBody{ + ResponseBody: &extProcPb.HttpBody{Body: []byte(""), EndOfStream: true}, + }, + }, + }, + + // // pod 0 will be picked as all other models are above threshold pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ fakePod(0): { @@ -313,69 +1136,261 @@ func TestKubeInferenceModelRequest(t *testing.T) { }, }, }, - wantHeaders: []*configPb.HeaderValueOption{ + wantErr: false, + wantResponses: []*extProcPb.ProcessingResponse{ { - Header: &configPb.HeaderValue{ - Key: runserver.DefaultDestinationEndpointHintKey, - RawValue: []byte("192.168.1.1:8000"), + Response: &extProcPb.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: "x-went-into-resp-headers", + RawValue: []byte("true"), + }, + }, + }, + }, + }, + }, }, }, { - Header: &configPb.HeaderValue{ - Key: "Content-Length", - RawValue: []byte("76"), + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-sheddable\",\"prompt\":\"test6\",\"temperature\":0}"), + EndOfStream: true, + }, + }, + }, + }, + }, }, }, }, - wantMetadata: makeMetadata("192.168.1.1:8000"), - wantBody: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg3\",\"prompt\":\"test5\",\"temperature\":0}"), - wantMetrics: ` - # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. - # TYPE inference_model_request_total counter - inference_model_request_total{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3"} 1 - `, - wantErr: false, }, - } - - // Set up global k8sclient and extproc server runner with test environment config - cleanup := BeforeSuit(t) - defer cleanup() - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - client, cleanup := setUpHermeticServer(t, test.pods) - t.Cleanup(cleanup) - want := &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_RequestBody{ - RequestBody: &extProcPb.BodyResponse{ - Response: &extProcPb.CommonResponse{ - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: test.wantHeaders, + { + name: "responsebody sent over a single request, but empty body with EndOfStream in the second request(this is how envoy operates); content-type is json, buffer", + requests: []*extProcPb.ProcessingRequest{ + { + Request: &extProcPb.ProcessingRequest_ResponseHeaders{ + ResponseHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{ + Headers: []*configPb.HeaderValue{ + { + Key: "content-type", + RawValue: []byte("text/event-stream"), + }, + { + Key: "status", + RawValue: []byte("200"), + }, + }, }, - BodyMutation: &extProcPb.BodyMutation{ - Mutation: &extProcPb.BodyMutation_Body{ - Body: test.wantBody, + }, + }, + }, + { + Request: &extProcPb.ProcessingRequest_ResponseBody{ + ResponseBody: &extProcPb.HttpBody{ + Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"tweet-summary-1","choices":[{"index":0,"text":"NEVER","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), + EndOfStream: false}, + }, + }, + { + Request: &extProcPb.ProcessingRequest_ResponseBody{ + ResponseBody: &extProcPb.HttpBody{ + Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"tweet-summary-1","choices":[{"index":0,"text":"GONNA","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), + EndOfStream: false}, + }, + }, + { + Request: &extProcPb.ProcessingRequest_ResponseBody{ + ResponseBody: &extProcPb.HttpBody{ + Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"tweet-summary-1","choices":[{"index":0,"text":"GIVE","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), + EndOfStream: false}, + }, + }, + { + Request: &extProcPb.ProcessingRequest_ResponseBody{ + ResponseBody: &extProcPb.HttpBody{ + Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"tweet-summary-1","choices":[{"index":0,"text":"YOU","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), + EndOfStream: false}, + }, + }, + { + Request: &extProcPb.ProcessingRequest_ResponseBody{ + ResponseBody: &extProcPb.HttpBody{ + Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"tweet-summary-1","choices":[{"index":0,"text":"UP","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), + EndOfStream: false}, + }, + }, + { + Request: &extProcPb.ProcessingRequest_ResponseBody{ + ResponseBody: &extProcPb.HttpBody{ + Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"tweet-summary-1","choices":[],"usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}}`), + EndOfStream: false}, + }, + }, + { + Request: &extProcPb.ProcessingRequest_ResponseBody{ + ResponseBody: &extProcPb.HttpBody{ + Body: []byte("data: [DONE]"), + EndOfStream: true}, + }, + }, + }, + wantErr: false, + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: "x-went-into-resp-headers", + RawValue: []byte("true"), + }, + }, + }, }, }, }, }, }, - DynamicMetadata: test.wantMetadata, - } - res, err := sendRequest(t, client, test.req) + { + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"tweet-summary-1","choices":[{"index":0,"text":"NEVER","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), + EndOfStream: false, + }, + }, + }, + }, + }, + }, + }, + { + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"tweet-summary-1","choices":[{"index":0,"text":"GONNA","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), + EndOfStream: false, + }, + }, + }, + }, + }, + }, + }, + { + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"tweet-summary-1","choices":[{"index":0,"text":"GIVE","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), + EndOfStream: false, + }, + }, + }, + }, + }, + }, + }, + { + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"tweet-summary-1","choices":[{"index":0,"text":"YOU","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), + EndOfStream: false, + }, + }, + }, + }, + }, + }, + }, + { + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"tweet-summary-1","choices":[{"index":0,"text":"UP","logprobs":null,"finish_reason":null,"stop_reason":null}],"usage":null}`), + EndOfStream: false, + }, + }, + }, + }, + }, + }, + }, + { + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"tweet-summary-1","choices":[],"usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}}`), + EndOfStream: false, + }, + }, + }, + }, + }, + }, + }, + { + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte("data: [DONE]"), + EndOfStream: true, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + client, cleanup := setUpHermeticServer(t, test.pods, true) + t.Cleanup(cleanup) + responses, err := streamedRequest(t, client, test.requests, len(test.wantResponses)) if err != nil && !test.wantErr { t.Errorf("Unexpected error, got: %v, want error: %v", err, test.wantErr) } - if test.immediateResponse != nil { - want = &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ImmediateResponse{ - ImmediateResponse: test.immediateResponse, - }, - } - } - if diff := cmp.Diff(want, res, protocmp.Transform()); diff != "" { + if diff := cmp.Diff(test.wantResponses, responses, protocmp.Transform()); diff != "" { t.Errorf("Unexpected response, (-want +got): %v", diff) } @@ -390,13 +1405,14 @@ func TestKubeInferenceModelRequest(t *testing.T) { } } -func setUpHermeticServer(t *testing.T, podAndMetrics map[backendmetrics.Pod]*backendmetrics.Metrics) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { +func setUpHermeticServer(t *testing.T, podAndMetrics map[backendmetrics.Pod]*backendmetrics.Metrics, streamed bool) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { // Reconfigure the TestPodMetricsClient. res := map[types.NamespacedName]*backendmetrics.Metrics{} for pod, metrics := range podAndMetrics { res[pod.NamespacedName] = metrics } serverRunner.TestPodMetricsClient.SetRes(res) + serverRunner.UseStreaming = streamed serverCtx, stopServer := context.WithCancel(context.Background()) @@ -475,7 +1491,7 @@ func fakePod(index int) backendmetrics.Pod { } // Sets up a test environment and returns the runner struct -func BeforeSuit(t *testing.T) func() { +func BeforeSuite() func() { // Set up mock k8s API Client testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "config", "crd", "bases")}, @@ -499,7 +1515,7 @@ func BeforeSuit(t *testing.T) func() { // Init runtime. ctrl.SetLogger(logger) - mgr, err := server.NewDefaultManager("default", "vllm-llama2-7b-pool", cfg) + mgr, err := server.NewManagerWithOptions(cfg, managerTestOptions("default", "vllm-llama2-7b-pool")) if err != nil { logutil.Fatal(logger, err, "Failed to create controller manager") } @@ -520,7 +1536,7 @@ func BeforeSuit(t *testing.T) func() { logutil.Fatal(logger, err, "Failed to setup server runner") } - // Start the controller manager in go routine, not blocking + // Start the controller manager in a go routine, not blocking go func() { if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil { logutil.Fatal(logger, err, "Failed to start manager") @@ -561,14 +1577,16 @@ func BeforeSuit(t *testing.T) func() { } } - assert.EventuallyWithT(t, func(t *assert.CollectT) { + assert.Eventually(nil, func() bool { modelExist := serverRunner.Datastore.ModelGet("my-model") synced := serverRunner.Datastore.PoolHasSynced() && modelExist != nil - assert.True(t, synced, "Timeout waiting for the pool and models to sync") + return synced }, 10*time.Second, 10*time.Millisecond) return func() { _ = testEnv.Stop() + _ = k8sClient.DeleteAllOf(context.Background(), &v1alpha2.InferencePool{}) + _ = k8sClient.DeleteAllOf(context.Background(), &v1alpha2.InferenceModel{}) } } @@ -588,6 +1606,44 @@ func sendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, return res, err } +func streamedRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, requests []*extProcPb.ProcessingRequest, expectedResponses int) ([]*extProcPb.ProcessingResponse, error) { + for _, req := range requests { + t.Logf("Sending request: %v", req) + if err := client.Send(req); err != nil { + t.Logf("Failed to send request %+v: %v", req, err) + return nil, err + } + // Brief pause for the goroutines to execute sequentially and populate the internal pipe channels sequentially + // without the pause there can be a race condition where a goroutine from a subsequent request is able to populate + // the pipe writer channel before a previous chunk. This is simply due to everything running in memory, this would + // not happen in a real world environment with non-zero latency. + time.Sleep(1 * time.Millisecond) + } + responses := []*extProcPb.ProcessingResponse{} + + // Make an incredible simple timeout func in the case where + // there is less than the expected amount of responses; bail and fail. + var simpleTimeout bool + go func() { + time.Sleep(10 * time.Second) + simpleTimeout = true + }() + + for range expectedResponses { + if simpleTimeout { + break + } + res, err := client.Recv() + if err != nil && err != io.EOF { + t.Logf("Failed to receive: %v", err) + return nil, err + } + t.Logf("Received request %+v", res) + responses = append(responses, res) + } + return responses, nil +} + // readDocuments reads documents from file. func readDocuments(fp string) ([][]byte, error) { b, err := os.ReadFile(fp) @@ -658,3 +1714,41 @@ func registerMetricsHandler(mgr manager.Manager, port int) error { } return nil } + +// inject options that allow multiple test runs to run +// https://github.com/kubernetes-sigs/controller-runtime/issues/2937 +func managerTestOptions(namespace, name string) ctrl.Options { + return ctrl.Options{ + Scheme: scheme, + Cache: cache.Options{ + ByObject: map[client.Object]cache.ByObject{ + &corev1.Pod{}: { + Namespaces: map[string]cache.Config{ + namespace: {}, + }, + }, + &v1alpha2.InferencePool{}: { + Namespaces: map[string]cache.Config{ + namespace: { + FieldSelector: fields.SelectorFromSet(fields.Set{ + "metadata.name": name, + }), + }, + }, + }, + &v1alpha2.InferenceModel{}: { + Namespaces: map[string]cache.Config{ + namespace: {}, + }, + }, + }, + }, + Controller: config.Controller{ + SkipNameValidation: boolPointer(true), + }, + } +} + +func boolPointer(b bool) *bool { + return &b +} diff --git a/test/testdata/inferencepool-with-model-hermetic.yaml b/test/testdata/inferencepool-with-model-hermetic.yaml index c9ca763e1..36b6e539d 100644 --- a/test/testdata/inferencepool-with-model-hermetic.yaml +++ b/test/testdata/inferencepool-with-model-hermetic.yaml @@ -50,3 +50,14 @@ spec: targetModels: - name: my-model-12345 weight: 100 +--- +apiVersion: inference.networking.x-k8s.io/v1alpha2 +kind: InferenceModel +metadata: + name: inferencemodel-direct-model-name + namespace: default +spec: + modelName: direct-model + criticality: Critical + poolRef: + name: vllm-llama2-7b-pool \ No newline at end of file