diff --git a/pkg/epp/scheduling/framework/plugins/picker/common.go b/pkg/epp/scheduling/framework/plugins/picker/common.go index 4bbc300da..c8655840f 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/common.go +++ b/pkg/epp/scheduling/framework/plugins/picker/common.go @@ -16,6 +16,13 @@ limitations under the License. package picker +import ( + "math/rand/v2" + "time" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + const ( DefaultMaxNumOfEndpoints = 1 // common default to all pickers ) @@ -24,3 +31,14 @@ const ( type pickerParameters struct { MaxNumOfEndpoints int `json:"maxNumOfEndpoints"` } + +func shuffleScoredPods(scoredPods []*types.ScoredPod) { + // Rand package is not safe for concurrent use, so we create a new instance. + // Source: https://pkg.go.dev/math/rand/v2#pkg-overview + randomGenerator := rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), 0)) + + // Shuffle in-place + randomGenerator.Shuffle(len(scoredPods), func(i, j int) { + scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i] + }) +} diff --git a/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go b/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go index 325f735fa..47ad1f1ba 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go +++ b/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go @@ -20,12 +20,9 @@ import ( "context" "encoding/json" "fmt" - "math/rand" "slices" - "time" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" @@ -85,15 +82,8 @@ func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState, log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting pods from candidates sorted by max score", "max-num-of-endpoints", p.maxNumOfEndpoints, "num-of-candidates", len(scoredPods), "scored-pods", scoredPods) - // TODO: merge this with the logic in RandomPicker - // Rand package is not safe for concurrent use, so we create a new instance. - // Source: https://pkg.go.dev/math/rand#pkg-overview - randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano())) - // Shuffle in-place - needed for random tie break when scores are equal - randomGenerator.Shuffle(len(scoredPods), func(i, j int) { - scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i] - }) + shuffleScoredPods(scoredPods) slices.SortStableFunc(scoredPods, func(i, j *types.ScoredPod) int { // highest score first if i.Score > j.Score { diff --git a/pkg/epp/scheduling/framework/plugins/picker/random_picker.go b/pkg/epp/scheduling/framework/plugins/picker/random_picker.go index 87a1747fc..4c697d2f6 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/random_picker.go +++ b/pkg/epp/scheduling/framework/plugins/picker/random_picker.go @@ -20,11 +20,8 @@ import ( "context" "encoding/json" "fmt" - "math/rand" - "time" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" @@ -84,15 +81,8 @@ func (p *RandomPicker) Pick(ctx context.Context, _ *types.CycleState, scoredPods log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting pods from candidates randomly", "max-num-of-endpoints", p.maxNumOfEndpoints, "num-of-candidates", len(scoredPods), "scored-pods", scoredPods) - // TODO: merge this with the logic in MaxScorePicker - // Rand package is not safe for concurrent use, so we create a new instance. - // Source: https://pkg.go.dev/math/rand#pkg-overview - randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano())) - // Shuffle in-place - randomGenerator.Shuffle(len(scoredPods), func(i, j int) { - scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i] - }) + shuffleScoredPods(scoredPods) // if we have enough pods to return keep only the relevant subset if p.maxNumOfEndpoints < len(scoredPods) {