diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 707fba00a..0ba3983e4 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -40,6 +40,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics/collectors" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" @@ -227,6 +228,8 @@ func run() error { saturationDetector := saturationdetector.NewDetector(sdConfig, datastore, ctrl.Log) + director := requestcontrol.NewDirector(datastore, scheduler, saturationDetector) // can call "director.WithPostResponsePlugins" to add post response plugins + // --- Setup ExtProc Server Runner --- serverRunner := &runserver.ExtProcServerRunner{ GrpcPort: *grpcPort, @@ -237,7 +240,7 @@ func run() error { SecureServing: *secureServing, CertPath: *certPath, RefreshPrometheusMetricsInterval: *refreshPrometheusMetricsInterval, - Scheduler: scheduler, + Director: director, SaturationDetector: saturationDetector, } if err := serverRunner.SetupWithManager(ctx, mgr); err != nil { diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index debf23964..85f77cb91 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -32,6 +32,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" @@ -79,7 +80,7 @@ type StreamingServer struct { // Specifically, there are fields related to the ext-proc protocol, and then fields related to the lifecycle of the request. // We should split these apart as this monolithic object exposes too much data to too many layers. type RequestContext struct { - TargetPod string + TargetPod *backend.Pod TargetEndpoint string Model string ResolvedTargetModel string @@ -93,6 +94,8 @@ type RequestContext struct { RequestRunning bool Request *Request + SchedulingRequest *schedulingtypes.LLMRequest + RequestState StreamRequestState modelServerStreaming bool diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go index c77e0f05c..50f637478 100644 --- a/pkg/epp/metrics/metrics.go +++ b/pkg/epp/metrics/metrics.go @@ -202,6 +202,18 @@ var ( []string{"plugin_type", "plugin_name"}, ) + RequestControlPluginProcessingLatencies = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: InferenceExtension, + Name: "request_control_plugin_duration_seconds", + Help: metricsutil.HelpMsgWithStability("RequestControl plugin processing latency distribution in seconds for each plugin type and plugin name.", compbasemetrics.ALPHA), + Buckets: []float64{ + 0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, + }, + }, + []string{"plugin_type", "plugin_name"}, + ) + // Prefix indexer Metrics PrefixCacheSize = prometheus.NewGaugeVec( prometheus.GaugeOpts{ @@ -263,6 +275,7 @@ func Register(customCollectors ...prometheus.Collector) { metrics.Registry.MustRegister(inferencePoolReadyPods) metrics.Registry.MustRegister(SchedulerPluginProcessingLatencies) metrics.Registry.MustRegister(SchedulerE2ELatency) + metrics.Registry.MustRegister(RequestControlPluginProcessingLatencies) metrics.Registry.MustRegister(InferenceExtensionInfo) metrics.Registry.MustRegister(PrefixCacheSize) metrics.Registry.MustRegister(PrefixCacheHitRatio) @@ -289,6 +302,7 @@ func Reset() { inferencePoolReadyPods.Reset() SchedulerPluginProcessingLatencies.Reset() SchedulerE2ELatency.Reset() + RequestControlPluginProcessingLatencies.Reset() InferenceExtensionInfo.Reset() PrefixCacheSize.Reset() PrefixCacheHitRatio.Reset() @@ -400,6 +414,11 @@ func RecordSchedulerE2ELatency(duration time.Duration) { SchedulerE2ELatency.WithLabelValues().Observe(duration.Seconds()) } +// RecordRequestControlPluginProcessingLatency records the processing latency for a request-control plugin. +func RecordRequestControlPluginProcessingLatency(pluginType, pluginName string, duration time.Duration) { + RequestControlPluginProcessingLatencies.WithLabelValues(pluginType, pluginName).Observe(duration.Seconds()) +} + // RecordPrefixCacheSize records the size of the prefix indexer in megabytes. func RecordPrefixCacheSize(size int64) { PrefixCacheSize.WithLabelValues().Set(float64(size)) diff --git a/pkg/epp/plugins/plugins.go b/pkg/epp/plugins/plugins.go new file mode 100644 index 000000000..5dd8d87bf --- /dev/null +++ b/pkg/epp/plugins/plugins.go @@ -0,0 +1,24 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugins + +// Plugin defines the interface for a plugin. +// This interface should be embedded in all plugins across the code. +type Plugin interface { + // Name returns the name of the plugin. + Name() string +} diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index 2bb92e5ba..089bbab5e 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -23,6 +23,7 @@ import ( "fmt" "math/rand" "strconv" + "time" "github.com/go-logr/logr" "sigs.k8s.io/controller-runtime/pkg/log" @@ -30,6 +31,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -39,7 +41,6 @@ import ( // Scheduler defines the interface required by the Director for scheduling. type Scheduler interface { Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result map[string]*schedulingtypes.Result, err error) - OnResponse(ctx context.Context, resp *schedulingtypes.LLMResponse, targetPodName string) } // SaturationDetector provides a signal indicating whether the backends are considered saturated. @@ -47,16 +48,25 @@ type SaturationDetector interface { IsSaturated(ctx context.Context) bool } +// NewDirector creates a new Director instance with all dependencies. +// postResponsePlugins remains nil as this is an optional field that can be set using the "WithPostResponsePlugins" function. +func NewDirector(datastore datastore.Datastore, scheduler Scheduler, saturationDetector SaturationDetector) *Director { + return &Director{datastore: datastore, scheduler: scheduler, saturationDetector: saturationDetector} +} + // Director orchestrates the request handling flow, including scheduling. type Director struct { - datastore datastore.Datastore - scheduler Scheduler - saturationDetector SaturationDetector + datastore datastore.Datastore + scheduler Scheduler + saturationDetector SaturationDetector + postResponsePlugins []PostResponsePlugin } -// NewDirector creates a new Director instance with all dependencies. -func NewDirector(datastore datastore.Datastore, scheduler Scheduler, saturationDetector SaturationDetector) *Director { - return &Director{datastore, scheduler, saturationDetector} +// WithPostResponsePlugins sets the given plugins as the PostResponse plugins. +// If the Director has PostResponse plugins already, this call replaces the existing plugins with the given ones. +func (d *Director) WithPostResponsePlugins(plugins ...PostResponsePlugin) *Director { + d.postResponsePlugins = plugins + return d } // HandleRequest orchestrates the request lifecycle: @@ -104,7 +114,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo } // Prepare LLMRequest (needed for both saturation detection and Scheduler) - llmReq := &schedulingtypes.LLMRequest{ + reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{ TargetModel: reqCtx.ResolvedTargetModel, RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], Critical: requestCriticality == v1alpha2.Critical, @@ -113,7 +123,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo } logger = logger.WithValues( "model", reqCtx.Model, - "resolvedTargetModel", llmReq.TargetModel, + "resolvedTargetModel", reqCtx.ResolvedTargetModel, "criticality", requestCriticality, ) ctx = log.IntoContext(ctx, logger) @@ -126,7 +136,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo } // --- 3. Dispatch (Calls Scheduler) --- - results, dispatchErr := d.Dispatch(ctx, llmReq) + results, dispatchErr := d.Dispatch(ctx, reqCtx.SchedulingRequest) if dispatchErr != nil { return reqCtx, dispatchErr } @@ -193,22 +203,19 @@ func (d *Director) PostDispatch(ctx context.Context, reqCtx *handlers.RequestCon endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) logger.V(logutil.DEFAULT).Info("Request handled", "model", reqCtx.Model, "targetModel", reqCtx.ResolvedTargetModel, "endpoint", targetPod) - reqCtx.TargetPod = targetPod.NamespacedName.String() + reqCtx.TargetPod = targetPod reqCtx.TargetEndpoint = endpoint return reqCtx, nil } func (d *Director) HandleResponse(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { - logger := log.FromContext(ctx) - - llmResp := &schedulingtypes.LLMResponse{ + response := &Response{ RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], Headers: reqCtx.Response.Headers, } - logger.V(logutil.DEBUG).Info("LLM response assembled", "response", llmResp) - d.scheduler.OnResponse(ctx, llmResp, reqCtx.TargetPod) + d.runPostResponsePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod) return reqCtx, nil } @@ -253,3 +260,12 @@ func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed } return "" } + +func (d *Director) runPostResponsePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { + for _, plugin := range d.postResponsePlugins { + log.FromContext(ctx).V(logutil.DEBUG).Info("Running post-response plugin", "plugin", plugin.Name()) + before := time.Now() + plugin.PostResponse(ctx, request, response, targetPod) + metrics.RecordRequestControlPluginProcessingLatency(PostResponsePluginType, plugin.Name(), time.Since(before)) + } +} diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index a56115818..a7ca55a93 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -27,6 +27,7 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" k8stypes "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" "sigs.k8s.io/controller-runtime/pkg/client/fake" @@ -53,24 +54,14 @@ func (m *mockSaturationDetector) IsSaturated(_ context.Context) bool { } type mockScheduler struct { - scheduleResults map[string]*schedulingtypes.Result - scheduleErr error - lastRespOnResponse *schedulingtypes.LLMResponse - lastTargetPodOnResponse string + scheduleResults map[string]*schedulingtypes.Result + scheduleErr error } -func (m *mockScheduler) Schedule( - ctx context.Context, - req *schedulingtypes.LLMRequest, -) (map[string]*schedulingtypes.Result, error) { +func (m *mockScheduler) Schedule(ctx context.Context, req *schedulingtypes.LLMRequest) (map[string]*schedulingtypes.Result, error) { return m.scheduleResults, m.scheduleErr } -func (m *mockScheduler) OnResponse(ctx context.Context, resp *schedulingtypes.LLMResponse, targetPodName string) { - m.lastRespOnResponse = resp - m.lastTargetPodOnResponse = targetPodName -} - func TestDirector_HandleRequest(t *testing.T) { ctx := logutil.NewTestLoggerIntoContext(context.Background()) @@ -170,8 +161,11 @@ func TestDirector_HandleRequest(t *testing.T) { wantReqCtx: &handlers.RequestContext{ Model: model, ResolvedTargetModel: model, - TargetPod: "default/pod1", - TargetEndpoint: "192.168.1.100:8000", + TargetPod: &backend.Pod{ + NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, + Address: "192.168.1.100", + }, + TargetEndpoint: "192.168.1.100:8000", }, wantMutatedBodyModel: model, }, @@ -192,8 +186,11 @@ func TestDirector_HandleRequest(t *testing.T) { wantReqCtx: &handlers.RequestContext{ Model: model, ResolvedTargetModel: model, - TargetPod: "default/pod1", - TargetEndpoint: "192.168.1.100:8000", + TargetPod: &backend.Pod{ + NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, + Address: "192.168.1.100", + }, + TargetEndpoint: "192.168.1.100:8000", }, wantMutatedBodyModel: model, }, @@ -218,8 +215,11 @@ func TestDirector_HandleRequest(t *testing.T) { wantReqCtx: &handlers.RequestContext{ Model: model, ResolvedTargetModel: model, - TargetPod: "default/pod1", - TargetEndpoint: "192.168.1.100:8000", + TargetPod: &backend.Pod{ + NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, + Address: "192.168.1.100", + }, + TargetEndpoint: "192.168.1.100:8000", }, wantMutatedBodyModel: model, }, @@ -236,8 +236,11 @@ func TestDirector_HandleRequest(t *testing.T) { wantReqCtx: &handlers.RequestContext{ Model: modelSheddable, ResolvedTargetModel: modelSheddable, - TargetPod: "default/pod1", - TargetEndpoint: "192.168.1.100:8000", + TargetPod: &backend.Pod{ + NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, + Address: "192.168.1.100", + }, + TargetEndpoint: "192.168.1.100:8000", }, wantMutatedBodyModel: modelSheddable, }, @@ -254,8 +257,11 @@ func TestDirector_HandleRequest(t *testing.T) { wantReqCtx: &handlers.RequestContext{ Model: modelWithResolvedTarget, ResolvedTargetModel: "resolved-target-model-A", - TargetPod: "default/pod1", - TargetEndpoint: "192.168.1.100:8000", + TargetPod: &backend.Pod{ + NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"}, + Address: "192.168.1.100", + }, + TargetEndpoint: "192.168.1.100:8000", }, wantMutatedBodyModel: "resolved-target-model-A", }, @@ -338,12 +344,7 @@ func TestDirector_HandleRequest(t *testing.T) { if test.schedulerMockSetup != nil { test.schedulerMockSetup(mockSched) } - - var sd SaturationDetector - if test.mockSaturationDetector != nil { - sd = test.mockSaturationDetector - } - director := NewDirector(ds, mockSched, sd) + director := NewDirector(ds, mockSched, test.mockSaturationDetector) reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ @@ -513,10 +514,15 @@ func pointer(v int32) *int32 { } func TestDirector_HandleResponse(t *testing.T) { + pr1 := &testPostResponse{ + NameRes: "pr1", + } + ctx := logutil.NewTestLoggerIntoContext(context.Background()) ds := datastore.NewDatastore(t.Context(), nil) mockSched := &mockScheduler{} - director := NewDirector(ds, mockSched, nil) + director := NewDirector(ds, mockSched, nil). + WithPostResponsePlugins(pr1) reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ @@ -527,7 +533,8 @@ func TestDirector_HandleResponse(t *testing.T) { Response: &handlers.Response{ // Simulate some response headers Headers: map[string]string{"X-Test-Response-Header": "TestValue"}, }, - TargetPod: "namespace1/test-pod-name", + + TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } _, err := director.HandleResponse(ctx, reqCtx) @@ -535,13 +542,26 @@ func TestDirector_HandleResponse(t *testing.T) { t.Fatalf("HandleResponse() returned unexpected error: %v", err) } - if diff := cmp.Diff("test-req-id-for-response", mockSched.lastRespOnResponse.RequestId); diff != "" { + if diff := cmp.Diff("test-req-id-for-response", pr1.lastRespOnResponse.RequestId); diff != "" { t.Errorf("Scheduler.OnResponse RequestId mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff(reqCtx.Response.Headers, mockSched.lastRespOnResponse.Headers); diff != "" { + if diff := cmp.Diff(reqCtx.Response.Headers, pr1.lastRespOnResponse.Headers); diff != "" { t.Errorf("Scheduler.OnResponse Headers mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff("namespace1/test-pod-name", mockSched.lastTargetPodOnResponse); diff != "" { + if diff := cmp.Diff("namespace1/test-pod-name", pr1.lastTargetPodOnResponse); diff != "" { t.Errorf("Scheduler.OnResponse TargetPodName mismatch (-want +got):\n%s", diff) } } + +type testPostResponse struct { + NameRes string + lastRespOnResponse *Response + lastTargetPodOnResponse string +} + +func (p *testPostResponse) Name() string { return p.NameRes } + +func (p *testPostResponse) PostResponse(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { + p.lastRespOnResponse = response + p.lastTargetPodOnResponse = targetPod.NamespacedName.String() +} diff --git a/pkg/epp/requestcontrol/plugins.go b/pkg/epp/requestcontrol/plugins.go new file mode 100644 index 000000000..994ff56d3 --- /dev/null +++ b/pkg/epp/requestcontrol/plugins.go @@ -0,0 +1,36 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package requestcontrol + +import ( + "context" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +const ( + PostResponsePluginType = "PostResponse" +) + +// PostResponse is called by the director after a successful response was sent. +// The given pod argument is the pod that served the request. +type PostResponsePlugin interface { + plugins.Plugin + PostResponse(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod) +} diff --git a/pkg/epp/requestcontrol/types.go b/pkg/epp/requestcontrol/types.go new file mode 100644 index 000000000..8604e1dda --- /dev/null +++ b/pkg/epp/requestcontrol/types.go @@ -0,0 +1,31 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package requestcontrol + +// Response contains information from the response received to be passed to PostResponse plugins +type Response struct { + // RequestId is the Envoy generated Id for the request being processed + RequestId string + // Headers is a map of the response headers. Nil during body processing + Headers map[string]string + // Body Is the body of the response or nil during header processing + Body string + // IsStreaming indicates whether or not the response is being streamed by the model + IsStreaming bool + // EndOfStream when true indicates that this invocation contains the last chunk of the response + EndOfStream bool +} diff --git a/pkg/epp/scheduling/framework/plugins.go b/pkg/epp/scheduling/framework/plugins.go index a402cc262..2cc1ff4bf 100644 --- a/pkg/epp/scheduling/framework/plugins.go +++ b/pkg/epp/scheduling/framework/plugins.go @@ -19,60 +19,46 @@ package framework import ( "context" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) const ( - ProfilePickerType = "ProfilePicker" - FilterPluginType = "Filter" - ScorerPluginType = "Scorer" - PickerPluginType = "Picker" - PostCyclePluginType = "PostCycle" - PostResponsePluginType = "PostResponse" + ProfilePickerType = "ProfilePicker" + FilterPluginType = "Filter" + ScorerPluginType = "Scorer" + PickerPluginType = "Picker" + PostCyclePluginType = "PostCycle" ) -// Plugin defines the interface for scheduler plugins, combining scoring, filtering, -// and event handling capabilities. -type Plugin interface { - // Name returns the name of the plugin. - Name() string -} - // ProfilePicker selects the SchedulingProfiles to run from a list of candidate profiles, while taking into consideration the request properties // and the previously executed SchedluderProfile cycles along with their results. type ProfilePicker interface { - Plugin + plugins.Plugin Pick(request *types.LLMRequest, profiles map[string]*SchedulerProfile, executionResults map[string]*types.Result) map[string]*SchedulerProfile } // Filter defines the interface for filtering a list of pods based on context. type Filter interface { - Plugin + plugins.Plugin Filter(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) []types.Pod } // Scorer defines the interface for scoring a list of pods based on context. // Scorers must score pods with a value within the range of [0,1] where 1 is the highest score. type Scorer interface { - Plugin + plugins.Plugin Score(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) map[types.Pod]float64 } // Picker picks the final pod(s) to send the request to. type Picker interface { - Plugin + plugins.Plugin Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.Result } // PostCycle is called by the scheduler after it selects a targetPod for the request in the SchedulerProfile cycle. type PostCycle interface { - Plugin + plugins.Plugin PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.Result) } - -// PostResponse is called by the scheduler after a successful response was sent. -// The given pod argument is the pod that served the request. -type PostResponse interface { - Plugin - PostResponse(ctx context.Context, response *types.LLMResponse, targetPod types.Pod) -} diff --git a/pkg/epp/scheduling/framework/scheduler_profile.go b/pkg/epp/scheduling/framework/scheduler_profile.go index 25528958b..ccbb4fba0 100644 --- a/pkg/epp/scheduling/framework/scheduler_profile.go +++ b/pkg/epp/scheduling/framework/scheduler_profile.go @@ -23,6 +23,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -31,21 +32,19 @@ import ( // NewSchedulerProfile creates a new SchedulerProfile object and returns its pointer. func NewSchedulerProfile() *SchedulerProfile { return &SchedulerProfile{ - filters: []Filter{}, - scorers: []*WeightedScorer{}, - postCyclePlugins: []PostCycle{}, - PostResponsePlugins: []PostResponse{}, + filters: []Filter{}, + scorers: []*WeightedScorer{}, + postCyclePlugins: []PostCycle{}, // picker remains nil since profile doesn't support multiple pickers } } // SchedulerProfile provides a profile configuration for the scheduler which influence routing decisions. type SchedulerProfile struct { - filters []Filter - scorers []*WeightedScorer - picker Picker - postCyclePlugins []PostCycle - PostResponsePlugins []PostResponse // TODO this field should get out of the scheduler + filters []Filter + scorers []*WeightedScorer + picker Picker + postCyclePlugins []PostCycle } // WithFilters sets the given filter plugins as the Filter plugins. @@ -81,7 +80,7 @@ func (p *SchedulerProfile) WithPostCyclePlugins(plugins ...PostCycle) *Scheduler // Special Case: In order to add a scorer, one must use the scorer.NewWeightedScorer function in order to provide a weight. // if a scorer implements more than one interface, supplying a WeightedScorer is sufficient. The function will take the internal // scorer object and register it to all interfaces it implements. -func (p *SchedulerProfile) AddPlugins(pluginObjects ...Plugin) error { +func (p *SchedulerProfile) AddPlugins(pluginObjects ...plugins.Plugin) error { for _, plugin := range pluginObjects { if weightedScorer, ok := plugin.(*WeightedScorer); ok { p.scorers = append(p.scorers, weightedScorer) @@ -101,9 +100,6 @@ func (p *SchedulerProfile) AddPlugins(pluginObjects ...Plugin) error { if postCyclePlugin, ok := plugin.(PostCycle); ok { p.postCyclePlugins = append(p.postCyclePlugins, postCyclePlugin) } - if postResponsePlugin, ok := plugin.(PostResponse); ok { - p.PostResponsePlugins = append(p.PostResponsePlugins, postResponsePlugin) - } } return nil } diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index a1aa5bea1..bb3798a16 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -135,35 +135,3 @@ func (s *Scheduler) Schedule(ctx context.Context, request *types.LLMRequest) (ma return profileExecutionResults, nil } - -// OnResponse is invoked during the processing of a response from an inference pod. It will invoke -// any defined plugins that process the response. -func (s *Scheduler) OnResponse(ctx context.Context, response *types.LLMResponse, targetPodName string) { - // Snapshot pod metrics from the datastore to: - // 1. Reduce concurrent access to the datastore. - // 2. Ensure consistent data during the scheduling operation of a request. - pods := types.ToSchedulerPodMetrics(s.datastore.PodGetAll()) - var targetPod types.Pod - for _, pod := range pods { - if pod.GetPod().NamespacedName.String() == targetPodName { - targetPod = pod - break - } - } - - // WORKAROUND until PostResponse is out of Scheduler - profileExecutionResults := map[string]*types.Result{} - profiles := s.profilePicker.Pick(nil, s.profiles, profileExecutionResults) // all profiles - for _, profile := range profiles { - s.runPostResponsePlugins(ctx, response, targetPod, profile) - } -} - -func (s *Scheduler) runPostResponsePlugins(ctx context.Context, response *types.LLMResponse, targetPod types.Pod, profile *framework.SchedulerProfile) { - for _, plugin := range profile.PostResponsePlugins { - log.FromContext(ctx).V(logutil.DEBUG).Info("Running post-response plugin", "plugin", plugin.Name()) - before := time.Now() - plugin.PostResponse(ctx, response, targetPod) - metrics.RecordSchedulerPluginProcessingLatency(framework.PostResponsePluginType, plugin.Name(), time.Since(before)) - } -} diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index 8fcaf3287..b8f1d7f10 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -25,8 +25,6 @@ import ( k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - profilepicker "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile-picker" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -247,60 +245,6 @@ func TestSchedule(t *testing.T) { } } -func TestPostResponse(t *testing.T) { - pr1 := &testPostResponse{ - NameRes: "pr1", - ExtraHeaders: map[string]string{"x-session-id": "qwer-asdf-zxcv"}, - ReceivedResponseHeaders: make(map[string]string), - } - - targetPod := k8stypes.NamespacedName{Name: "pod2"} - - tests := []struct { - name string - config *framework.SchedulerProfile - input []*backendmetrics.FakePodMetrics - responseHeaders map[string]string - wantUpdatedHeaders map[string]string - }{ - { - name: "Simple postResponse test", - config: &framework.SchedulerProfile{ - PostResponsePlugins: []framework.PostResponse{pr1}, - }, - input: []*backendmetrics.FakePodMetrics{ - {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, - {Pod: &backend.Pod{NamespacedName: targetPod}}, - }, - responseHeaders: map[string]string{"Content-type": "application/json", "Content-Length": "1234"}, - wantUpdatedHeaders: map[string]string{"x-session-id": "qwer-asdf-zxcv", "Content-type": "application/json", "Content-Length": "1234"}, - }, - } - - for _, test := range tests { - schedulerConfig := NewSchedulerConfig(profilepicker.NewAllProfilesPicker(), map[string]*framework.SchedulerProfile{"default": test.config}) - scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, schedulerConfig) - - headers := map[string]string{} - for k, v := range test.responseHeaders { - headers[k] = v - } - resp := &types.LLMResponse{ - Headers: headers, - } - - scheduler.OnResponse(context.Background(), resp, targetPod.String()) - - if diff := cmp.Diff(test.responseHeaders, pr1.ReceivedResponseHeaders); diff != "" { - t.Errorf("Unexpected output (-responseHeaders +ReceivedResponseHeaders): %v", diff) - } - - if diff := cmp.Diff(test.wantUpdatedHeaders, resp.Headers); diff != "" { - t.Errorf("Unexpected output (-wantUpdatedHeaders +resp.Headers): %v", diff) - } - } -} - type fakeDataStore struct { pods []*backendmetrics.FakePodMetrics } @@ -312,20 +256,3 @@ func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics { } return pm } - -type testPostResponse struct { - NameRes string - ReceivedResponseHeaders map[string]string - ExtraHeaders map[string]string -} - -func (pr *testPostResponse) Name() string { return pr.NameRes } - -func (pr *testPostResponse) PostResponse(_ context.Context, response *types.LLMResponse, _ types.Pod) { - for key, value := range response.Headers { - pr.ReceivedResponseHeaders[key] = value - } - for key, value := range pr.ExtraHeaders { - response.Headers[key] = value - } -} diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index c7f6fa53d..0d5fadb58 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -41,20 +41,6 @@ func (r *LLMRequest) String() string { return fmt.Sprintf("TargetModel: %s, Critical: %t, PromptLength: %d, Headers: %v", r.TargetModel, r.Critical, len(r.Prompt), r.Headers) } -// LLMResponse contains information from the response received to be passed to plugins -type LLMResponse struct { - // RequestId is the Envoy generated Id for the request being processed - RequestId string - // Headers is a map of the response headers. Nil during body processing - Headers map[string]string - // Body Is the body of the response or nil during header processing - Body string - // IsStreaming indicates whether or not the response is being streamed by the model - IsStreaming bool - // EndOfStream when true indicates that this invocation contains the last chunk of the response - EndOfStream bool -} - type Pod interface { GetPod() *backend.Pod GetMetrics() *backendmetrics.MetricsState diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index 3dd1d58bd..cc29ee462 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -48,7 +48,7 @@ type ExtProcServerRunner struct { SecureServing bool CertPath string RefreshPrometheusMetricsInterval time.Duration - Scheduler requestcontrol.Scheduler + Director *requestcontrol.Director SaturationDetector requestcontrol.SaturationDetector // This should only be used in tests. We won't need this once we do not inject metrics in the tests. @@ -141,12 +141,11 @@ func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { srv = grpc.NewServer() } - director := requestcontrol.NewDirector(r.Datastore, r.Scheduler, r.SaturationDetector) extProcServer := handlers.NewStreamingServer( r.DestinationEndpointHintMetadataNamespace, r.DestinationEndpointHintKey, r.Datastore, - director, + r.Director, ) extProcPb.RegisterExternalProcessorServer( srv, diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index 525852693..e740d499c 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -62,6 +62,7 @@ import ( backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/saturationdetector" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" @@ -1588,7 +1589,7 @@ func BeforeSuite() func() { // Adjust from defaults serverRunner.PoolNamespacedName = types.NamespacedName{Name: "vllm-llama3-8b-instruct-pool", Namespace: "default"} serverRunner.Datastore = datastore.NewDatastore(context.Background(), pmf) - serverRunner.Scheduler = scheduling.NewScheduler(serverRunner.Datastore) + scheduler := scheduling.NewScheduler(serverRunner.Datastore) sdConfig := &saturationdetector.Config{ QueueDepthThreshold: saturationdetector.DefaultQueueDepthThreshold, @@ -1597,6 +1598,7 @@ func BeforeSuite() func() { } detector := saturationdetector.NewDetector(sdConfig, serverRunner.Datastore, logger.WithName("saturation-detector")) serverRunner.SaturationDetector = detector + serverRunner.Director = requestcontrol.NewDirector(serverRunner.Datastore, scheduler, detector) serverRunner.SecureServing = false if err := serverRunner.SetupWithManager(context.Background(), mgr); err != nil {