Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/epp/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ func (r *Runner) registerInTreePlugins() {
plugins.Register(prefix.PrefixCachePluginType, prefix.PrefixCachePluginFactory)
plugins.Register(picker.MaxScorePickerType, picker.MaxScorePickerFactory)
plugins.Register(picker.RandomPickerType, picker.RandomPickerFactory)
plugins.Register(picker.WeightedRandomPickerType, picker.WeightedRandomPickerFactory)
plugins.Register(profile.SingleProfileHandlerType, profile.SingleProfileHandlerFactory)
plugins.Register(scorer.KvCacheUtilizationScorerType, scorer.KvCacheUtilizationScorerFactory)
plugins.Register(scorer.QueueScorerType, scorer.QueueScorerFactory)
Expand Down
1 change: 1 addition & 0 deletions pkg/epp/config/loader/configloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ func registerNeededPlgugins() {
plugins.Register(prefix.PrefixCachePluginType, prefix.PrefixCachePluginFactory)
plugins.Register(picker.MaxScorePickerType, picker.MaxScorePickerFactory)
plugins.Register(picker.RandomPickerType, picker.RandomPickerFactory)
plugins.Register(picker.WeightedRandomPickerType, picker.WeightedRandomPickerFactory)
plugins.Register(profile.SingleProfileHandlerType, profile.SingleProfileHandlerFactory)
}

Expand Down
129 changes: 129 additions & 0 deletions pkg/epp/scheduling/framework/plugins/picker/picker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,132 @@ func TestPickMaxScorePicker(t *testing.T) {
})
}
}

func TestPickWeightedRandomPicker(t *testing.T) {
pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}
pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}
pod3 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}
pod4 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod4"}}}
pod5 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod5"}}}

// A-Res algorithm uses U^(1/w) transformation which introduces statistical variance
// beyond simple proportional sampling. Generous tolerance is required to prevent
// flaky tests in CI environments, especially for multi-tier weights.
tests := []struct {
name string
input []*types.ScoredPod
maxPods int // maxNumOfEndpoints for this test
iterations int
expectedProbabilities map[string]float64 // pod name -> expected probability
tolerancePercent float64 // acceptable deviation percentage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the way you implemented a probability based test. I think we should give it a try and hopefully we don't see flaky tests in CI.
I have several minor comments:

  • can we set iterations as a const or a local variable and just use it across all the tests? I didn't understand why different tests have different number of iterations. also I assume 500 iterations should be enough to have representative results.
  • can we move expectedProb to be calculated in the code and not set as input? it looks like magic numbers and this is an easy calculation (summing the scores, dividing score / sum).
  • tolerance can also be a local variable or a const, isn't it? why should the tolerance be different between tests?
  • nit: I think it's easier to understand if we use tolerance value in the range of [0,1] (instead of dividing by 100).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used 1000 iterations because 500 iterations sometimes make tests fail. The others followed your suggestions
Thx!

}{
{
name: "High weight dominance test",
input: []*types.ScoredPod{
{Pod: pod1, Score: 10}, // Lower weight
{Pod: pod2, Score: 90}, // Higher weight (should dominate)
},
maxPods: 1,
iterations: 2000,
expectedProbabilities: map[string]float64{
"pod1": 0.10,
"pod2": 0.90,
},
tolerancePercent: 20.0,
},
{
name: "Equal weights test - A-Res uniform distribution",
input: []*types.ScoredPod{
{Pod: pod1, Score: 100}, // Equal weights (higher values for better numerical precision)
{Pod: pod2, Score: 100}, // Equal weights should yield uniform distribution
{Pod: pod3, Score: 100}, // Equal weights in A-Res
},
maxPods: 1,
iterations: 1500,
expectedProbabilities: map[string]float64{
"pod1": 0.333, // Equal weights should yield uniform distribution
"pod2": 0.333, // A-Res maintains equal probability for equal weights
"pod3": 0.333, // Each pod has theoretically equal chance
},
tolerancePercent: 20.0,
},
{
name: "Zero weight exclusion test - A-Res edge case",
input: []*types.ScoredPod{
{Pod: pod1, Score: 30}, // Normal weight, should be selected
{Pod: pod2, Score: 0}, // Zero weight, never selected in A-Res
},
maxPods: 1,
iterations: 500,
expectedProbabilities: map[string]float64{
"pod1": 1.0, // Only pod with positive weight
"pod2": 0.0, // Zero weight pods are filtered out
},
tolerancePercent: 5.0, // ±5% tolerance (should be exact for zero weights)
},
{
name: "Multi-tier weighted test - A-Res complex distribution",
input: []*types.ScoredPod{
{Pod: pod1, Score: 100}, // Highest weight
{Pod: pod2, Score: 90}, // High weight
{Pod: pod3, Score: 50}, // Medium weight
{Pod: pod4, Score: 30}, // Low weight
{Pod: pod5, Score: 20}, // Lowest weight
},
maxPods: 1,
iterations: 1000,
expectedProbabilities: map[string]float64{
"pod1": 0.345, // Highest weight gets highest probability
"pod2": 0.310, // High weight gets high probability
"pod3": 0.172, // Medium weight gets medium probability
"pod4": 0.103, // Low weight gets low probability
"pod5": 0.069, // Lowest weight gets lowest probability
},
tolerancePercent: 25.0,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
picker := NewWeightedRandomPicker(test.maxPods)
selectionCounts := make(map[string]int)

// Initialize selection counters for each pod
for _, pod := range test.input {
podName := pod.GetPod().NamespacedName.Name
selectionCounts[podName] = 0
}

// Run multiple iterations to gather statistical data
for i := 0; i < test.iterations; i++ {
result := picker.Pick(context.Background(), types.NewCycleState(), test.input)

// Count selections for probability analysis
if len(result.TargetPods) > 0 {
selectedPodName := result.TargetPods[0].GetPod().NamespacedName.Name
selectionCounts[selectedPodName]++
}
}

// Verify probability distribution
if test.expectedProbabilities != nil {
for podName, expectedProb := range test.expectedProbabilities {
actualCount := selectionCounts[podName]
actualProb := float64(actualCount) / float64(test.iterations)

tolerance := expectedProb * test.tolerancePercent / 100.0
lowerBound := expectedProb - tolerance
upperBound := expectedProb + tolerance

if actualProb < lowerBound || actualProb > upperBound {
t.Errorf("Pod %s: expected probability %.3f ±%.1f%%, got %.3f (count: %d/%d)",
podName, expectedProb, test.tolerancePercent, actualProb, actualCount, test.iterations)
} else {
t.Logf("Pod %s: expected %.3f, got %.3f (count: %d/%d) ✓",
podName, expectedProb, actualProb, actualCount, test.iterations)
}
}
}
})
}
}
169 changes: 169 additions & 0 deletions pkg/epp/scheduling/framework/plugins/picker/weighted_random_picker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package picker

import (
"context"
"encoding/json"
"fmt"
"math"
"math/rand"
"sort"
"time"

"sigs.k8s.io/controller-runtime/pkg/log"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

const (
WeightedRandomPickerType = "weighted-random-picker"
)

// weightedScoredPod represents a scored pod with its A-Res sampling key
type weightedScoredPod struct {
*types.ScoredPod
key float64
}

var _ framework.Picker = &WeightedRandomPicker{}

func WeightedRandomPickerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
parameters := pickerParameters{
MaxNumOfEndpoints: DefaultMaxNumOfEndpoints,
}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
return nil, fmt.Errorf("failed to parse the parameters of the '%s' picker - %w", WeightedRandomPickerType, err)
}
}

return NewWeightedRandomPicker(parameters.MaxNumOfEndpoints).WithName(name), nil
}

func NewWeightedRandomPicker(maxNumOfEndpoints int) *WeightedRandomPicker {
if maxNumOfEndpoints <= 0 {
maxNumOfEndpoints = DefaultMaxNumOfEndpoints
}

return &WeightedRandomPicker{
typedName: plugins.TypedName{Type: WeightedRandomPickerType, Name: WeightedRandomPickerType},
maxNumOfEndpoints: maxNumOfEndpoints,
randomPicker: NewRandomPicker(maxNumOfEndpoints),
}
}

type WeightedRandomPicker struct {
typedName plugins.TypedName
maxNumOfEndpoints int
randomPicker *RandomPicker // fallback for zero weights
}

func (p *WeightedRandomPicker) WithName(name string) *WeightedRandomPicker {
p.typedName.Name = name
return p
}

func (p *WeightedRandomPicker) TypedName() plugins.TypedName {
return p.typedName
}

// WeightedRandomPicker performs weighted random sampling using A-Res algorithm.
// Reference: https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf
// Algorithm:
// - Uses A-Res (Algorithm for Reservoir Sampling): keyᵢ = Uᵢ^(1/wᵢ)
// - Selects k items with largest keys for mathematically correct weighted sampling
// - More efficient than traditional cumulative probability approach
//
// Key characteristics:
// - Mathematically correct weighted random sampling
// - Single pass algorithm with O(n + k log k) complexity
func (p *WeightedRandomPicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult {
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates using weighted random sampling: %+v",
p.maxNumOfEndpoints, len(scoredPods), scoredPods))

// Check if all weights are zero or negative
allZeroWeights := true
for _, scoredPod := range scoredPods {
if scoredPod.Score > 0 {
allZeroWeights = false
break
}
}

// Delegate to RandomPicker for uniform selection when all weights are zero
if allZeroWeights {
log.FromContext(ctx).V(logutil.DEBUG).Info("All weights are zero, delegating to RandomPicker for uniform selection")
return p.randomPicker.Pick(ctx, cycleState, scoredPods)
}

randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano()))

// A-Res algorithm: keyᵢ = Uᵢ^(1/wᵢ)
weightedPods := make([]weightedScoredPod, 0, len(scoredPods))

for _, scoredPod := range scoredPods {
weight := float64(scoredPod.Score)

// Handle zero or negative weights
if weight <= 0 {
// Assign very small key for zero-weight pods (effectively excludes them)
weightedPods = append(weightedPods, weightedScoredPod{
ScoredPod: scoredPod,
key: 0,
})
continue
}

// Generate random number U in (0,1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be updated to// Generate random number U in [0.0, 1.0) here?

It will be clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No 0 is always excluded so the existing comment is correct.

	u := randomGenerator.Float64()
		if u == 0 {
			u = 1e-10 // Avoid log(0)   #<-- this will be like 0.0000000001
		}

u := randomGenerator.Float64()
if u == 0 {
u = 1e-10 // Avoid log(0)
}

// Calculate key = U^(1/weight)
key := math.Pow(u, 1.0/weight)

weightedPods = append(weightedPods, weightedScoredPod{
ScoredPod: scoredPod,
key: key,
})
}

// Sort by key in descending order (largest keys first)
sort.Slice(weightedPods, func(i, j int) bool {
return weightedPods[i].key > weightedPods[j].key
})

// Select top k pods
selectedCount := min(p.maxNumOfEndpoints, len(weightedPods))

scoredPods = make([]*types.ScoredPod, selectedCount)
for i := range selectedCount {
scoredPods[i] = weightedPods[i].ScoredPod
}

targetPods := make([]types.Pod, len(scoredPods))
for i, scoredPod := range scoredPods {
targetPods[i] = scoredPod
}

return &types.ProfileRunResult{TargetPods: targetPods}
}