From a807b406c1f5e95b72da45b99c4d7fa822183df8 Mon Sep 17 00:00:00 2001 From: ahg-g Date: Sat, 22 Feb 2025 06:52:07 +0000 Subject: [PATCH 1/5] Currently the logic tracks the models by Spec.ModelName, since this is not guaranteed to be unique within the cluster, we could run into two issues: 1) If the model name changes on the same InferenceModel object, we don't delete the original model entry in the datastore. 2) We don't enforce the semantics that the modelName with the oldest creation timestamp is retained. While the api is assuming that this is enforced by another controller via the Ready condition, we don't have this controller yet, and so currently the behavior is unpredictable depending on InferenceModel events order. To address the above, the PR makes changes to both the InferenceModel reconciler and the Model APIs in the datastore to ensure thread safe updates of the entries. In the store, the sync.Map was replaced with two maps to track the InferenceModel entries by both ModelName and InferenceModel object NamespacedName. This is needed to properly handle deletions when the object doesn't exist anymore (could be handled in other ways, but this seemed like a reasonable approach). The PR increases the datastore pkg unit test coverage the Pool and Model APIs. We still need to followup with adding unit test coverage to the pods APIs, which is currently non-existent. --- cmd/epp/main.go | 6 +- pkg/epp/backend/provider_test.go | 50 +- .../controller/inferencemodel_reconciler.go | 113 ++++- .../inferencemodel_reconciler_test.go | 439 +++++++----------- .../inferencepool_reconciler_test.go | 93 ++-- pkg/epp/controller/pod_reconciler.go | 2 +- pkg/epp/controller/pod_reconciler_test.go | 148 ++---- pkg/epp/datastore/datastore.go | 128 ++++- pkg/epp/datastore/datastore_test.go | 172 ++++++- pkg/epp/handlers/request.go | 2 +- pkg/epp/server/runserver.go | 4 +- pkg/epp/test/utils.go | 9 +- pkg/epp/util/testing/diff.go | 27 ++ pkg/epp/util/testing/wrappers.go | 117 ++++- test/e2e/e2e_suite_test.go | 5 - test/integration/hermetic_test.go | 11 +- 16 files changed, 784 insertions(+), 542 deletions(-) create mode 100644 pkg/epp/util/testing/diff.go diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 5d399a42a..b66024ecb 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -149,6 +149,8 @@ func run() error { return err } + ctx := ctrl.SetupSignalHandler() + // Setup runner. datastore := datastore.NewDatastore() provider := backend.NewProvider(&vllm.PodMetricsClientImpl{}, datastore) @@ -165,7 +167,7 @@ func run() error { CertPath: *certPath, Provider: provider, } - if err := serverRunner.SetupWithManager(mgr); err != nil { + if err := serverRunner.SetupWithManager(ctx, mgr); err != nil { setupLog.Error(err, "Failed to setup ext-proc controllers") return err } @@ -188,7 +190,7 @@ func run() error { // Start the manager. This blocks until a signal is received. setupLog.Info("Controller manager starting") - if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil { + if err := mgr.Start(ctx); err != nil { setupLog.Error(err, "Error starting controller manager") return err } diff --git a/pkg/epp/backend/provider_test.go b/pkg/epp/backend/provider_test.go index 1e11afe2c..f2db09feb 100644 --- a/pkg/epp/backend/provider_test.go +++ b/pkg/epp/backend/provider_test.go @@ -19,7 +19,6 @@ package backend import ( "context" "errors" - "sync" "testing" "time" @@ -37,6 +36,9 @@ var ( Name: "pod1", }, }, + } + pod1WithMetrics = &datastore.PodMetrics{ + Pod: pod1.Pod, Metrics: datastore.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, @@ -53,6 +55,9 @@ var ( Name: "pod2", }, }, + } + pod2WithMetrics = &datastore.PodMetrics{ + Pod: pod2.Pod, Metrics: datastore.Metrics{ WaitingQueueSize: 1, KVCacheUsagePercent: 0.2, @@ -69,35 +74,30 @@ func TestProvider(t *testing.T) { tests := []struct { name string pmc PodMetricsClient - datastore datastore.Datastore + storePods []*datastore.PodMetrics want []*datastore.PodMetrics }{ { name: "Probing metrics success", pmc: &FakePodMetricsClient{ Res: map[types.NamespacedName]*datastore.PodMetrics{ - pod1.NamespacedName: pod1, - pod2.NamespacedName: pod2, + pod1.NamespacedName: pod1WithMetrics, + pod2.NamespacedName: pod2WithMetrics, }, }, - datastore: datastore.NewFakeDatastore(populateMap(pod1, pod2), nil, nil), - want: []*datastore.PodMetrics{ - pod1, - pod2, - }, + storePods: []*datastore.PodMetrics{pod1, pod2}, + want: []*datastore.PodMetrics{pod1WithMetrics, pod2WithMetrics}, }, { name: "Only pods in the datastore are probed", pmc: &FakePodMetricsClient{ Res: map[types.NamespacedName]*datastore.PodMetrics{ - pod1.NamespacedName: pod1, - pod2.NamespacedName: pod2, + pod1.NamespacedName: pod1WithMetrics, + pod2.NamespacedName: pod2WithMetrics, }, }, - datastore: datastore.NewFakeDatastore(populateMap(pod1), nil, nil), - want: []*datastore.PodMetrics{ - pod1, - }, + storePods: []*datastore.PodMetrics{pod1}, + want: []*datastore.PodMetrics{pod1WithMetrics}, }, { name: "Probing metrics error", @@ -106,13 +106,12 @@ func TestProvider(t *testing.T) { pod2.NamespacedName: errors.New("injected error"), }, Res: map[types.NamespacedName]*datastore.PodMetrics{ - pod1.NamespacedName: pod1, + pod1.NamespacedName: pod1WithMetrics, }, }, - datastore: datastore.NewFakeDatastore(populateMap(pod1, pod2), nil, nil), - + storePods: []*datastore.PodMetrics{pod1, pod2}, want: []*datastore.PodMetrics{ - pod1, + pod1WithMetrics, // Failed to fetch pod2 metrics so it remains the default values. { Pod: datastore.Pod{NamespacedName: pod2.NamespacedName}, @@ -128,12 +127,13 @@ func TestProvider(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - p := NewProvider(test.pmc, test.datastore) + ds := datastore.NewFakeDatastore(test.storePods, nil, nil) + p := NewProvider(test.pmc, ds) ctx, cancel := context.WithCancel(context.Background()) defer cancel() _ = p.Init(ctx, time.Millisecond, time.Millisecond) assert.EventuallyWithT(t, func(t *assert.CollectT) { - metrics := test.datastore.PodGetAll() + metrics := ds.PodGetAll() diff := cmp.Diff(test.want, metrics, cmpopts.SortSlices(func(a, b *datastore.PodMetrics) bool { return a.String() < b.String() })) @@ -142,11 +142,3 @@ func TestProvider(t *testing.T) { }) } } - -func populateMap(pods ...*datastore.PodMetrics) *sync.Map { - newMap := &sync.Map{} - for _, pod := range pods { - newMap.Store(pod.NamespacedName, &datastore.PodMetrics{Pod: datastore.Pod{NamespacedName: pod.NamespacedName, Address: pod.Address}}) - } - return newMap -} diff --git a/pkg/epp/controller/inferencemodel_reconciler.go b/pkg/epp/controller/inferencemodel_reconciler.go index 9de77989c..778b8e3c2 100644 --- a/pkg/epp/controller/inferencemodel_reconciler.go +++ b/pkg/epp/controller/inferencemodel_reconciler.go @@ -18,8 +18,8 @@ package controller import ( "context" + "fmt" - "github.com/go-logr/logr" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" @@ -34,6 +34,10 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) +const ( + modelNameKey = "spec.modelName" +) + type InferenceModelReconciler struct { client.Client Scheme *runtime.Scheme @@ -43,44 +47,103 @@ type InferenceModelReconciler struct { } func (c *InferenceModelReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { - logger := log.FromContext(ctx) - loggerDefault := logger.V(logutil.DEFAULT) - loggerDefault.Info("Reconciling InferenceModel", "name", req.NamespacedName) + if req.Namespace != c.PoolNamespacedName.Namespace { + return ctrl.Result{}, nil + } + logger := log.FromContext(ctx).V(logutil.DEFAULT).WithValues("inferenceModel", req.Name) + ctx = ctrl.LoggerInto(ctx, logger) + + logger.Info("Reconciling InferenceModel") infModel := &v1alpha2.InferenceModel{} + notFound := false if err := c.Get(ctx, req.NamespacedName, infModel); err != nil { - if errors.IsNotFound(err) { - loggerDefault.Info("InferenceModel not found. Removing from datastore since object must be deleted", "name", req.NamespacedName) - c.Datastore.ModelDelete(infModel.Spec.ModelName) - return ctrl.Result{}, nil + if !errors.IsNotFound(err) { + logger.Error(err, "Unable to get InferenceModel") + return ctrl.Result{}, err } - loggerDefault.Error(err, "Unable to get InferenceModel", "name", req.NamespacedName) + notFound = true + } + + if notFound || !infModel.DeletionTimestamp.IsZero() || infModel.Spec.PoolRef.Name != c.PoolNamespacedName.Name { + // InferenceModel object got deleted or changed the referenced pool. + err := c.handleModelDeleted(ctx, req.NamespacedName) return ctrl.Result{}, err - } else if !infModel.DeletionTimestamp.IsZero() { - loggerDefault.Info("InferenceModel is marked for deletion. Removing from datastore", "name", req.NamespacedName) - c.Datastore.ModelDelete(infModel.Spec.ModelName) - return ctrl.Result{}, nil } - c.updateDatastore(logger, infModel) + // Add or update if the InferenceModel instance has a creation timestamp older than the existing entry of the model. + logger = logger.WithValues("poolRef", infModel.Spec.PoolRef).WithValues("modelName", infModel.Spec.ModelName) + if !c.Datastore.ModelSetIfOlder(infModel) { + logger.Info("Skipping InferenceModel, existing instance has older creation timestamp") + + } + logger.Info("Added/Updated InferenceModel") + return ctrl.Result{}, nil } -func (c *InferenceModelReconciler) updateDatastore(logger logr.Logger, infModel *v1alpha2.InferenceModel) { - loggerDefault := logger.V(logutil.DEFAULT) +func (c *InferenceModelReconciler) handleModelDeleted(ctx context.Context, req types.NamespacedName) error { + logger := log.FromContext(ctx) + + // We will lookup the modelName associated with this object to search for + // other instance referencing the same ModelName if exist to store the oldest in + // its place. This ensures that the InferenceModel with the oldest creation + // timestamp is active. + existing, exists := c.Datastore.ModelGetByObjName(req) + if !exists { + // No entry exists in the first place, nothing to do. + return nil + } + // Delete the internal object, it may be replaced with another version below. + c.Datastore.ModelDelete(req) + logger.Info("InferenceModel removed from datastore", "poolRef", existing.Spec.PoolRef, "modelName", existing.Spec.ModelName) - if infModel.Spec.PoolRef.Name == c.PoolNamespacedName.Name { - loggerDefault.Info("Updating datastore", "poolRef", infModel.Spec.PoolRef, "serverPoolName", c.PoolNamespacedName) - loggerDefault.Info("Adding/Updating InferenceModel", "modelName", infModel.Spec.ModelName) - c.Datastore.ModelSet(infModel) - return + // List all InferenceModels with a matching ModelName. + var models v1alpha2.InferenceModelList + if err := c.List(ctx, &models, client.MatchingFields{modelNameKey: existing.Spec.ModelName}, client.InNamespace(c.PoolNamespacedName.Namespace)); err != nil { + return fmt.Errorf("listing models that match the modelName %s: %w", existing.Spec.ModelName, err) + } + if len(models.Items) == 0 { + // No other instances of InferenceModels with this ModelName exists. + return nil + } + + var oldest *v1alpha2.InferenceModel + for i := range models.Items { + m := &models.Items[i] + if m.Spec.ModelName != existing.Spec.ModelName || // The index should filter those out, but just in case! + m.Spec.PoolRef.Name != c.PoolNamespacedName.Name || // We don't care about other pools, we could setup an index on this too! + m.Name == existing.Name { // We don't care about the same object, it could be in the list if it was only marked for deletion, but not yet deleted. + continue + } + if oldest == nil || m.ObjectMeta.CreationTimestamp.Before(&oldest.ObjectMeta.CreationTimestamp) { + oldest = m + } } - loggerDefault.Info("Removing/Not adding InferenceModel", "modelName", infModel.Spec.ModelName) - // If we get here. The model is not relevant to this pool, remove. - c.Datastore.ModelDelete(infModel.Spec.ModelName) + if oldest != nil && c.Datastore.ModelSetIfOlder(oldest) { + logger.Info("InferenceModel replaced.", + "poolRef", oldest.Spec.PoolRef, + "modelName", oldest.Spec.ModelName, + "newInferenceModel", types.NamespacedName{Name: oldest.Name, Namespace: oldest.Namespace}) + } + + return nil +} + +func indexInferenceModelsByModelName(obj client.Object) []string { + m, ok := obj.(*v1alpha2.InferenceModel) + if !ok { + return nil + } + return []string{m.Spec.ModelName} } -func (c *InferenceModelReconciler) SetupWithManager(mgr ctrl.Manager) error { +func (c *InferenceModelReconciler) SetupWithManager(ctx context.Context, mgr ctrl.Manager) error { + // Create an index on ModelName for InferenceModel objects. + indexer := mgr.GetFieldIndexer() + if err := indexer.IndexField(ctx, &v1alpha2.InferenceModel{}, modelNameKey, indexInferenceModelsByModelName); err != nil { + return fmt.Errorf("setting index on ModelName for InferenceModel: %w", err) + } return ctrl.NewControllerManagedBy(mgr). For(&v1alpha2.InferenceModel{}). WithEventFilter(predicate.Funcs{ diff --git a/pkg/epp/controller/inferencemodel_reconciler_test.go b/pkg/epp/controller/inferencemodel_reconciler_test.go index cea7bf427..b9b664f24 100644 --- a/pkg/epp/controller/inferencemodel_reconciler_test.go +++ b/pkg/epp/controller/inferencemodel_reconciler_test.go @@ -18,302 +18,219 @@ package controller import ( "context" - "sync" "testing" + "github.com/google/go-cmp/cmp" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" - - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) var ( - infModel1 = &v1alpha2.InferenceModel{ - Spec: v1alpha2.InferenceModelSpec{ - ModelName: "fake model1", - PoolRef: v1alpha2.PoolObjectReference{Name: "test-pool"}, + pool = utiltest.MakeInferencePool("test-pool1").Namespace("ns1").ObjRef() + infModel1 = utiltest.MakeInferenceModel("model1"). + Namespace(pool.Namespace). + ModelName("fake model1"). + Criticality(v1alpha2.Standard). + CreationTimestamp(metav1.Unix(1000, 0)). + PoolName(pool.Name).ObjRef() + infModel1Pool2 = utiltest.MakeInferenceModel(infModel1.Name). + Namespace(infModel1.Namespace). + ModelName(infModel1.Spec.ModelName). + Criticality(*infModel1.Spec.Criticality). + CreationTimestamp(metav1.Unix(1001, 0)). + PoolName("test-pool2").ObjRef() + infModel1NS2 = utiltest.MakeInferenceModel(infModel1.Name). + Namespace("ns2"). + ModelName(infModel1.Spec.ModelName). + Criticality(*infModel1.Spec.Criticality). + CreationTimestamp(metav1.Unix(1002, 0)). + PoolName(pool.Name).ObjRef() + infModel1Critical = utiltest.MakeInferenceModel(infModel1.Name). + Namespace(infModel1.Namespace). + ModelName(infModel1.Spec.ModelName). + Criticality(v1alpha2.Critical). + CreationTimestamp(metav1.Unix(1003, 0)). + PoolName(pool.Name).ObjRef() + infModel1Deleted = utiltest.MakeInferenceModel(infModel1.Name). + Namespace(infModel1.Namespace). + ModelName(infModel1.Spec.ModelName). + CreationTimestamp(metav1.Unix(1004, 0)). + DeletionTimestamp(). + PoolName(pool.Name).ObjRef() + // Same ModelName, different object with newer creation timestamp + infModel1Newer = utiltest.MakeInferenceModel("model1-newer"). + Namespace(pool.Namespace). + ModelName("fake model1"). + Criticality(v1alpha2.Standard). + CreationTimestamp(metav1.Unix(1005, 0)). + PoolName(pool.Name).ObjRef() + // Same ModelName, different object with older creation timestamp + infModel1Older = utiltest.MakeInferenceModel("model1-older"). + Namespace(pool.Namespace). + ModelName("fake model1"). + Criticality(v1alpha2.Standard). + CreationTimestamp(metav1.Unix(999, 0)). + PoolName(pool.Name).ObjRef() + + infModel2 = utiltest.MakeInferenceModel("model2"). + Namespace(pool.Namespace). + ModelName("fake model2"). + CreationTimestamp(metav1.Unix(1000, 0)). + PoolName(pool.Name).ObjRef() + infModel2NS2 = utiltest.MakeInferenceModel(infModel2.Name). + Namespace("ns2"). + ModelName(infModel2.Spec.ModelName). + CreationTimestamp(metav1.Unix(1000, 0)). + PoolName(pool.Name).ObjRef() +) + +func TestInferenceModelReconciler(t *testing.T) { + tests := []struct { + name string + modelsInStore []*v1alpha2.InferenceModel + modelsInAPIServer []*v1alpha2.InferenceModel + model *v1alpha2.InferenceModel + incomingReq *types.NamespacedName + wantModels []*v1alpha2.InferenceModel + wantResult ctrl.Result + }{ + { + name: "Empty store, add new model", + model: infModel1, + wantModels: []*v1alpha2.InferenceModel{infModel1}, }, - ObjectMeta: metav1.ObjectMeta{ - Name: "test-service", + { + name: "Existing model changed pools", + modelsInStore: []*v1alpha2.InferenceModel{infModel1}, + model: infModel1Pool2, + wantModels: []*v1alpha2.InferenceModel{}, }, - } - infModel1Modified = &v1alpha2.InferenceModel{ - Spec: v1alpha2.InferenceModelSpec{ - ModelName: "fake model1", - PoolRef: v1alpha2.PoolObjectReference{Name: "test-poolio"}, + { + name: "Not found, delete existing model", + modelsInStore: []*v1alpha2.InferenceModel{infModel1}, + incomingReq: &types.NamespacedName{Name: infModel1.Name, Namespace: infModel1.Namespace}, + wantModels: []*v1alpha2.InferenceModel{}, }, - ObjectMeta: metav1.ObjectMeta{ - Name: "test-service", + { + name: "Deletion timestamp set, delete existing model", + modelsInStore: []*v1alpha2.InferenceModel{infModel1}, + model: infModel1Deleted, + wantModels: []*v1alpha2.InferenceModel{}, }, - } - infModel2 = &v1alpha2.InferenceModel{ - Spec: v1alpha2.InferenceModelSpec{ - ModelName: "fake model", - PoolRef: v1alpha2.PoolObjectReference{Name: "test-pool"}, + { + name: "Model referencing a different pool, different pool name but same namespace", + modelsInStore: []*v1alpha2.InferenceModel{infModel1}, + model: infModel1NS2, + wantModels: []*v1alpha2.InferenceModel{infModel1}, }, - ObjectMeta: metav1.ObjectMeta{ - Name: "test-service-2", + { + name: "Model referencing a different pool, same pool name but different namespace", + modelsInStore: []*v1alpha2.InferenceModel{infModel1}, + model: infModel2NS2, + wantModels: []*v1alpha2.InferenceModel{infModel1}, }, - } -) - -func TestUpdateDatastore_InferenceModelReconciler(t *testing.T) { - logger := logutil.NewTestLogger() - - tests := []struct { - name string - datastore datastore.Datastore - incomingService *v1alpha2.InferenceModel - wantInferenceModels *sync.Map - }{ { - name: "No Services registered; valid, new service incoming.", - datastore: datastore.NewFakeDatastore(nil, nil, &v1alpha2.InferencePool{ - Spec: v1alpha2.InferencePoolSpec{ - Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{"app": "vllm"}, - }, - ObjectMeta: metav1.ObjectMeta{ - Name: "test-pool", - ResourceVersion: "Old and boring", - }, - }), - - incomingService: infModel1, - wantInferenceModels: populateServiceMap(infModel1), + name: "Existing model changed pools, replaced with another", + modelsInStore: []*v1alpha2.InferenceModel{infModel1}, + model: infModel1Pool2, + modelsInAPIServer: []*v1alpha2.InferenceModel{infModel1Newer}, + wantModels: []*v1alpha2.InferenceModel{infModel1Newer}, + }, + { + name: "Not found, delete existing model, replaced with another", + modelsInStore: []*v1alpha2.InferenceModel{infModel1}, + incomingReq: &types.NamespacedName{Name: infModel1.Name, Namespace: infModel1.Namespace}, + modelsInAPIServer: []*v1alpha2.InferenceModel{infModel1Newer}, + wantModels: []*v1alpha2.InferenceModel{infModel1Newer}, + }, + { + name: "Deletion timestamp set, delete existing model, replaced with another", + modelsInStore: []*v1alpha2.InferenceModel{infModel1}, + model: infModel1Deleted, + modelsInAPIServer: []*v1alpha2.InferenceModel{infModel1Newer}, + wantModels: []*v1alpha2.InferenceModel{infModel1Newer}, }, { - name: "Removing existing service.", - datastore: datastore.NewFakeDatastore(nil, populateServiceMap(infModel1), &v1alpha2.InferencePool{ - Spec: v1alpha2.InferencePoolSpec{ - Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{"app": "vllm"}, - }, - ObjectMeta: metav1.ObjectMeta{ - Name: "test-pool", - ResourceVersion: "Old and boring", - }, - }), - incomingService: infModel1Modified, - wantInferenceModels: populateServiceMap(), + name: "Older instance of the model observed", + modelsInStore: []*v1alpha2.InferenceModel{infModel1}, + model: infModel1Older, + wantModels: []*v1alpha2.InferenceModel{infModel1Older}, }, { - name: "Unrelated service, do nothing.", - datastore: datastore.NewFakeDatastore(nil, populateServiceMap(infModel1), &v1alpha2.InferencePool{ - Spec: v1alpha2.InferencePoolSpec{ - Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{"app": "vllm"}, - }, - ObjectMeta: metav1.ObjectMeta{ - Name: "test-pool", - ResourceVersion: "Old and boring", - }, - }), - incomingService: &v1alpha2.InferenceModel{ - Spec: v1alpha2.InferenceModelSpec{ - ModelName: "fake model", - PoolRef: v1alpha2.PoolObjectReference{Name: "test-poolio"}, - }, - ObjectMeta: metav1.ObjectMeta{ - Name: "unrelated-service", - }, - }, - wantInferenceModels: populateServiceMap(infModel1), + name: "Model changed criticality", + modelsInStore: []*v1alpha2.InferenceModel{infModel1}, + model: infModel1Critical, + wantModels: []*v1alpha2.InferenceModel{infModel1Critical}, }, { - name: "Add to existing", - datastore: datastore.NewFakeDatastore(nil, populateServiceMap(infModel1), &v1alpha2.InferencePool{ - Spec: v1alpha2.InferencePoolSpec{ - Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{"app": "vllm"}, - }, - ObjectMeta: metav1.ObjectMeta{ - Name: "test-pool", - ResourceVersion: "Old and boring", - }, - }), - incomingService: infModel2, - wantInferenceModels: populateServiceMap(infModel1, infModel2), + name: "Model not found, no matching existing model to delete", + modelsInStore: []*v1alpha2.InferenceModel{infModel1}, + incomingReq: &types.NamespacedName{Name: "non-existent-model", Namespace: pool.Namespace}, + wantModels: []*v1alpha2.InferenceModel{infModel1}, + }, + { + name: "Add to existing", + modelsInStore: []*v1alpha2.InferenceModel{infModel1}, + model: infModel2, + wantModels: []*v1alpha2.InferenceModel{infModel1, infModel2}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - pool, err := test.datastore.PoolGet() - if err != nil { - t.Fatalf("failed to get pool: %v", err) + // Create a fake client with no InferenceModel objects. + scheme := runtime.NewScheme() + _ = v1alpha2.AddToScheme(scheme) + initObjs := []client.Object{} + if test.model != nil { + initObjs = append(initObjs, test.model) } - reconciler := &InferenceModelReconciler{ - Datastore: test.datastore, - PoolNamespacedName: types.NamespacedName{Name: pool.Name}, + for _, m := range test.modelsInAPIServer { + initObjs = append(initObjs, m) } - reconciler.updateDatastore(logger, test.incomingService) - - test.wantInferenceModels.Range(func(k, v any) bool { - _, exist := test.datastore.ModelGet(k.(string)) - if !exist { - t.Fatalf("failed to get model %s", k) - } - return true - }) - }) - } -} - -func TestReconcile_ResourceNotFound(t *testing.T) { - // Set up the scheme. - scheme := runtime.NewScheme() - _ = v1alpha2.AddToScheme(scheme) - - // Create a fake client with no InferenceModel objects. - fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() - - // Create a minimal datastore. - datastore := datastore.NewFakeDatastore(nil, nil, &v1alpha2.InferencePool{ - ObjectMeta: metav1.ObjectMeta{Name: "test-pool"}, - }) - - // Create the reconciler. - reconciler := &InferenceModelReconciler{ - Client: fakeClient, - Scheme: scheme, - Record: record.NewFakeRecorder(10), - Datastore: datastore, - PoolNamespacedName: types.NamespacedName{Name: "test-pool"}, - } - - // Create a request for a non-existent resource. - req := ctrl.Request{NamespacedName: types.NamespacedName{Name: "non-existent-model", Namespace: "default"}} - - // Call Reconcile. - result, err := reconciler.Reconcile(context.Background(), req) - if err != nil { - t.Fatalf("expected no error when resource is not found, got %v", err) - } - - // Check that no requeue is requested. - if result.Requeue || result.RequeueAfter != 0 { - t.Errorf("expected no requeue, got %+v", result) - } -} - -func TestReconcile_ModelMarkedForDeletion(t *testing.T) { - // Set up the scheme. - scheme := runtime.NewScheme() - _ = v1alpha2.AddToScheme(scheme) - - // Create an InferenceModel object. - now := metav1.Now() - existingModel := &v1alpha2.InferenceModel{ - ObjectMeta: metav1.ObjectMeta{ - Name: "existing-model", - Namespace: "default", - DeletionTimestamp: &now, - Finalizers: []string{"finalizer"}, - }, - Spec: v1alpha2.InferenceModelSpec{ - ModelName: "fake-model", - PoolRef: v1alpha2.PoolObjectReference{Name: "test-pool"}, - }, - } - - // Create a fake client with the existing model. - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithObjects(existingModel).Build() - - // Create a minimal datastore. - datastore := datastore.NewFakeDatastore(nil, nil, &v1alpha2.InferencePool{ - ObjectMeta: metav1.ObjectMeta{Name: "test-pool"}, - }) - - // Create the reconciler. - reconciler := &InferenceModelReconciler{ - Client: fakeClient, - Scheme: scheme, - Record: record.NewFakeRecorder(10), - Datastore: datastore, - PoolNamespacedName: types.NamespacedName{Name: "test-pool", Namespace: "default"}, - } - - // Create a request for the existing resource. - req := ctrl.Request{NamespacedName: types.NamespacedName{Name: "existing-model", Namespace: "default"}} - - // Call Reconcile. - result, err := reconciler.Reconcile(context.Background(), req) - if err != nil { - t.Fatalf("expected no error when resource exists, got %v", err) - } - - // Check that no requeue is requested. - if result.Requeue || result.RequeueAfter != 0 { - t.Errorf("expected no requeue, got %+v", result) - } - - // Verify that the datastore was not updated. - if _, exist := datastore.ModelGet(existingModel.Spec.ModelName); exist { - t.Errorf("expected datastore to not contain model %q", existingModel.Spec.ModelName) - } -} - -func TestReconcile_ResourceExists(t *testing.T) { - // Set up the scheme. - scheme := runtime.NewScheme() - _ = v1alpha2.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(initObjs...). + WithIndex(&v1alpha2.InferenceModel{}, modelNameKey, indexInferenceModelsByModelName). + Build() - // Create an InferenceModel object. - existingModel := &v1alpha2.InferenceModel{ - ObjectMeta: metav1.ObjectMeta{ - Name: "existing-model", - Namespace: "default", - }, - Spec: v1alpha2.InferenceModelSpec{ - ModelName: "fake-model", - PoolRef: v1alpha2.PoolObjectReference{Name: "test-pool"}, - }, - } - - // Create a fake client with the existing model. - fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithObjects(existingModel).Build() - - // Create a minimal datastore. - datastore := datastore.NewFakeDatastore(nil, nil, &v1alpha2.InferencePool{ - ObjectMeta: metav1.ObjectMeta{Name: "test-pool"}, - }) - - // Create the reconciler. - reconciler := &InferenceModelReconciler{ - Client: fakeClient, - Scheme: scheme, - Record: record.NewFakeRecorder(10), - Datastore: datastore, - PoolNamespacedName: types.NamespacedName{Name: "test-pool", Namespace: "default"}, - } - - // Create a request for the existing resource. - req := ctrl.Request{NamespacedName: types.NamespacedName{Name: "existing-model", Namespace: "default"}} + datastore := datastore.NewFakeDatastore(nil, test.modelsInStore, pool) + reconciler := &InferenceModelReconciler{ + Client: fakeClient, + Scheme: scheme, + Record: record.NewFakeRecorder(10), + Datastore: datastore, + PoolNamespacedName: types.NamespacedName{Name: pool.Name, Namespace: pool.Namespace}, + } + if test.incomingReq == nil { + test.incomingReq = &types.NamespacedName{Name: test.model.Name, Namespace: test.model.Namespace} + } - // Call Reconcile. - result, err := reconciler.Reconcile(context.Background(), req) - if err != nil { - t.Fatalf("expected no error when resource exists, got %v", err) - } + // Call Reconcile. + result, err := reconciler.Reconcile(context.Background(), ctrl.Request{NamespacedName: *test.incomingReq}) + if err != nil { + t.Fatalf("expected no error when resource is not found, got %v", err) + } - // Check that no requeue is requested. - if result.Requeue || result.RequeueAfter != 0 { - t.Errorf("expected no requeue, got %+v", result) - } + if diff := cmp.Diff(result, test.wantResult); diff != "" { + t.Errorf("Unexpected result diff (+got/-want): %s", diff) + } - // Verify that the datastore was updated. - if _, exist := datastore.ModelGet(existingModel.Spec.ModelName); !exist { - t.Errorf("expected datastore to contain model %q", existingModel.Spec.ModelName) - } -} + if len(test.wantModels) != len(datastore.ModelGetAll()) { + t.Errorf("Unexpected; want: %d, got:%d", len(test.wantModels), len(datastore.ModelGetAll())) + } -func populateServiceMap(services ...*v1alpha2.InferenceModel) *sync.Map { - returnVal := &sync.Map{} + if diff := diffStore(datastore, diffStoreParams{wantPool: pool, wantModels: test.wantModels}); diff != "" { + t.Errorf("Unexpected diff (+got/-want): %s", diff) + } - for _, service := range services { - returnVal.Store(service.Spec.ModelName, service) + }) } - return returnVal } diff --git a/pkg/epp/controller/inferencepool_reconciler_test.go b/pkg/epp/controller/inferencepool_reconciler_test.go index 26b81d9a4..f35b8dc05 100644 --- a/pkg/epp/controller/inferencepool_reconciler_test.go +++ b/pkg/epp/controller/inferencepool_reconciler_test.go @@ -23,7 +23,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" @@ -32,42 +31,44 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" - utiltesting "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" + utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) var ( selector_v1 = map[string]string{"app": "vllm_v1"} selector_v2 = map[string]string{"app": "vllm_v2"} - pool1 = &v1alpha2.InferencePool{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pool1", - Namespace: "pool1-ns", - }, - Spec: v1alpha2.InferencePoolSpec{ - Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{"app": "vllm_v1"}, - TargetPortNumber: 8080, - }, - } - pool2 = &v1alpha2.InferencePool{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pool2", - Namespace: "pool2-ns", - }, - } - pods = []corev1.Pod{ + pool1 = utiltest.MakeInferencePool("pool1"). + Namespace("pool1-ns"). + Selector(selector_v1). + TargetPortNumber(8080).ObjRef() + pool2 = utiltest.MakeInferencePool("pool2").Namespace("pool2-ns").ObjRef() + pods = []*corev1.Pod{ // Two ready pods matching pool1 - utiltesting.MakePod("pod1", "pool1-ns").Labels(selector_v1).ReadyCondition().Obj(), - utiltesting.MakePod("pod2", "pool1-ns").Labels(selector_v1).ReadyCondition().Obj(), + utiltest.MakePod("pod1"). + Namespace("pool1-ns"). + Labels(selector_v1).ReadyCondition().ObjRef(), + utiltest.MakePod("pod2"). + Namespace("pool1-ns"). + Labels(selector_v1). + ReadyCondition().ObjRef(), // A not ready pod matching pool1 - utiltesting.MakePod("pod3", "pool1-ns").Labels(selector_v1).Obj(), + utiltest.MakePod("pod3"). + Namespace("pool1-ns"). + Labels(selector_v1).ObjRef(), // A pod not matching pool1 namespace - utiltesting.MakePod("pod4", "pool2-ns").Labels(selector_v1).ReadyCondition().Obj(), + utiltest.MakePod("pod4"). + Namespace("pool2-ns"). + Labels(selector_v1). + ReadyCondition().ObjRef(), // A ready pod matching pool1 with a new selector - utiltesting.MakePod("pod5", "pool1-ns").Labels(selector_v2).ReadyCondition().Obj(), + utiltest.MakePod("pod5"). + Namespace("pool1-ns"). + Labels(selector_v2). + ReadyCondition().ObjRef(), } ) -func TestReconcile_InferencePoolReconciler(t *testing.T) { +func TestInferencePoolReconciler(t *testing.T) { // The best practice is to use table-driven tests, however in this scaenario it seems // more logical to do a single test with steps that depend on each other. @@ -79,7 +80,7 @@ func TestReconcile_InferencePoolReconciler(t *testing.T) { // Create a fake client with the pool and the pods. initialObjects := []client.Object{pool1, pool2} for i := range pods { - initialObjects = append(initialObjects, &pods[i]) + initialObjects = append(initialObjects, pods[i]) } fakeClient := fake.NewClientBuilder(). WithScheme(scheme). @@ -98,11 +99,10 @@ func TestReconcile_InferencePoolReconciler(t *testing.T) { if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := diffPool(datastore, pool1, []string{"pod1", "pod2"}); diff != "" { + if diff := diffStore(datastore, diffStoreParams{wantPool: pool1, wantPods: []string{"pod1", "pod2"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } - // Step 2: update the pool selector to include more pods newPool1 := &v1alpha2.InferencePool{} if err := fakeClient.Get(ctx, req.NamespacedName, newPool1); err != nil { t.Errorf("Unexpected pool get error: %v", err) @@ -115,7 +115,7 @@ func TestReconcile_InferencePoolReconciler(t *testing.T) { if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := diffPool(datastore, newPool1, []string{"pod5"}); diff != "" { + if diff := diffStore(datastore, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } @@ -130,7 +130,7 @@ func TestReconcile_InferencePoolReconciler(t *testing.T) { if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := diffPool(datastore, newPool1, []string{"pod5"}); diff != "" { + if diff := diffStore(datastore, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5"}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } @@ -144,19 +144,42 @@ func TestReconcile_InferencePoolReconciler(t *testing.T) { if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - if diff := diffPool(datastore, nil, []string{}); diff != "" { + if diff := diffStore(datastore, diffStoreParams{wantPods: []string{}}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } } -func diffPool(datastore datastore.Datastore, wantPool *v1alpha2.InferencePool, wantPods []string) string { +type diffStoreParams struct { + wantPool *v1alpha2.InferencePool + wantPods []string + wantModels []*v1alpha2.InferenceModel +} + +func diffStore(datastore datastore.Datastore, params diffStoreParams) string { gotPool, _ := datastore.PoolGet() - if diff := cmp.Diff(wantPool, gotPool); diff != "" { - return diff + if diff := cmp.Diff(params.wantPool, gotPool); diff != "" { + return "pool:" + diff + } + + // Default wantPods if not set because PodGetAll returns an empty slice when empty. + if params.wantPods == nil { + params.wantPods = []string{} } gotPods := []string{} for _, pm := range datastore.PodGetAll() { gotPods = append(gotPods, pm.NamespacedName.Name) } - return cmp.Diff(wantPods, gotPods, cmpopts.SortSlices(func(a, b string) bool { return a < b })) + if diff := cmp.Diff(params.wantPods, gotPods, cmpopts.SortSlices(func(a, b string) bool { return a < b })); diff != "" { + return "pods:" + diff + } + + // Default wantModels if not set because ModelGetAll returns an empty slice when empty. + if params.wantModels == nil { + params.wantModels = []*v1alpha2.InferenceModel{} + } + gotModels := datastore.ModelGetAll() + if diff := utiltest.DiffModelLists(params.wantModels, gotModels); diff != "" { + return "models:" + diff + } + return "" } diff --git a/pkg/epp/controller/pod_reconciler.go b/pkg/epp/controller/pod_reconciler.go index 5b0c25c99..717d9f60e 100644 --- a/pkg/epp/controller/pod_reconciler.go +++ b/pkg/epp/controller/pod_reconciler.go @@ -75,7 +75,7 @@ func (c *PodReconciler) SetupWithManager(mgr ctrl.Manager) error { func (c *PodReconciler) updateDatastore(logger logr.Logger, pod *corev1.Pod) { namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} if !pod.DeletionTimestamp.IsZero() || !c.Datastore.PoolLabelsMatch(pod.Labels) || !podIsReady(pod) { - logger.V(logutil.DEFAULT).Info("Pod removed or not added", "name", namespacedName) + logger.V(logutil.DEBUG).Info("Pod removed or not added", "name", namespacedName) c.Datastore.PodDelete(namespacedName) } else { if c.Datastore.PodUpdateOrAddIfNotExist(pod) { diff --git a/pkg/epp/controller/pod_reconciler_test.go b/pkg/epp/controller/pod_reconciler_test.go index 8a39dbabd..575762130 100644 --- a/pkg/epp/controller/pod_reconciler_test.go +++ b/pkg/epp/controller/pod_reconciler_test.go @@ -18,13 +18,11 @@ package controller import ( "context" - "sync" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" @@ -33,6 +31,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + utiltest "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) var ( @@ -42,8 +41,7 @@ var ( basePod11 = &datastore.PodMetrics{Pod: datastore.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}, Address: "address-11", ScrapePath: "/metrics", ScrapePort: 8000}} ) -func TestUpdateDatastore_PodReconciler(t *testing.T) { - now := metav1.Now() +func TestPodReconciler(t *testing.T) { tests := []struct { name string datastore datastore.Datastore @@ -53,7 +51,7 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }{ { name: "Add new pod", - datastore: datastore.NewFakeDatastore(populateMap(basePod1, basePod2), nil, &v1alpha2.InferencePool{ + datastore: datastore.NewFakeDatastore([]*datastore.PodMetrics{basePod1, basePod2}, nil, &v1alpha2.InferencePool{ Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ @@ -61,28 +59,15 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }), - incomingPod: &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: basePod3.NamespacedName.Name, - Labels: map[string]string{ - "some-key": "some-val", - }, - }, - Status: corev1.PodStatus{ - PodIP: basePod3.Address, - Conditions: []corev1.PodCondition{ - { - Type: corev1.PodReady, - Status: corev1.ConditionTrue, - }, - }, - }, - }, + incomingPod: utiltest.MakePod(basePod3.NamespacedName.Name). + Labels(map[string]string{"some-key": "some-val"}). + IP(basePod3.Address). + ReadyCondition().ObjRef(), wantPods: []datastore.Pod{basePod1.Pod, basePod2.Pod, basePod3.Pod}, }, { name: "Update pod1 address", - datastore: datastore.NewFakeDatastore(populateMap(basePod1, basePod2), nil, &v1alpha2.InferencePool{ + datastore: datastore.NewFakeDatastore([]*datastore.PodMetrics{basePod1, basePod2}, nil, &v1alpha2.InferencePool{ Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ @@ -90,28 +75,15 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }), - incomingPod: &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: basePod11.NamespacedName.Name, - Labels: map[string]string{ - "some-key": "some-val", - }, - }, - Status: corev1.PodStatus{ - PodIP: basePod11.Address, - Conditions: []corev1.PodCondition{ - { - Type: corev1.PodReady, - Status: corev1.ConditionTrue, - }, - }, - }, - }, + incomingPod: utiltest.MakePod(basePod11.NamespacedName.Name). + Labels(map[string]string{"some-key": "some-val"}). + IP(basePod11.Address). + ReadyCondition().ObjRef(), wantPods: []datastore.Pod{basePod11.Pod, basePod2.Pod}, }, { name: "Delete pod with DeletionTimestamp", - datastore: datastore.NewFakeDatastore(populateMap(basePod1, basePod2), nil, &v1alpha2.InferencePool{ + datastore: datastore.NewFakeDatastore([]*datastore.PodMetrics{basePod1, basePod2}, nil, &v1alpha2.InferencePool{ Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ @@ -119,29 +91,15 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }), - incomingPod: &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pod1", - Labels: map[string]string{ - "some-key": "some-val", - }, - DeletionTimestamp: &now, - Finalizers: []string{"finalizer"}, - }, - Status: corev1.PodStatus{ - Conditions: []corev1.PodCondition{ - { - Type: corev1.PodReady, - Status: corev1.ConditionTrue, - }, - }, - }, - }, + incomingPod: utiltest.MakePod("pod1"). + Labels(map[string]string{"some-key": "some-val"}). + DeletionTimestamp(). + ReadyCondition().ObjRef(), wantPods: []datastore.Pod{basePod2.Pod}, }, { name: "Delete notfound pod", - datastore: datastore.NewFakeDatastore(populateMap(basePod1, basePod2), nil, &v1alpha2.InferencePool{ + datastore: datastore.NewFakeDatastore([]*datastore.PodMetrics{basePod1, basePod2}, nil, &v1alpha2.InferencePool{ Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ @@ -154,7 +112,7 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, { name: "New pod, not ready, valid selector", - datastore: datastore.NewFakeDatastore(populateMap(basePod1, basePod2), nil, &v1alpha2.InferencePool{ + datastore: datastore.NewFakeDatastore([]*datastore.PodMetrics{basePod1, basePod2}, nil, &v1alpha2.InferencePool{ Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ @@ -162,27 +120,13 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }), - incomingPod: &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pod3", - Labels: map[string]string{ - "some-key": "some-val", - }, - }, - Status: corev1.PodStatus{ - Conditions: []corev1.PodCondition{ - { - Type: corev1.PodReady, - Status: corev1.ConditionFalse, - }, - }, - }, - }, + incomingPod: utiltest.MakePod("pod3"). + Labels(map[string]string{"some-key": "some-val"}).ObjRef(), wantPods: []datastore.Pod{basePod1.Pod, basePod2.Pod}, }, { name: "Remove pod that does not match selector", - datastore: datastore.NewFakeDatastore(populateMap(basePod1, basePod2), nil, &v1alpha2.InferencePool{ + datastore: datastore.NewFakeDatastore([]*datastore.PodMetrics{basePod1, basePod2}, nil, &v1alpha2.InferencePool{ Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ @@ -190,27 +134,14 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }), - incomingPod: &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pod1", - Labels: map[string]string{ - "some-wrong-key": "some-val", - }, - }, - Status: corev1.PodStatus{ - Conditions: []corev1.PodCondition{ - { - Type: corev1.PodReady, - Status: corev1.ConditionTrue, - }, - }, - }, - }, + incomingPod: utiltest.MakePod("pod1"). + Labels(map[string]string{"some-wrong-key": "some-val"}). + ReadyCondition().ObjRef(), wantPods: []datastore.Pod{basePod2.Pod}, }, { name: "Remove pod that is not ready", - datastore: datastore.NewFakeDatastore(populateMap(basePod1, basePod2), nil, &v1alpha2.InferencePool{ + datastore: datastore.NewFakeDatastore([]*datastore.PodMetrics{basePod1, basePod2}, nil, &v1alpha2.InferencePool{ Spec: v1alpha2.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{ @@ -218,22 +149,9 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }), - incomingPod: &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pod1", - Labels: map[string]string{ - "some-wrong-key": "some-val", - }, - }, - Status: corev1.PodStatus{ - Conditions: []corev1.PodCondition{ - { - Type: corev1.PodReady, - Status: corev1.ConditionFalse, - }, - }, - }, - }, + incomingPod: utiltest.MakePod("pod1"). + Labels(map[string]string{"some-wrong-key": "some-val"}). + ReadyCondition().ObjRef(), wantPods: []datastore.Pod{basePod2.Pod}, }, } @@ -274,11 +192,3 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }) } } - -func populateMap(pods ...*datastore.PodMetrics) *sync.Map { - newMap := &sync.Map{} - for _, pod := range pods { - newMap.Store(pod.NamespacedName, &datastore.PodMetrics{Pod: datastore.Pod{NamespacedName: pod.NamespacedName, Address: pod.Address, ScrapePort: pod.ScrapePort, ScrapePath: pod.ScrapePath}}) - } - return newMap -} diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index c5bbddcfd..71a93f6a6 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -32,6 +32,10 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) +var ( + errPoolNotSynced = errors.New("InferencePool is not initialized in data store") +) + // The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api) type Datastore interface { // InferencePool operations @@ -41,9 +45,11 @@ type Datastore interface { PoolLabelsMatch(podLabels map[string]string) bool // InferenceModel operations - ModelSet(infModel *v1alpha2.InferenceModel) - ModelGet(modelName string) (*v1alpha2.InferenceModel, bool) - ModelDelete(modelName string) + ModelSetIfOlder(infModel *v1alpha2.InferenceModel) bool + ModelGetByModelName(modelName string) (*v1alpha2.InferenceModel, bool) + ModelGetByObjName(namespacedName types.NamespacedName) (*v1alpha2.InferenceModel, bool) + ModelDelete(namespacedName types.NamespacedName) + ModelGetAll() []*v1alpha2.InferenceModel // PodMetrics operations PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool @@ -61,22 +67,29 @@ type Datastore interface { func NewDatastore() Datastore { store := &datastore{ - poolMu: sync.RWMutex{}, - models: &sync.Map{}, - pods: &sync.Map{}, + poolMu: sync.RWMutex{}, + modelsMu: sync.RWMutex{}, + modelsByModelName: make(map[string]*v1alpha2.InferenceModel), + modelsByObjName: make(map[types.NamespacedName]*v1alpha2.InferenceModel), + pods: &sync.Map{}, } return store } // Used for test only -func NewFakeDatastore(pods, models *sync.Map, pool *v1alpha2.InferencePool) Datastore { +func NewFakeDatastore(pods []*PodMetrics, models []*v1alpha2.InferenceModel, pool *v1alpha2.InferencePool) Datastore { store := NewDatastore() - if pods != nil { - store.(*datastore).pods = pods + + for _, pod := range pods { + // Making a copy since in tests we may use the same global PodMetric across tests. + p := *pod + store.(*datastore).pods.Store(pod.NamespacedName, &p) } - if models != nil { - store.(*datastore).models = models + + for _, m := range models { + store.ModelSetIfOlder(m) } + if pool != nil { store.(*datastore).pool = pool } @@ -85,9 +98,13 @@ func NewFakeDatastore(pods, models *sync.Map, pool *v1alpha2.InferencePool) Data type datastore struct { // poolMu is used to synchronize access to the inferencePool. - poolMu sync.RWMutex - pool *v1alpha2.InferencePool - models *sync.Map + poolMu sync.RWMutex + pool *v1alpha2.InferencePool + modelsMu sync.RWMutex + // key: types.NamespacedName, value: *InferenceModel + modelsByObjName map[types.NamespacedName]*v1alpha2.InferenceModel + // key: InferenceModel.Spec.ModelName, value: *InferenceModel + modelsByModelName map[string]*v1alpha2.InferenceModel // key: types.NamespacedName, value: *PodMetrics pods *sync.Map } @@ -96,7 +113,10 @@ func (ds *datastore) Clear() { ds.poolMu.Lock() defer ds.poolMu.Unlock() ds.pool = nil - ds.models.Clear() + ds.modelsMu.Lock() + ds.modelsByModelName = make(map[string]*v1alpha2.InferenceModel) + ds.modelsByObjName = make(map[types.NamespacedName]*v1alpha2.InferenceModel) + ds.modelsMu.Unlock() ds.pods.Clear() } @@ -111,7 +131,7 @@ func (ds *datastore) PoolGet() (*v1alpha2.InferencePool, error) { ds.poolMu.RLock() defer ds.poolMu.RUnlock() if !ds.PoolHasSynced() { - return nil, errors.New("InferencePool is not initialized in data store") + return nil, errPoolNotSynced } return ds.pool, nil } @@ -123,26 +143,84 @@ func (ds *datastore) PoolHasSynced() bool { } func (ds *datastore) PoolLabelsMatch(podLabels map[string]string) bool { + ds.poolMu.RLock() + defer ds.poolMu.RUnlock() poolSelector := selectorFromInferencePoolSelector(ds.pool.Spec.Selector) podSet := labels.Set(podLabels) return poolSelector.Matches(podSet) } // /// InferenceModel APIs /// -func (ds *datastore) ModelSet(infModel *v1alpha2.InferenceModel) { - ds.models.Store(infModel.Spec.ModelName, infModel) +func (ds *datastore) ModelSetIfOlder(infModel *v1alpha2.InferenceModel) bool { + ds.modelsMu.Lock() + defer ds.modelsMu.Unlock() + + // Check first if the existing model is older. + // One exception is if the incoming model object is the same, in which case, we should not + // check for creation timestamp since that means the object was re-created, and so we should override. + existing, exists := ds.modelsByModelName[infModel.Spec.ModelName] + if exists { + diffObj := infModel.Name != existing.Name || infModel.Namespace != existing.Namespace + if diffObj && existing.ObjectMeta.CreationTimestamp.Before(&infModel.ObjectMeta.CreationTimestamp) { + return false + } + } + + // Deleting the model first ensures that the two maps are always aligned. + namespacedName := types.NamespacedName{Name: infModel.Name, Namespace: infModel.Namespace} + ds.modelDeleteByObjName(namespacedName) + ds.modelDeleteByModelName(infModel.Spec.ModelName) + ds.modelsByModelName[infModel.Spec.ModelName] = infModel + ds.modelsByObjName[namespacedName] = infModel + return true } -func (ds *datastore) ModelGet(modelName string) (*v1alpha2.InferenceModel, bool) { - infModel, ok := ds.models.Load(modelName) - if ok { - return infModel.(*v1alpha2.InferenceModel), true +func (ds *datastore) ModelGetByModelName(modelName string) (*v1alpha2.InferenceModel, bool) { + ds.modelsMu.RLock() + defer ds.modelsMu.RUnlock() + m, exists := ds.modelsByModelName[modelName] + return m, exists +} + +func (ds *datastore) ModelGetByObjName(namespacedName types.NamespacedName) (*v1alpha2.InferenceModel, bool) { + ds.modelsMu.RLock() + defer ds.modelsMu.RUnlock() + m, exists := ds.modelsByObjName[namespacedName] + return m, exists +} + +func (ds *datastore) ModelDelete(namespacedName types.NamespacedName) { + ds.modelsMu.Lock() + defer ds.modelsMu.Unlock() + ds.modelDeleteByObjName(namespacedName) +} + +func (ds *datastore) modelDeleteByObjName(namespacedName types.NamespacedName) { + infModel, ok := ds.modelsByObjName[namespacedName] + if !ok { + return } - return nil, false + delete(ds.modelsByObjName, namespacedName) + delete(ds.modelsByModelName, infModel.Spec.ModelName) } -func (ds *datastore) ModelDelete(modelName string) { - ds.models.Delete(modelName) +func (ds *datastore) modelDeleteByModelName(modelName string) { + infModel, ok := ds.modelsByModelName[modelName] + if !ok { + return + } + delete(ds.modelsByObjName, types.NamespacedName{Name: infModel.Name, Namespace: infModel.Namespace}) + delete(ds.modelsByModelName, modelName) +} + +func (ds *datastore) ModelGetAll() []*v1alpha2.InferenceModel { + ds.modelsMu.RLock() + defer ds.modelsMu.RUnlock() + res := []*v1alpha2.InferenceModel{} + for _, v := range ds.modelsByObjName { + res = append(res, v) + } + return res } // /// Pods/endpoints APIs /// diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index 2af365413..e6c172d58 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -19,49 +19,175 @@ package datastore import ( "testing" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) -func TestHasSynced(t *testing.T) { +func TestPool(t *testing.T) { + pool1Selector := map[string]string{"app": "vllm_v1"} + pool1 := testutil.MakeInferencePool("pool1"). + Namespace("default"). + Selector(pool1Selector).ObjRef() tests := []struct { - name string - inferencePool *v1alpha2.InferencePool - hasSynced bool + name string + inferencePool *v1alpha2.InferencePool + labels map[string]string + wantSynced bool + wantPool *v1alpha2.InferencePool + wantErr error + wantLabelsMatch bool }{ { - name: "Ready when InferencePool exists in data store", - inferencePool: &v1alpha2.InferencePool{ - ObjectMeta: v1.ObjectMeta{ - Name: "test-pool", - Namespace: "default", - }, - }, - hasSynced: true, + name: "Ready when InferencePool exists in data store", + inferencePool: pool1, + labels: pool1Selector, + wantSynced: true, + wantPool: pool1, + wantLabelsMatch: true, }, { - name: "Not ready when InferencePool is nil in data store", - inferencePool: nil, - hasSynced: false, + name: "Labels not matched", + inferencePool: pool1, + labels: map[string]string{"app": "vllm_v2"}, + wantSynced: true, + wantPool: pool1, + wantLabelsMatch: false, + }, + { + name: "Not ready when InferencePool is nil in data store", + wantErr: errPoolNotSynced, + wantSynced: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { datastore := NewDatastore() - // Set the inference pool - if tt.inferencePool != nil { - datastore.PoolSet(tt.inferencePool) + datastore.PoolSet(tt.inferencePool) + gotPool, gotErr := datastore.PoolGet() + if diff := cmp.Diff(tt.wantErr, gotErr, cmpopts.EquateErrors()); diff != "" { + t.Errorf("Unexpected error diff (+got/-want): %s", diff) + } + if diff := cmp.Diff(tt.wantPool, gotPool); diff != "" { + t.Errorf("Unexpected pool diff (+got/-want): %s", diff) } - // Check if the data store has been initialized - hasSynced := datastore.PoolHasSynced() - if hasSynced != tt.hasSynced { - t.Errorf("IsInitialized() = %v, want %v", hasSynced, tt.hasSynced) + gotSynced := datastore.PoolHasSynced() + if diff := cmp.Diff(tt.wantSynced, gotSynced); diff != "" { + t.Errorf("Unexpected synced diff (+got/-want): %s", diff) + } + if tt.labels != nil { + gotLabelsMatch := datastore.PoolLabelsMatch(tt.labels) + if diff := cmp.Diff(tt.wantLabelsMatch, gotLabelsMatch); diff != "" { + t.Errorf("Unexpected labels match diff (+got/-want): %s", diff) + } } }) } } +func TestModel(t *testing.T) { + chatModel := "chat" + tsModel := "tweet-summary" + model1ts := testutil.MakeInferenceModel("model1"). + CreationTimestamp(metav1.Unix(1000, 0)). + ModelName(tsModel).ObjRef() + // Same model name as model1ts, different object name. + model2ts := testutil.MakeInferenceModel("model2"). + CreationTimestamp(metav1.Unix(1001, 0)). + ModelName(tsModel).ObjRef() + // Same model name as model1ts, newer timestamp + model1tsNewer := testutil.MakeInferenceModel("model1"). + CreationTimestamp(metav1.Unix(1002, 0)). + Criticality(v1alpha2.Critical). + ModelName(tsModel).ObjRef() + model2tsNewer := testutil.MakeInferenceModel("model2"). + CreationTimestamp(metav1.Unix(1003, 0)). + ModelName(tsModel).ObjRef() + // Same object name as model2ts, different model name. + model2chat := testutil.MakeInferenceModel(model2ts.Name). + CreationTimestamp(metav1.Unix(1005, 0)). + ModelName(chatModel).ObjRef() + + ds := NewDatastore() + dsImpl := ds.(*datastore) + + // Step 1: add model1 with tweet-summary as modelName. + ds.ModelSetIfOlder(model1ts) + if diff := diffModelMaps(dsImpl, []*v1alpha2.InferenceModel{model1ts}); diff != "" { + t.Errorf("Unexpected models diff: %s", diff) + } + + // Step 2: set model1 with the same modelName, but with criticality set and newer creation timestamp, should update. + ds.ModelSetIfOlder(model1tsNewer) + if diff := diffModelMaps(dsImpl, []*v1alpha2.InferenceModel{model1tsNewer}); diff != "" { + t.Errorf("Unexpected models diff: %s", diff) + } + + // Step 3: set model2 with the same modelName, but newer creation timestamp, should not update. + ds.ModelSetIfOlder(model2tsNewer) + if diff := diffModelMaps(dsImpl, []*v1alpha2.InferenceModel{model1tsNewer}); diff != "" { + t.Errorf("Unexpected models diff: %s", diff) + } + + // Step 4: set model2 with the same modelName, but older creation timestamp, should update. + ds.ModelSetIfOlder(model2ts) + if diff := diffModelMaps(dsImpl, []*v1alpha2.InferenceModel{model2ts}); diff != "" { + t.Errorf("Unexpected models diff: %s", diff) + } + + // Step 5: set model2 updated with a new modelName, should update modelName. + ds.ModelSetIfOlder(model2chat) + if diff := diffModelMaps(dsImpl, []*v1alpha2.InferenceModel{model2chat}); diff != "" { + t.Errorf("Unexpected models diff: %s", diff) + } + + // Step 6: set model1 with the tweet-summary modelName, both models should exist. + ds.ModelSetIfOlder(model1ts) + if diff := diffModelMaps(dsImpl, []*v1alpha2.InferenceModel{model2chat, model1ts}); diff != "" { + t.Errorf("Unexpected models diff: %s", diff) + } + + // Step 7: getting the models by model name, chat -> model2; tweet-summary -> model1 + gotChat, exists := ds.ModelGetByModelName(chatModel) + if !exists { + t.Error("Chat model should exist!") + } + if diff := cmp.Diff(model2chat, gotChat); diff != "" { + t.Errorf("Unexpected chat model diff: %s", diff) + } + gotSummary, exists := ds.ModelGetByModelName(tsModel) + if !exists { + t.Error("Summary model should exist!") + } + if diff := cmp.Diff(model1ts, gotSummary); diff != "" { + t.Errorf("Unexpected summary model diff: %s", diff) + } + + // Step 6: delete model1, summary model should not exist. + ds.ModelDelete(types.NamespacedName{Name: model1ts.Name, Namespace: model1ts.Namespace}) + _, exists = ds.ModelGetByModelName(tsModel) + if exists { + t.Error("Summary model should not exist!") + } + +} + +func diffModelMaps(ds *datastore, want []*v1alpha2.InferenceModel) string { + byObjName := ds.ModelGetAll() + byModelName := []*v1alpha2.InferenceModel{} + for _, v := range ds.modelsByModelName { + byModelName = append(byModelName, v) + } + if diff := testutil.DiffModelLists(byObjName, byModelName); diff != "" { + return "Inconsistent maps diff: " + diff + } + return testutil.DiffModelLists(want, byObjName) +} + func TestRandomWeightedDraw(t *testing.T) { logger := logutil.NewTestLogger() tests := []struct { diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index c6cfdda29..32062da33 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -64,7 +64,7 @@ func (s *Server) HandleRequestBody( // NOTE: The nil checking for the modelObject means that we DO allow passthrough currently. // This might be a security risk in the future where adapters not registered in the InferenceModel // are able to be requested by using their distinct name. - modelObj, exist := s.datastore.ModelGet(model) + modelObj, exist := s.datastore.ModelGetByModelName(model) if !exist { return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)} } diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index 6e6b68b1a..f3d9b6ac0 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -85,7 +85,7 @@ func NewDefaultExtProcServerRunner() *ExtProcServerRunner { } // SetupWithManager sets up the runner with the given manager. -func (r *ExtProcServerRunner) SetupWithManager(mgr ctrl.Manager) error { +func (r *ExtProcServerRunner) SetupWithManager(ctx context.Context, mgr ctrl.Manager) error { // Create the controllers and register them with the manager if err := (&controller.InferencePoolReconciler{ Datastore: r.Datastore, @@ -109,7 +109,7 @@ func (r *ExtProcServerRunner) SetupWithManager(mgr ctrl.Manager) error { Namespace: r.PoolNamespace, }, Record: mgr.GetEventRecorderFor("InferenceModel"), - }).SetupWithManager(mgr); err != nil { + }).SetupWithManager(ctx, mgr); err != nil { return fmt.Errorf("failed setting up InferenceModelReconciler: %w", err) } diff --git a/pkg/epp/test/utils.go b/pkg/epp/test/utils.go index 6a75ed2ff..a916bda2e 100644 --- a/pkg/epp/test/utils.go +++ b/pkg/epp/test/utils.go @@ -53,14 +53,15 @@ func StartExtProc( pmc := &backend.FakePodMetricsClient{Res: pms} datastore := datastore.NewDatastore() for _, m := range models { - datastore.ModelSet(m) + datastore.ModelSetIfOlder(m) } for _, pm := range pods { - pod := utiltesting.MakePod(pm.NamespacedName.Name, pm.NamespacedName.Namespace). + pod := utiltesting.MakePod(pm.NamespacedName.Name). + Namespace(pm.NamespacedName.Namespace). ReadyCondition(). IP(pm.Address). - Obj() - datastore.PodUpdateOrAddIfNotExist(&pod) + ObjRef() + datastore.PodUpdateOrAddIfNotExist(pod) datastore.PodUpdateMetricsIfExist(pm.NamespacedName, &pm.Metrics) } pp := backend.NewProvider(pmc, datastore) diff --git a/pkg/epp/util/testing/diff.go b/pkg/epp/util/testing/diff.go new file mode 100644 index 000000000..34b0b8caf --- /dev/null +++ b/pkg/epp/util/testing/diff.go @@ -0,0 +1,27 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package testing + +import ( + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" +) + +func DiffModelLists(want, got []*v1alpha2.InferenceModel) string { + return cmp.Diff(want, got, cmpopts.SortSlices(func(a, b *v1alpha2.InferenceModel) bool { return a.Name < b.Name })) +} diff --git a/pkg/epp/util/testing/wrappers.go b/pkg/epp/util/testing/wrappers.go index 7c9a29394..bfcf2690c 100644 --- a/pkg/epp/util/testing/wrappers.go +++ b/pkg/epp/util/testing/wrappers.go @@ -19,6 +19,7 @@ package testing import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" ) // PodWrapper wraps a Pod. @@ -27,12 +28,11 @@ type PodWrapper struct { } // MakePod creates a wrapper for a Pod. -func MakePod(podName, ns string) *PodWrapper { +func MakePod(podName string) *PodWrapper { return &PodWrapper{ corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ - Name: podName, - Namespace: ns, + Name: podName, }, Spec: corev1.PodSpec{}, Status: corev1.PodStatus{}, @@ -40,6 +40,11 @@ func MakePod(podName, ns string) *PodWrapper { } } +func (p *PodWrapper) Namespace(ns string) *PodWrapper { + p.ObjectMeta.Namespace = ns + return p +} + // Labels sets the pod labels. func (p *PodWrapper) Labels(labels map[string]string) *PodWrapper { p.ObjectMeta.Labels = labels @@ -60,7 +65,109 @@ func (p *PodWrapper) IP(ip string) *PodWrapper { return p } +func (p *PodWrapper) DeletionTimestamp() *PodWrapper { + now := metav1.Now() + p.ObjectMeta.DeletionTimestamp = &now + p.ObjectMeta.Finalizers = []string{"finalizer"} + return p +} + // Obj returns the wrapped Pod. -func (p *PodWrapper) Obj() corev1.Pod { - return p.Pod +func (p *PodWrapper) ObjRef() *corev1.Pod { + return &p.Pod +} + +// InferenceModelWrapper wraps an InferenceModel. +type InferenceModelWrapper struct { + v1alpha2.InferenceModel +} + +// MakeInferenceModel creates a wrapper for a InferenceModel. +func MakeInferenceModel(name string) *InferenceModelWrapper { + return &InferenceModelWrapper{ + v1alpha2.InferenceModel{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + }, + Spec: v1alpha2.InferenceModelSpec{}, + }, + } +} + +func (m *InferenceModelWrapper) Namespace(ns string) *InferenceModelWrapper { + m.ObjectMeta.Namespace = ns + return m +} + +// Obj returns the wrapped InferenceModel. +func (m *InferenceModelWrapper) ObjRef() *v1alpha2.InferenceModel { + return &m.InferenceModel +} + +func (m *InferenceModelWrapper) ModelName(modelName string) *InferenceModelWrapper { + m.Spec.ModelName = modelName + return m +} + +func (m *InferenceModelWrapper) PoolName(poolName string) *InferenceModelWrapper { + m.Spec.PoolRef = v1alpha2.PoolObjectReference{Name: poolName} + return m +} + +func (m *InferenceModelWrapper) Criticality(criticality v1alpha2.Criticality) *InferenceModelWrapper { + m.Spec.Criticality = &criticality + return m +} + +func (m *InferenceModelWrapper) DeletionTimestamp() *InferenceModelWrapper { + now := metav1.Now() + m.ObjectMeta.DeletionTimestamp = &now + m.ObjectMeta.Finalizers = []string{"finalizer"} + return m +} + +func (m *InferenceModelWrapper) CreationTimestamp(t metav1.Time) *InferenceModelWrapper { + m.ObjectMeta.CreationTimestamp = t + return m +} + +// InferencePoolWrapper wraps an InferencePool. +type InferencePoolWrapper struct { + v1alpha2.InferencePool +} + +// MakeInferencePool creates a wrapper for a InferencePool. +func MakeInferencePool(name string) *InferencePoolWrapper { + return &InferencePoolWrapper{ + v1alpha2.InferencePool{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + }, + Spec: v1alpha2.InferencePoolSpec{}, + }, + } +} + +func (m *InferencePoolWrapper) Namespace(ns string) *InferencePoolWrapper { + m.ObjectMeta.Namespace = ns + return m +} + +func (m *InferencePoolWrapper) Selector(selector map[string]string) *InferencePoolWrapper { + s := make(map[v1alpha2.LabelKey]v1alpha2.LabelValue) + for k, v := range selector { + s[v1alpha2.LabelKey(k)] = v1alpha2.LabelValue(v) + } + m.Spec.Selector = s + return m +} + +func (m *InferencePoolWrapper) TargetPortNumber(p int32) *InferencePoolWrapper { + m.Spec.TargetPortNumber = p + return m +} + +// Obj returns the wrapped InferencePool. +func (m *InferencePoolWrapper) ObjRef() *v1alpha2.InferencePool { + return &m.InferencePool } diff --git a/test/e2e/e2e_suite_test.go b/test/e2e/e2e_suite_test.go index 14ee738f3..3d068c9f7 100644 --- a/test/e2e/e2e_suite_test.go +++ b/test/e2e/e2e_suite_test.go @@ -245,11 +245,6 @@ func createModelServer(k8sClient client.Client, secretPath, deployPath string) { // Wait for the deployment to be available. testutils.DeploymentAvailable(ctx, k8sClient, deploy, modelReadyTimeout, interval) - - // Wait for the service to exist. - testutils.EventuallyExists(ctx, func() error { - return k8sClient.Get(ctx, types.NamespacedName{Namespace: nsName, Name: modelServerName}, &corev1.Service{}) - }, existsTimeout, interval) } // createEnvoy creates the envoy proxy resources used for testing from the given filePath. diff --git a/test/integration/hermetic_test.go b/test/integration/hermetic_test.go index 85c49913a..dc64758ea 100644 --- a/test/integration/hermetic_test.go +++ b/test/integration/hermetic_test.go @@ -360,11 +360,12 @@ func setUpHermeticServer(podMetrics []*datastore.PodMetrics) (client extProcPb.E go func() { serverRunner.Datastore.PodDeleteAll() for _, pm := range podMetrics { - pod := utiltesting.MakePod(pm.NamespacedName.Name, pm.NamespacedName.Namespace). + pod := utiltesting.MakePod(pm.NamespacedName.Name). + Namespace(pm.NamespacedName.Namespace). ReadyCondition(). IP(pm.Address). - Obj() - serverRunner.Datastore.PodUpdateOrAddIfNotExist(&pod) + ObjRef() + serverRunner.Datastore.PodUpdateOrAddIfNotExist(pod) serverRunner.Datastore.PodUpdateMetricsIfExist(pm.NamespacedName, &pm.Metrics) } serverRunner.Provider = backend.NewProvider(pmc, serverRunner.Datastore) @@ -429,7 +430,7 @@ func BeforeSuit(t *testing.T) func() { serverRunner.Datastore = datastore.NewDatastore() serverRunner.SecureServing = false - if err := serverRunner.SetupWithManager(mgr); err != nil { + if err := serverRunner.SetupWithManager(context.Background(), mgr); err != nil { logutil.Fatal(logger, err, "Failed to setup server runner") } @@ -475,7 +476,7 @@ func BeforeSuit(t *testing.T) func() { } assert.EventuallyWithT(t, func(t *assert.CollectT) { - _, modelExist := serverRunner.Datastore.ModelGet("my-model") + _, modelExist := serverRunner.Datastore.ModelGetByModelName("my-model") synced := serverRunner.Datastore.PoolHasSynced() && modelExist assert.True(t, synced, "Timeout waiting for the pool and models to sync") }, 10*time.Second, 10*time.Millisecond) From 119cee0c4385223d46174146d48ca75a8796e312 Mon Sep 17 00:00:00 2001 From: ahg-g Date: Wed, 26 Feb 2025 02:46:20 +0000 Subject: [PATCH 2/5] Convert unit test to a table --- pkg/epp/datastore/datastore_test.go | 171 ++++++++++++++++++---------- 1 file changed, 112 insertions(+), 59 deletions(-) diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index e6c172d58..793f8ee1a 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -89,7 +89,7 @@ func TestPool(t *testing.T) { } } -func TestModel(t *testing.T) { +func TestModel1(t *testing.T) { chatModel := "chat" tsModel := "tweet-summary" model1ts := testutil.MakeInferenceModel("model1"). @@ -112,68 +112,121 @@ func TestModel(t *testing.T) { CreationTimestamp(metav1.Unix(1005, 0)). ModelName(chatModel).ObjRef() - ds := NewDatastore() - dsImpl := ds.(*datastore) - - // Step 1: add model1 with tweet-summary as modelName. - ds.ModelSetIfOlder(model1ts) - if diff := diffModelMaps(dsImpl, []*v1alpha2.InferenceModel{model1ts}); diff != "" { - t.Errorf("Unexpected models diff: %s", diff) - } - - // Step 2: set model1 with the same modelName, but with criticality set and newer creation timestamp, should update. - ds.ModelSetIfOlder(model1tsNewer) - if diff := diffModelMaps(dsImpl, []*v1alpha2.InferenceModel{model1tsNewer}); diff != "" { - t.Errorf("Unexpected models diff: %s", diff) - } - - // Step 3: set model2 with the same modelName, but newer creation timestamp, should not update. - ds.ModelSetIfOlder(model2tsNewer) - if diff := diffModelMaps(dsImpl, []*v1alpha2.InferenceModel{model1tsNewer}); diff != "" { - t.Errorf("Unexpected models diff: %s", diff) - } - - // Step 4: set model2 with the same modelName, but older creation timestamp, should update. - ds.ModelSetIfOlder(model2ts) - if diff := diffModelMaps(dsImpl, []*v1alpha2.InferenceModel{model2ts}); diff != "" { - t.Errorf("Unexpected models diff: %s", diff) - } - - // Step 5: set model2 updated with a new modelName, should update modelName. - ds.ModelSetIfOlder(model2chat) - if diff := diffModelMaps(dsImpl, []*v1alpha2.InferenceModel{model2chat}); diff != "" { - t.Errorf("Unexpected models diff: %s", diff) - } - - // Step 6: set model1 with the tweet-summary modelName, both models should exist. - ds.ModelSetIfOlder(model1ts) - if diff := diffModelMaps(dsImpl, []*v1alpha2.InferenceModel{model2chat, model1ts}); diff != "" { - t.Errorf("Unexpected models diff: %s", diff) - } + tests := []struct { + name string + existingModels []*v1alpha2.InferenceModel + op func(ds Datastore) bool + wantOpResult bool + wantModels []*v1alpha2.InferenceModel + }{ + { + name: "Add model1 with tweet-summary as modelName", + op: func(ds Datastore) bool { + return ds.ModelSetIfOlder(model1ts) + }, + wantModels: []*v1alpha2.InferenceModel{model1ts}, + wantOpResult: true, + }, + { + name: "Set model1 with the same modelName, but with diff criticality and newer creation timestamp, should update.", + existingModels: []*v1alpha2.InferenceModel{model1ts}, + op: func(ds Datastore) bool { + return ds.ModelSetIfOlder(model1tsNewer) + }, + wantOpResult: true, + wantModels: []*v1alpha2.InferenceModel{model1tsNewer}, + }, + { + name: "set model2 with the same modelName, but newer creation timestamp, should not update.", + existingModels: []*v1alpha2.InferenceModel{model1tsNewer}, + op: func(ds Datastore) bool { + return ds.ModelSetIfOlder(model2tsNewer) + }, + wantOpResult: false, + wantModels: []*v1alpha2.InferenceModel{model1tsNewer}, + }, + { + name: "Set model2 with the same modelName, but older creation timestamp, should update", + existingModels: []*v1alpha2.InferenceModel{model1tsNewer}, + op: func(ds Datastore) bool { + return ds.ModelSetIfOlder(model2ts) + }, + wantOpResult: true, + wantModels: []*v1alpha2.InferenceModel{model2ts}, + }, + { + name: "Set model2 updated with a new modelName, should update modelName", + existingModels: []*v1alpha2.InferenceModel{model2ts}, + op: func(ds Datastore) bool { + return ds.ModelSetIfOlder(model2chat) + }, + wantOpResult: true, + wantModels: []*v1alpha2.InferenceModel{model2chat}, + }, + { + name: "Set model1 with the tweet-summary modelName, both models should exist", + existingModels: []*v1alpha2.InferenceModel{model2chat}, + op: func(ds Datastore) bool { + return ds.ModelSetIfOlder(model1ts) + }, + wantOpResult: true, + wantModels: []*v1alpha2.InferenceModel{model2chat, model1ts}, + }, + { + name: "Set model1 with the tweet-summary modelName, both models should exist", + existingModels: []*v1alpha2.InferenceModel{model2chat, model1ts}, + op: func(ds Datastore) bool { + return ds.ModelSetIfOlder(model1ts) + }, + wantOpResult: true, + wantModels: []*v1alpha2.InferenceModel{model2chat, model1ts}, + }, + { + name: "Getting by model name, chat -> model2", + existingModels: []*v1alpha2.InferenceModel{model2chat, model1ts}, + op: func(ds Datastore) bool { + gotChat, exists := ds.ModelGetByModelName(chatModel) + return exists && cmp.Diff(model2chat, gotChat) == "" + }, + wantOpResult: true, + wantModels: []*v1alpha2.InferenceModel{model2chat, model1ts}, + }, + { + name: "Getting by obj name, model1 -> tweet-summary", + existingModels: []*v1alpha2.InferenceModel{model2chat, model1ts}, + op: func(ds Datastore) bool { + got, exists := ds.ModelGetByObjName(types.NamespacedName{Name: model1ts.Name, Namespace: model1ts.Namespace}) + return exists && cmp.Diff(model1ts, got) == "" + }, + wantOpResult: true, + wantModels: []*v1alpha2.InferenceModel{model2chat, model1ts}, + }, + { + name: "Getting by model name, chat -> model2", + existingModels: []*v1alpha2.InferenceModel{model2chat, model1ts}, + op: func(ds Datastore) bool { + ds.ModelDelete(types.NamespacedName{Name: model1ts.Name, Namespace: model1ts.Namespace}) + _, exists := ds.ModelGetByModelName(tsModel) + return exists - // Step 7: getting the models by model name, chat -> model2; tweet-summary -> model1 - gotChat, exists := ds.ModelGetByModelName(chatModel) - if !exists { - t.Error("Chat model should exist!") - } - if diff := cmp.Diff(model2chat, gotChat); diff != "" { - t.Errorf("Unexpected chat model diff: %s", diff) - } - gotSummary, exists := ds.ModelGetByModelName(tsModel) - if !exists { - t.Error("Summary model should exist!") - } - if diff := cmp.Diff(model1ts, gotSummary); diff != "" { - t.Errorf("Unexpected summary model diff: %s", diff) + }, + wantOpResult: false, + wantModels: []*v1alpha2.InferenceModel{model2chat}, + }, } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ds := NewFakeDatastore(nil, test.existingModels, nil) + gotOpResult := test.op(ds) + if gotOpResult != test.wantOpResult { + t.Errorf("Unexpected operation result, want: %v, got: %v", test.wantOpResult, gotOpResult) + } + if diff := diffModelMaps(ds.(*datastore), test.wantModels); diff != "" { + t.Errorf("Unexpected models diff: %s", diff) + } - // Step 6: delete model1, summary model should not exist. - ds.ModelDelete(types.NamespacedName{Name: model1ts.Name, Namespace: model1ts.Namespace}) - _, exists = ds.ModelGetByModelName(tsModel) - if exists { - t.Error("Summary model should not exist!") + }) } - } func diffModelMaps(ds *datastore, want []*v1alpha2.InferenceModel) string { From f307bc525c8b6d23773b42bd41b4dcff71c83c6a Mon Sep 17 00:00:00 2001 From: ahg-g Date: Wed, 26 Feb 2025 04:30:43 +0000 Subject: [PATCH 3/5] remove the dual map for the models store, and rely on linear search when looking up the model by object name --- .../controller/inferencemodel_reconciler.go | 8 +- pkg/epp/datastore/datastore.go | 116 +++++++----------- pkg/epp/datastore/datastore_test.go | 39 ++---- pkg/epp/handlers/request.go | 2 +- test/integration/hermetic_test.go | 2 +- 5 files changed, 58 insertions(+), 109 deletions(-) diff --git a/pkg/epp/controller/inferencemodel_reconciler.go b/pkg/epp/controller/inferencemodel_reconciler.go index 778b8e3c2..72a11a4af 100644 --- a/pkg/epp/controller/inferencemodel_reconciler.go +++ b/pkg/epp/controller/inferencemodel_reconciler.go @@ -85,17 +85,15 @@ func (c *InferenceModelReconciler) Reconcile(ctx context.Context, req ctrl.Reque func (c *InferenceModelReconciler) handleModelDeleted(ctx context.Context, req types.NamespacedName) error { logger := log.FromContext(ctx) - // We will lookup the modelName associated with this object to search for - // other instance referencing the same ModelName if exist to store the oldest in + // We will lookup and delete the modelName associated with this object, and search for + // other instances referencing the same modelName if exist, and store the oldest in // its place. This ensures that the InferenceModel with the oldest creation // timestamp is active. - existing, exists := c.Datastore.ModelGetByObjName(req) + existing, exists := c.Datastore.ModelDelete(req) if !exists { // No entry exists in the first place, nothing to do. return nil } - // Delete the internal object, it may be replaced with another version below. - c.Datastore.ModelDelete(req) logger.Info("InferenceModel removed from datastore", "poolRef", existing.Spec.PoolRef, "modelName", existing.Spec.ModelName) // List all InferenceModels with a matching ModelName. diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index 71a93f6a6..3c7001d90 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -46,9 +46,8 @@ type Datastore interface { // InferenceModel operations ModelSetIfOlder(infModel *v1alpha2.InferenceModel) bool - ModelGetByModelName(modelName string) (*v1alpha2.InferenceModel, bool) - ModelGetByObjName(namespacedName types.NamespacedName) (*v1alpha2.InferenceModel, bool) - ModelDelete(namespacedName types.NamespacedName) + ModelGet(modelName string) (*v1alpha2.InferenceModel, bool) + ModelDelete(namespacedName types.NamespacedName) (*v1alpha2.InferenceModel, bool) ModelGetAll() []*v1alpha2.InferenceModel // PodMetrics operations @@ -67,11 +66,9 @@ type Datastore interface { func NewDatastore() Datastore { store := &datastore{ - poolMu: sync.RWMutex{}, - modelsMu: sync.RWMutex{}, - modelsByModelName: make(map[string]*v1alpha2.InferenceModel), - modelsByObjName: make(map[types.NamespacedName]*v1alpha2.InferenceModel), - pods: &sync.Map{}, + poolAndModelsMu: sync.RWMutex{}, + models: make(map[string]*v1alpha2.InferenceModel), + pods: &sync.Map{}, } return store } @@ -97,39 +94,33 @@ func NewFakeDatastore(pods []*PodMetrics, models []*v1alpha2.InferenceModel, poo } type datastore struct { - // poolMu is used to synchronize access to the inferencePool. - poolMu sync.RWMutex - pool *v1alpha2.InferencePool - modelsMu sync.RWMutex - // key: types.NamespacedName, value: *InferenceModel - modelsByObjName map[types.NamespacedName]*v1alpha2.InferenceModel + // poolAndModelsMu is used to synchronize access to pool and the models map. + poolAndModelsMu sync.RWMutex + pool *v1alpha2.InferencePool // key: InferenceModel.Spec.ModelName, value: *InferenceModel - modelsByModelName map[string]*v1alpha2.InferenceModel + models map[string]*v1alpha2.InferenceModel // key: types.NamespacedName, value: *PodMetrics pods *sync.Map } func (ds *datastore) Clear() { - ds.poolMu.Lock() - defer ds.poolMu.Unlock() + ds.poolAndModelsMu.Lock() + defer ds.poolAndModelsMu.Unlock() ds.pool = nil - ds.modelsMu.Lock() - ds.modelsByModelName = make(map[string]*v1alpha2.InferenceModel) - ds.modelsByObjName = make(map[types.NamespacedName]*v1alpha2.InferenceModel) - ds.modelsMu.Unlock() + ds.models = make(map[string]*v1alpha2.InferenceModel) ds.pods.Clear() } // /// InferencePool APIs /// func (ds *datastore) PoolSet(pool *v1alpha2.InferencePool) { - ds.poolMu.Lock() - defer ds.poolMu.Unlock() + ds.poolAndModelsMu.Lock() + defer ds.poolAndModelsMu.Unlock() ds.pool = pool } func (ds *datastore) PoolGet() (*v1alpha2.InferencePool, error) { - ds.poolMu.RLock() - defer ds.poolMu.RUnlock() + ds.poolAndModelsMu.RLock() + defer ds.poolAndModelsMu.RUnlock() if !ds.PoolHasSynced() { return nil, errPoolNotSynced } @@ -137,14 +128,14 @@ func (ds *datastore) PoolGet() (*v1alpha2.InferencePool, error) { } func (ds *datastore) PoolHasSynced() bool { - ds.poolMu.RLock() - defer ds.poolMu.RUnlock() + ds.poolAndModelsMu.RLock() + defer ds.poolAndModelsMu.RUnlock() return ds.pool != nil } func (ds *datastore) PoolLabelsMatch(podLabels map[string]string) bool { - ds.poolMu.RLock() - defer ds.poolMu.RUnlock() + ds.poolAndModelsMu.RLock() + defer ds.poolAndModelsMu.RUnlock() poolSelector := selectorFromInferencePoolSelector(ds.pool.Spec.Selector) podSet := labels.Set(podLabels) return poolSelector.Matches(podSet) @@ -152,72 +143,53 @@ func (ds *datastore) PoolLabelsMatch(podLabels map[string]string) bool { // /// InferenceModel APIs /// func (ds *datastore) ModelSetIfOlder(infModel *v1alpha2.InferenceModel) bool { - ds.modelsMu.Lock() - defer ds.modelsMu.Unlock() + ds.poolAndModelsMu.Lock() + defer ds.poolAndModelsMu.Unlock() // Check first if the existing model is older. // One exception is if the incoming model object is the same, in which case, we should not // check for creation timestamp since that means the object was re-created, and so we should override. - existing, exists := ds.modelsByModelName[infModel.Spec.ModelName] + existing, exists := ds.models[infModel.Spec.ModelName] if exists { diffObj := infModel.Name != existing.Name || infModel.Namespace != existing.Namespace if diffObj && existing.ObjectMeta.CreationTimestamp.Before(&infModel.ObjectMeta.CreationTimestamp) { return false } } - - // Deleting the model first ensures that the two maps are always aligned. - namespacedName := types.NamespacedName{Name: infModel.Name, Namespace: infModel.Namespace} - ds.modelDeleteByObjName(namespacedName) - ds.modelDeleteByModelName(infModel.Spec.ModelName) - ds.modelsByModelName[infModel.Spec.ModelName] = infModel - ds.modelsByObjName[namespacedName] = infModel + // Delete the model + ds.modelDelete(types.NamespacedName{Name: infModel.Name, Namespace: infModel.Namespace}) + ds.models[infModel.Spec.ModelName] = infModel return true } -func (ds *datastore) ModelGetByModelName(modelName string) (*v1alpha2.InferenceModel, bool) { - ds.modelsMu.RLock() - defer ds.modelsMu.RUnlock() - m, exists := ds.modelsByModelName[modelName] - return m, exists -} - -func (ds *datastore) ModelGetByObjName(namespacedName types.NamespacedName) (*v1alpha2.InferenceModel, bool) { - ds.modelsMu.RLock() - defer ds.modelsMu.RUnlock() - m, exists := ds.modelsByObjName[namespacedName] +func (ds *datastore) ModelGet(modelName string) (*v1alpha2.InferenceModel, bool) { + ds.poolAndModelsMu.RLock() + defer ds.poolAndModelsMu.RUnlock() + m, exists := ds.models[modelName] return m, exists } -func (ds *datastore) ModelDelete(namespacedName types.NamespacedName) { - ds.modelsMu.Lock() - defer ds.modelsMu.Unlock() - ds.modelDeleteByObjName(namespacedName) -} - -func (ds *datastore) modelDeleteByObjName(namespacedName types.NamespacedName) { - infModel, ok := ds.modelsByObjName[namespacedName] - if !ok { - return +func (ds *datastore) modelDelete(namespacedName types.NamespacedName) (*v1alpha2.InferenceModel, bool) { + for _, m := range ds.models { + if m.Name == namespacedName.Name && m.Namespace == namespacedName.Namespace { + delete(ds.models, m.Spec.ModelName) + return m, true + } } - delete(ds.modelsByObjName, namespacedName) - delete(ds.modelsByModelName, infModel.Spec.ModelName) + return nil, false } -func (ds *datastore) modelDeleteByModelName(modelName string) { - infModel, ok := ds.modelsByModelName[modelName] - if !ok { - return - } - delete(ds.modelsByObjName, types.NamespacedName{Name: infModel.Name, Namespace: infModel.Namespace}) - delete(ds.modelsByModelName, modelName) +func (ds *datastore) ModelDelete(namespacedName types.NamespacedName) (*v1alpha2.InferenceModel, bool) { + ds.poolAndModelsMu.Lock() + defer ds.poolAndModelsMu.Unlock() + return ds.modelDelete(namespacedName) } func (ds *datastore) ModelGetAll() []*v1alpha2.InferenceModel { - ds.modelsMu.RLock() - defer ds.modelsMu.RUnlock() + ds.poolAndModelsMu.RLock() + defer ds.poolAndModelsMu.RUnlock() res := []*v1alpha2.InferenceModel{} - for _, v := range ds.modelsByObjName { + for _, v := range ds.models { res = append(res, v) } return res diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index 793f8ee1a..972025458 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -89,7 +89,7 @@ func TestPool(t *testing.T) { } } -func TestModel1(t *testing.T) { +func TestModel(t *testing.T) { chatModel := "chat" tsModel := "tweet-summary" model1ts := testutil.MakeInferenceModel("model1"). @@ -185,32 +185,22 @@ func TestModel1(t *testing.T) { name: "Getting by model name, chat -> model2", existingModels: []*v1alpha2.InferenceModel{model2chat, model1ts}, op: func(ds Datastore) bool { - gotChat, exists := ds.ModelGetByModelName(chatModel) + gotChat, exists := ds.ModelGet(chatModel) return exists && cmp.Diff(model2chat, gotChat) == "" }, wantOpResult: true, wantModels: []*v1alpha2.InferenceModel{model2chat, model1ts}, }, { - name: "Getting by obj name, model1 -> tweet-summary", + name: "Delete the model", existingModels: []*v1alpha2.InferenceModel{model2chat, model1ts}, op: func(ds Datastore) bool { - got, exists := ds.ModelGetByObjName(types.NamespacedName{Name: model1ts.Name, Namespace: model1ts.Namespace}) - return exists && cmp.Diff(model1ts, got) == "" - }, - wantOpResult: true, - wantModels: []*v1alpha2.InferenceModel{model2chat, model1ts}, - }, - { - name: "Getting by model name, chat -> model2", - existingModels: []*v1alpha2.InferenceModel{model2chat, model1ts}, - op: func(ds Datastore) bool { - ds.ModelDelete(types.NamespacedName{Name: model1ts.Name, Namespace: model1ts.Namespace}) - _, exists := ds.ModelGetByModelName(tsModel) - return exists + _, existed := ds.ModelDelete(types.NamespacedName{Name: model1ts.Name, Namespace: model1ts.Namespace}) + _, exists := ds.ModelGet(tsModel) + return existed && !exists }, - wantOpResult: false, + wantOpResult: true, wantModels: []*v1alpha2.InferenceModel{model2chat}, }, } @@ -221,7 +211,8 @@ func TestModel1(t *testing.T) { if gotOpResult != test.wantOpResult { t.Errorf("Unexpected operation result, want: %v, got: %v", test.wantOpResult, gotOpResult) } - if diff := diffModelMaps(ds.(*datastore), test.wantModels); diff != "" { + + if diff := testutil.DiffModelLists(test.wantModels, ds.ModelGetAll()); diff != "" { t.Errorf("Unexpected models diff: %s", diff) } @@ -229,18 +220,6 @@ func TestModel1(t *testing.T) { } } -func diffModelMaps(ds *datastore, want []*v1alpha2.InferenceModel) string { - byObjName := ds.ModelGetAll() - byModelName := []*v1alpha2.InferenceModel{} - for _, v := range ds.modelsByModelName { - byModelName = append(byModelName, v) - } - if diff := testutil.DiffModelLists(byObjName, byModelName); diff != "" { - return "Inconsistent maps diff: " + diff - } - return testutil.DiffModelLists(want, byObjName) -} - func TestRandomWeightedDraw(t *testing.T) { logger := logutil.NewTestLogger() tests := []struct { diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 32062da33..c6cfdda29 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -64,7 +64,7 @@ func (s *Server) HandleRequestBody( // NOTE: The nil checking for the modelObject means that we DO allow passthrough currently. // This might be a security risk in the future where adapters not registered in the InferenceModel // are able to be requested by using their distinct name. - modelObj, exist := s.datastore.ModelGetByModelName(model) + modelObj, exist := s.datastore.ModelGet(model) if !exist { return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)} } diff --git a/test/integration/hermetic_test.go b/test/integration/hermetic_test.go index dc64758ea..2ea66dba4 100644 --- a/test/integration/hermetic_test.go +++ b/test/integration/hermetic_test.go @@ -476,7 +476,7 @@ func BeforeSuit(t *testing.T) func() { } assert.EventuallyWithT(t, func(t *assert.CollectT) { - _, modelExist := serverRunner.Datastore.ModelGetByModelName("my-model") + _, modelExist := serverRunner.Datastore.ModelGet("my-model") synced := serverRunner.Datastore.PoolHasSynced() && modelExist assert.True(t, synced, "Timeout waiting for the pool and models to sync") }, 10*time.Second, 10*time.Millisecond) From 02aa4c9ab3e6b099eee7288ae3ccecd4ec8c9492 Mon Sep 17 00:00:00 2001 From: ahg-g Date: Wed, 26 Feb 2025 14:41:59 +0000 Subject: [PATCH 4/5] Added ModelResync to handle a race condition --- .../controller/inferencemodel_reconciler.go | 38 +++----------- .../inferencemodel_reconciler_test.go | 2 +- pkg/epp/datastore/datastore.go | 51 +++++++++++++++---- pkg/epp/datastore/datastore_test.go | 9 ---- 4 files changed, 49 insertions(+), 51 deletions(-) diff --git a/pkg/epp/controller/inferencemodel_reconciler.go b/pkg/epp/controller/inferencemodel_reconciler.go index 72a11a4af..4fbf24842 100644 --- a/pkg/epp/controller/inferencemodel_reconciler.go +++ b/pkg/epp/controller/inferencemodel_reconciler.go @@ -34,10 +34,6 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -const ( - modelNameKey = "spec.modelName" -) - type InferenceModelReconciler struct { client.Client Scheme *runtime.Scheme @@ -96,35 +92,13 @@ func (c *InferenceModelReconciler) handleModelDeleted(ctx context.Context, req t } logger.Info("InferenceModel removed from datastore", "poolRef", existing.Spec.PoolRef, "modelName", existing.Spec.ModelName) - // List all InferenceModels with a matching ModelName. - var models v1alpha2.InferenceModelList - if err := c.List(ctx, &models, client.MatchingFields{modelNameKey: existing.Spec.ModelName}, client.InNamespace(c.PoolNamespacedName.Namespace)); err != nil { - return fmt.Errorf("listing models that match the modelName %s: %w", existing.Spec.ModelName, err) + updated, err := c.Datastore.ModelResync(ctx, c.Client, existing.Spec.ModelName) + if err != nil { + return err } - if len(models.Items) == 0 { - // No other instances of InferenceModels with this ModelName exists. - return nil + if updated { + logger.Info("Model replaced.", "modelName", existing.Spec.ModelName) } - - var oldest *v1alpha2.InferenceModel - for i := range models.Items { - m := &models.Items[i] - if m.Spec.ModelName != existing.Spec.ModelName || // The index should filter those out, but just in case! - m.Spec.PoolRef.Name != c.PoolNamespacedName.Name || // We don't care about other pools, we could setup an index on this too! - m.Name == existing.Name { // We don't care about the same object, it could be in the list if it was only marked for deletion, but not yet deleted. - continue - } - if oldest == nil || m.ObjectMeta.CreationTimestamp.Before(&oldest.ObjectMeta.CreationTimestamp) { - oldest = m - } - } - if oldest != nil && c.Datastore.ModelSetIfOlder(oldest) { - logger.Info("InferenceModel replaced.", - "poolRef", oldest.Spec.PoolRef, - "modelName", oldest.Spec.ModelName, - "newInferenceModel", types.NamespacedName{Name: oldest.Name, Namespace: oldest.Namespace}) - } - return nil } @@ -139,7 +113,7 @@ func indexInferenceModelsByModelName(obj client.Object) []string { func (c *InferenceModelReconciler) SetupWithManager(ctx context.Context, mgr ctrl.Manager) error { // Create an index on ModelName for InferenceModel objects. indexer := mgr.GetFieldIndexer() - if err := indexer.IndexField(ctx, &v1alpha2.InferenceModel{}, modelNameKey, indexInferenceModelsByModelName); err != nil { + if err := indexer.IndexField(ctx, &v1alpha2.InferenceModel{}, datastore.ModelNameIndexKey, indexInferenceModelsByModelName); err != nil { return fmt.Errorf("setting index on ModelName for InferenceModel: %w", err) } return ctrl.NewControllerManagedBy(mgr). diff --git a/pkg/epp/controller/inferencemodel_reconciler_test.go b/pkg/epp/controller/inferencemodel_reconciler_test.go index b9b664f24..87323e807 100644 --- a/pkg/epp/controller/inferencemodel_reconciler_test.go +++ b/pkg/epp/controller/inferencemodel_reconciler_test.go @@ -198,7 +198,7 @@ func TestInferenceModelReconciler(t *testing.T) { fakeClient := fake.NewClientBuilder(). WithScheme(scheme). WithObjects(initObjs...). - WithIndex(&v1alpha2.InferenceModel{}, modelNameKey, indexInferenceModelsByModelName). + WithIndex(&v1alpha2.InferenceModel{}, datastore.ModelNameIndexKey, indexInferenceModelsByModelName). Build() datastore := datastore.NewFakeDatastore(nil, test.modelsInStore, pool) diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index 3c7001d90..cd5d290f2 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -19,6 +19,7 @@ package datastore import ( "context" "errors" + "fmt" "math/rand" "sync" @@ -32,6 +33,10 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) +const ( + ModelNameIndexKey = "spec.modelName" +) + var ( errPoolNotSynced = errors.New("InferencePool is not initialized in data store") ) @@ -48,6 +53,7 @@ type Datastore interface { ModelSetIfOlder(infModel *v1alpha2.InferenceModel) bool ModelGet(modelName string) (*v1alpha2.InferenceModel, bool) ModelDelete(namespacedName types.NamespacedName) (*v1alpha2.InferenceModel, bool) + ModelResync(ctx context.Context, ctrlClient client.Client, modelName string) (bool, error) ModelGetAll() []*v1alpha2.InferenceModel // PodMetrics operations @@ -156,12 +162,43 @@ func (ds *datastore) ModelSetIfOlder(infModel *v1alpha2.InferenceModel) bool { return false } } - // Delete the model - ds.modelDelete(types.NamespacedName{Name: infModel.Name, Namespace: infModel.Namespace}) + // Set the model. ds.models[infModel.Spec.ModelName] = infModel return true } +func (ds *datastore) ModelResync(ctx context.Context, c client.Client, modelName string) (bool, error) { + ds.poolAndModelsMu.Lock() + defer ds.poolAndModelsMu.Unlock() + + var models v1alpha2.InferenceModelList + if err := c.List(ctx, &models, client.MatchingFields{ModelNameIndexKey: modelName}, client.InNamespace(ds.pool.Namespace)); err != nil { + return false, fmt.Errorf("listing models that match the modelName %s: %w", modelName, err) + } + if len(models.Items) == 0 { + // No other instances of InferenceModels with this ModelName exists. + return false, nil + } + + var oldest *v1alpha2.InferenceModel + for i := range models.Items { + m := &models.Items[i] + if m.Spec.ModelName != modelName || // The index should filter those out, but just in case! + m.Spec.PoolRef.Name != ds.pool.Name || // We don't care about other pools, we could setup an index on this too! + !m.DeletionTimestamp.IsZero() { // ignore objects marked for deletion + continue + } + if oldest == nil || m.ObjectMeta.CreationTimestamp.Before(&oldest.ObjectMeta.CreationTimestamp) { + oldest = m + } + } + if oldest == nil { + return false, nil + } + ds.models[modelName] = oldest + return true, nil +} + func (ds *datastore) ModelGet(modelName string) (*v1alpha2.InferenceModel, bool) { ds.poolAndModelsMu.RLock() defer ds.poolAndModelsMu.RUnlock() @@ -169,7 +206,9 @@ func (ds *datastore) ModelGet(modelName string) (*v1alpha2.InferenceModel, bool) return m, exists } -func (ds *datastore) modelDelete(namespacedName types.NamespacedName) (*v1alpha2.InferenceModel, bool) { +func (ds *datastore) ModelDelete(namespacedName types.NamespacedName) (*v1alpha2.InferenceModel, bool) { + ds.poolAndModelsMu.Lock() + defer ds.poolAndModelsMu.Unlock() for _, m := range ds.models { if m.Name == namespacedName.Name && m.Namespace == namespacedName.Namespace { delete(ds.models, m.Spec.ModelName) @@ -179,12 +218,6 @@ func (ds *datastore) modelDelete(namespacedName types.NamespacedName) (*v1alpha2 return nil, false } -func (ds *datastore) ModelDelete(namespacedName types.NamespacedName) (*v1alpha2.InferenceModel, bool) { - ds.poolAndModelsMu.Lock() - defer ds.poolAndModelsMu.Unlock() - return ds.modelDelete(namespacedName) -} - func (ds *datastore) ModelGetAll() []*v1alpha2.InferenceModel { ds.poolAndModelsMu.RLock() defer ds.poolAndModelsMu.RUnlock() diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index 972025458..edc96626e 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -154,15 +154,6 @@ func TestModel(t *testing.T) { wantOpResult: true, wantModels: []*v1alpha2.InferenceModel{model2ts}, }, - { - name: "Set model2 updated with a new modelName, should update modelName", - existingModels: []*v1alpha2.InferenceModel{model2ts}, - op: func(ds Datastore) bool { - return ds.ModelSetIfOlder(model2chat) - }, - wantOpResult: true, - wantModels: []*v1alpha2.InferenceModel{model2chat}, - }, { name: "Set model1 with the tweet-summary modelName, both models should exist", existingModels: []*v1alpha2.InferenceModel{model2chat}, From 4d90fbbc0e023d8f4e03bd6e3309b62694d3ce6b Mon Sep 17 00:00:00 2001 From: Abdullah Gharaibeh <40361897+ahg-g@users.noreply.github.com> Date: Wed, 26 Feb 2025 10:25:15 -0800 Subject: [PATCH 5/5] Update pkg/epp/controller/inferencemodel_reconciler.go --- pkg/epp/controller/inferencemodel_reconciler.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/epp/controller/inferencemodel_reconciler.go b/pkg/epp/controller/inferencemodel_reconciler.go index 4fbf24842..7cf188087 100644 --- a/pkg/epp/controller/inferencemodel_reconciler.go +++ b/pkg/epp/controller/inferencemodel_reconciler.go @@ -92,6 +92,7 @@ func (c *InferenceModelReconciler) handleModelDeleted(ctx context.Context, req t } logger.Info("InferenceModel removed from datastore", "poolRef", existing.Spec.PoolRef, "modelName", existing.Spec.ModelName) + // TODO(#409): replace this backfill logic with one that is based on InferenceModel Ready conditions once those are set by an external controller. updated, err := c.Datastore.ModelResync(ctx, c.Client, existing.Spec.ModelName) if err != nil { return err