From b1230a323c609de19d61cacff738c7045cdd2acc Mon Sep 17 00:00:00 2001 From: Nir Rozenbaum Date: Thu, 28 Aug 2025 15:33:01 +0300 Subject: [PATCH 1/2] if request id was not supplied in header, generate uuid Signed-off-by: Nir Rozenbaum --- pkg/epp/handlers/server.go | 16 ++++++-- pkg/epp/server/server_test.go | 17 +++----- test/integration/epp/hermetic_test.go | 57 +++++++++++++++++++++++++++ test/integration/util.go | 5 +++ 4 files changed, 80 insertions(+), 15 deletions(-) diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 89e09f959..6a5c116d5 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -26,6 +26,7 @@ import ( extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3" "github.com/go-logr/logr" + "github.com/google/uuid" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "sigs.k8s.io/controller-runtime/pkg/log" @@ -186,11 +187,18 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) switch v := req.Request.(type) { case *extProcPb.ProcessingRequest_RequestHeaders: - if requestID := requtil.ExtractHeaderValue(v, requtil.RequestIdHeaderKey); len(requestID) > 0 { - logger = logger.WithValues(requtil.RequestIdHeaderKey, requestID) - loggerTrace = logger.V(logutil.TRACE) - ctx = log.IntoContext(ctx, logger) + requestID := requtil.ExtractHeaderValue(v, requtil.RequestIdHeaderKey) + // request ID is a must for maintaining a state per request in plugins that hold internal state and use PluginState. + // if request id was not supplied as a header, we generate it ourselves. + if len(requestID) == 0 { + requestID = uuid.NewString() + loggerTrace.Info("RequestID header is not found in the request, generated a request id") + reqCtx.Request.Headers[requtil.RequestIdHeaderKey] = requestID // update in headers so director can consume it } + logger = logger.WithValues(requtil.RequestIdHeaderKey, requestID) + loggerTrace = logger.V(logutil.TRACE) + ctx = log.IntoContext(ctx, logger) + err = s.HandleRequestHeaders(reqCtx, v) case *extProcPb.ProcessingRequest_RequestBody: loggerTrace.Info("Incoming body chunk", "EoS", v.RequestBody.EndOfStream) diff --git a/pkg/epp/server/server_test.go b/pkg/epp/server/server_test.go index dfcc4bb41..aff6d4644 100644 --- a/pkg/epp/server/server_test.go +++ b/pkg/epp/server/server_test.go @@ -42,13 +42,10 @@ const ( ) func TestServer(t *testing.T) { - theHeaderValue := "body" - requestHeader := "x-test" - expectedRequestHeaders := map[string]string{metadata.DestinationEndpointKey: fmt.Sprintf("%s:%d", podAddress, poolPort), - "Content-Length": "42", ":method": "POST", requestHeader: theHeaderValue} - expectedResponseHeaders := map[string]string{"x-went-into-resp-headers": "true", ":method": "POST", requestHeader: theHeaderValue} - expectedSchedulerHeaders := map[string]string{":method": "POST", requestHeader: theHeaderValue} + "Content-Length": "42", ":method": "POST", "x-test": "body", "x-request-id": "test-request-id"} + expectedResponseHeaders := map[string]string{"x-went-into-resp-headers": "true", ":method": "POST", "x-test": "body"} + expectedSchedulerHeaders := map[string]string{":method": "POST", "x-test": "body", "x-request-id": "test-request-id"} t.Run("server", func(t *testing.T) { model := testutil.MakeInferenceObjective("v1"). @@ -66,9 +63,10 @@ func TestServer(t *testing.T) { // Send request headers - no response expected headers := utils.BuildEnvoyGRPCHeaders(map[string]string{ - requestHeader: theHeaderValue, + "x-test": "body", ":method": "POST", metadata.FlowFairnessIDKey: "a-very-interesting-fairness-id", + "x-request-id": "test-request-id", }, true) request := &pb.ProcessingRequest{ Request: &pb.ProcessingRequest_RequestHeaders{ @@ -130,9 +128,6 @@ func TestServer(t *testing.T) { } // Check headers passed to the scheduler - if len(director.requestHeaders) != 2 { - t.Errorf("Incorrect number of request headers %d instead of 2", len(director.requestHeaders)) - } for expectedKey, expectedValue := range expectedSchedulerHeaders { got, ok := director.requestHeaders[expectedKey] if !ok { @@ -143,7 +138,7 @@ func TestServer(t *testing.T) { } // Send response headers - headers = utils.BuildEnvoyGRPCHeaders(map[string]string{requestHeader: theHeaderValue, ":method": "POST"}, false) + headers = utils.BuildEnvoyGRPCHeaders(map[string]string{"x-test": "body", ":method": "POST"}, false) request = &pb.ProcessingRequest{ Request: &pb.ProcessingRequest_ResponseHeaders{ ResponseHeaders: headers, diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index 384ac5e2e..a215adcf5 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -76,6 +76,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/scorer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" epptestutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" integrationutils "sigs.k8s.io/gateway-api-inference-extension/test/integration" ) @@ -187,6 +188,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { RawValue: []byte("mom"), }, }, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: requtil.RequestIdHeaderKey, + RawValue: []byte("test-request-id"), + }, + }, ), }, { @@ -250,6 +257,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { RawValue: []byte("mom"), }, }, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: requtil.RequestIdHeaderKey, + RawValue: []byte("test-request-id"), + }, + }, ), }, { @@ -279,6 +292,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { RawValue: []byte("mom"), }, }, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: requtil.RequestIdHeaderKey, + RawValue: []byte("test-request-id"), + }, + }, ), }, { @@ -308,6 +327,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { RawValue: []byte("mom"), }, }, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: requtil.RequestIdHeaderKey, + RawValue: []byte("test-request-id"), + }, + }, ), }, { @@ -330,6 +355,10 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { Key: metadata.ModelNameRewriteKey, Value: modelSheddableTarget, }, + { + Key: requtil.RequestIdHeaderKey, + Value: "test-request-id", + }, }, }, }, @@ -368,6 +397,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { RawValue: []byte("mom"), }, }, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: requtil.RequestIdHeaderKey, + RawValue: []byte("test-request-id"), + }, + }, ), }, { @@ -394,6 +429,10 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { Key: metadata.ModelNameRewriteKey, Value: modelDirect, }, + { + Key: requtil.RequestIdHeaderKey, + Value: "test-request-id", + }, }, }, }, @@ -432,6 +471,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { RawValue: []byte("mom"), }, }, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: requtil.RequestIdHeaderKey, + RawValue: []byte("test-request-id"), + }, + }, ), }, // Response flow tests @@ -778,6 +823,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { RawValue: []byte("mom"), }, }, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: requtil.RequestIdHeaderKey, + RawValue: []byte("test-request-id"), + }, + }, ), }, { @@ -811,6 +862,12 @@ func TestFullDuplexStreamed_KubeInferenceObjectiveRequest(t *testing.T) { RawValue: []byte("mom"), }, }, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: requtil.RequestIdHeaderKey, + RawValue: []byte("test-request-id"), + }, + }, ), }, { diff --git a/test/integration/util.go b/test/integration/util.go index 005244982..d78b76e28 100644 --- a/test/integration/util.go +++ b/test/integration/util.go @@ -31,6 +31,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request" ) const ( @@ -130,6 +131,10 @@ func GenerateStreamedRequestSet(logger logr.Logger, prompt, model, targetModel s Key: metadata.ModelNameRewriteKey, Value: targetModel, }, + { + Key: requtil.RequestIdHeaderKey, + Value: "test-request-id", + }, }, }, }, From b100b9f786068100b41f0230e130b82cd0009478 Mon Sep 17 00:00:00 2001 From: Nir Rozenbaum Date: Fri, 29 Aug 2025 08:10:47 +0300 Subject: [PATCH 2/2] convert map to sync.map in plugin state Signed-off-by: Nir Rozenbaum --- pkg/epp/plugins/plugin_state.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pkg/epp/plugins/plugin_state.go b/pkg/epp/plugins/plugin_state.go index 2da35571f..0281900d8 100644 --- a/pkg/epp/plugins/plugin_state.go +++ b/pkg/epp/plugins/plugin_state.go @@ -51,7 +51,7 @@ func NewPluginState(ctx context.Context) *PluginState { // Note: PluginState uses a sync.Map to back the storage, because it is thread safe. // It's aimed to optimize for the "write once and read many times" scenarios. type PluginState struct { - // key: RequestID, value: map[StateKey]StateData + // key: RequestID, value: sync.Map[StateKey]StateData storage sync.Map // key: RequestID, value: time.Time requestToLastAccessTime sync.Map @@ -66,9 +66,9 @@ func (s *PluginState) Read(requestID string, key StateKey) (StateData, error) { return nil, ErrNotFound } - stateData := stateMap.(map[StateKey]StateData) - if value, ok := stateData[key]; ok { - return value, nil + stateData := stateMap.(*sync.Map) + if value, ok := stateData.Load(key); ok { + return value.(StateData), nil } return nil, ErrNotFound @@ -77,15 +77,15 @@ func (s *PluginState) Read(requestID string, key StateKey) (StateData, error) { // Write stores the given "val" in PluginState with the given "key" in the context of the given "requestID". func (s *PluginState) Write(requestID string, key StateKey, val StateData) { s.requestToLastAccessTime.Store(requestID, time.Now()) - var stateData map[StateKey]StateData + var stateData *sync.Map stateMap, ok := s.storage.Load(requestID) if ok { - stateData = stateMap.(map[StateKey]StateData) + stateData = stateMap.(*sync.Map) } else { - stateData = map[StateKey]StateData{} + stateData = &sync.Map{} } - stateData[key] = val + stateData.Store(key, val) s.storage.Store(requestID, stateData) }