Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions cmd/epp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics/collectors"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/multi/prefix"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/filter"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker"
profilepicker "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile-picker"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/scorer"
runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server"
envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
Expand Down Expand Up @@ -196,20 +198,21 @@ func run() error {
queueScorerWeight := envutil.GetEnvInt("QUEUE_SCORE_WEIGHT", scorer.DefaultQueueScorerWeight, setupLog)
kvCacheScorerWeight := envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", scorer.DefaultKVCacheScorerWeight, setupLog)

schedulerConfig := scheduling.NewSchedulerConfig().
schedulerProfile := framework.NewSchedulerProfile().
WithFilters(filter.NewSheddableCapacityFilter()).
WithScorers(scorer.NewWeightedScorer(&scorer.QueueScorer{}, queueScorerWeight),
scorer.NewWeightedScorer(&scorer.KVCacheScorer{}, kvCacheScorerWeight)).
WithScorers(framework.NewWeightedScorer(&scorer.QueueScorer{}, queueScorerWeight),
framework.NewWeightedScorer(&scorer.KVCacheScorer{}, kvCacheScorerWeight)).
WithPicker(picker.NewMaxScorePicker())

if prefixCacheScheduling == "true" {
prefixScorerWeight := envutil.GetEnvInt("PREFIX_CACHE_SCORE_WEIGHT", prefix.DefaultScorerWeight, setupLog)
if err := schedulerConfig.AddPlugins(scorer.NewWeightedScorer(prefix.New(loadPrefixCacheConfig()), prefixScorerWeight)); err != nil {
if err := schedulerProfile.AddPlugins(framework.NewWeightedScorer(prefix.New(loadPrefixCacheConfig()), prefixScorerWeight)); err != nil {
setupLog.Error(err, "Failed to register scheduler plugins")
return err
}
}

schedulerConfig := scheduling.NewSchedulerConfig(profilepicker.NewAllProfilesPicker(), map[string]*framework.SchedulerProfile{"schedulerv2": schedulerProfile})
scheduler = scheduling.NewSchedulerWithConfig(datastore, schedulerConfig)
}
serverRunner := &runserver.ExtProcServerRunner{
Expand Down
14 changes: 9 additions & 5 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
)

type Scheduler interface {
Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error)
Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result map[string]*schedulingtypes.Result, err error)
OnResponse(ctx context.Context, resp *schedulingtypes.LLMResponse, targetPodName string)
}

Expand Down Expand Up @@ -108,23 +108,27 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
}

// Dispatch runs one or many scheduling cycles.
func (d *Director) Dispatch(ctx context.Context, llmReq *schedulingtypes.LLMRequest) ([]*schedulingtypes.Result, error) {
func (d *Director) Dispatch(ctx context.Context, llmReq *schedulingtypes.LLMRequest) (map[string]*schedulingtypes.Result, error) {
var err error
res, err := d.scheduler.Schedule(ctx, llmReq)
if err != nil {
return nil, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
}

return []*schedulingtypes.Result{res}, nil
return res, nil // TODO handle multi cycle result after defining the PostDispatch extension point
}

func (d *Director) PostDispatch(ctx context.Context, reqCtx *handlers.RequestContext, results []*schedulingtypes.Result) (*handlers.RequestContext, error) {
func (d *Director) PostDispatch(ctx context.Context, reqCtx *handlers.RequestContext, results map[string]*schedulingtypes.Result) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx)
// currently only get a single result. Will refactor to pluggably implement the PostSchedule
if len(results) == 0 {
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "results must be greater than zero"}
}
targetPod := results[0].TargetPod.GetPod()
var targetPod *backend.Pod
// TODO should handle multi cycle results, this should be pluggable logic
for _, result := range results {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will just set the final targetPod. Can we instead just index on the key & get the result that way? We can leave the todo to indicate this is a transitory state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean doing this?

targetPod = results[key].TargetPod.GetPod()

instead of:

targetPod = result.TargetPod.GetPod()

not sure I got the intention. we have here a map from profile-name -> result.
this is a transitionary stage where only one profile is used, so only one result.
but since it's a map I must use the range loop.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just do targetPod = results[key].TargetPod.GetPod() and drop the for loop.

Since we only end up with one result as is even if there are multiple profiles, it reads more obvious if we just use the key, I think.

Copy link
Contributor Author

@nirrozenbaum nirrozenbaum May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so now the next question, what is the value of key in this line

targetPod = results[key].TargetPod.GetPod()?

which key value to use?
if we used default configuration it's just "default", but if we used schedulerv2, it's a different profile.
since we don't know the key, I'm using the loop...

Copy link
Contributor Author

@nirrozenbaum nirrozenbaum May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect this to be solved as soon as we implement the extension point

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only one we set is schedulerv2 right?
line 215 in the main file of this PR

Copy link
Collaborator

@kfswain kfswain May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this may be a nit, but the for loop would look out of place since the iteration isn't quite being used

Copy link
Contributor Author

@nirrozenbaum nirrozenbaum May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we used the env var to enable "schedulerv2" than the key is "schedulerv2".
if not and we use default configuration, the key is "default" (in NewScheduler func):
https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/pkg/epp/scheduling/scheduler.go#L73

this is why I didn't put here a const.

targetPod = result.TargetPod.GetPod()
}

pool, err := d.datastore.PoolGet()
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package plugins
package framework

import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

const (
PreSchedulerPluginType = "PreSchedule"
ProfilePickerType = "ProfilePicker"
PreCyclePluginType = "PreCycle"
FilterPluginType = "Filter"
ScorerPluginType = "Scorer"
PostSchedulePluginType = "PostSchedule"
PickerPluginType = "Picker"
PostCyclePluginType = "PostCycle"
PostResponsePluginType = "PostResponse"
)

Expand All @@ -36,11 +37,18 @@ type Plugin interface {
Name() string
}

// PreSchedule is called when the scheduler receives a new request. It can be used for various
// initialization work.
type PreSchedule interface {
// ProfilePicker selects the SchedulingProfiles to run from a list of candidate profiles, while taking into consideration the request properties
// and the previously executed SchedluderProfile cycles along with their results.
type ProfilePicker interface {
Plugin
PreSchedule(ctx *types.SchedulingContext)
Pick(request *types.LLMRequest, profiles map[string]*SchedulerProfile, executionResults map[string]*types.Result) map[string]*SchedulerProfile
}

// PreCycle is called when the scheduler receives a new request and invokes a SchedulerProfile cycle.
// It can be used for various initialization work.
type PreCycle interface {
Plugin
PreCycle(ctx *types.SchedulingContext)
}

// Filter defines the interface for filtering a list of pods based on context.
Expand All @@ -62,10 +70,10 @@ type Picker interface {
Pick(ctx *types.SchedulingContext, scoredPods []*types.ScoredPod) *types.Result
}

// PostSchedule is called by the scheduler after it selects a targetPod for the request.
type PostSchedule interface {
// PostCycle is called by the scheduler after it selects a targetPod for the request in the SchedulerProfile cycle.
type PostCycle interface {
Plugin
PostSchedule(ctx *types.SchedulingContext, res *types.Result)
PostCycle(ctx *types.SchedulingContext, res *types.Result)
}

// PostResponse is called by the scheduler after a successful response was sent.
Expand Down
15 changes: 15 additions & 0 deletions pkg/epp/scheduling/framework/plugins/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Scheduling Plugins

This package contains the scheduling plugin implementations.

Plugins are organized by the following rule. Follow this rule when adding a new
plugin.

```
plugins/
|__ filter/(Plugins that implement the Filter interface only.)
|__ scorer/ (Plugins that implement the Scorer interface only.)
|__ picker/(Plugins that implement the Picker interface only.)
|__ multi/ (Plugins that implement multiple plugin interfaces.)
|____prefix/ (Prefix cache aware scheduling plugin.)
```
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,31 @@ limitations under the License.
package filter

import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

// compile-time type validation
var _ plugins.Filter = &DecisionTreeFilter{}
// compile-time type assertion
var _ framework.Filter = &DecisionTreeFilter{}

// DecisionTreeFilter applies current fitler, and then recursively applies next filters
// depending success or failure of the current filter.
// It can be used to construct a flow chart algorithm.
type DecisionTreeFilter struct {
Current plugins.Filter
Current framework.Filter
// NextOnSuccess filter will be applied after successfully applying the current filter.
// The filtered results will be passed to the next filter.
NextOnSuccess plugins.Filter
NextOnSuccess framework.Filter
// NextOnFailure filter will be applied if current filter results in no pods.
// The original input will be passed to the next filter.
NextOnFailure plugins.Filter
NextOnFailure framework.Filter
// NextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the
// success or failure of the current filter.
// NOTE: When using NextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil.
// However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of
// NextOnSuccessOrFailure, in the success and failure scenarios, respectively.
NextOnSuccessOrFailure plugins.Filter
NextOnSuccessOrFailure framework.Filter
}

// Name returns the name of the filter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

// compile-time type assertion
var _ framework.Filter = &filterAll{}

type filterAll struct{}

func (f *filterAll) Name() string {
Expand All @@ -44,7 +47,7 @@ func TestFilter(t *testing.T) {
tests := []struct {
name string
req *types.LLMRequest
filter plugins.Filter
filter framework.Filter
input []types.Pod
output []types.Pod
}{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ package filter
import (
"math"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

// compile-time type validation
var _ plugins.Filter = &LeastKVCacheFilter{}
// compile-time type assertion
var _ framework.Filter = &LeastKVCacheFilter{}

// NewLeastKVCacheFilter initializes a new LeastKVCacheFilter and returns its pointer.
func NewLeastKVCacheFilter() *LeastKVCacheFilter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ package filter
import (
"math"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

// compile-time type validation
var _ plugins.Filter = &LeastQueueFilter{}
// compile-time type assertion
var _ framework.Filter = &LeastQueueFilter{}

// NewLeastQueueFilter initializes a new LeastQueueFilter and returns its pointer.
func NewLeastQueueFilter() *LeastQueueFilter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import (
"time"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

// compile-time type validation
var _ plugins.Filter = &LoraAffinityFilter{}
// compile-time type assertion
var _ framework.Filter = &LoraAffinityFilter{}

// NewLoraAffinityFilter initializes a new LoraAffinityFilter and returns its pointer.
func NewLoraAffinityFilter() *LoraAffinityFilter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package filter

import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

// compile-time type validation
var _ plugins.Filter = &LowQueueFilter{}
// compile-time type assertion
var _ framework.Filter = &LowQueueFilter{}

// NewLowQueueFilter initializes a new LowQueueFilter and returns its pointer.
func NewLowQueueFilter() *LowQueueFilter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package filter

import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

// compile-time type validation
var _ plugins.Filter = &SheddableCapacityFilter{}
// compile-time type assertion
var _ framework.Filter = &SheddableCapacityFilter{}

// NewSheddableCapacityFilter initializes a new SheddableCapacityFilter and returns its pointer.
func NewSheddableCapacityFilter() *SheddableCapacityFilter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
"github.com/cespare/xxhash/v2"
k8stypes "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)
Expand Down Expand Up @@ -88,7 +88,7 @@ func (s ServerID) String() string {
return k8stypes.NamespacedName(s).String()
}

// compile-time type validation
// compile-time type assertion
var _ types.StateData = &schedulingContextState{}

// This is the state of this plugin to be used during a scheduling cycle.
Expand All @@ -113,10 +113,10 @@ func (s *schedulingContextState) Clone() types.StateData {
}
}

// compile-time type validation
var _ plugins.PreSchedule = &Plugin{}
var _ plugins.Scorer = &Plugin{}
var _ plugins.PostSchedule = &Plugin{}
// compile-time type assertion
var _ framework.PreCycle = &Plugin{}
var _ framework.Scorer = &Plugin{}
var _ framework.PostCycle = &Plugin{}

// New initializes a new prefix Plugin and returns its pointer.
func New(config Config) *Plugin {
Expand All @@ -132,20 +132,20 @@ func (m *Plugin) Name() string {
return "prefix-cache"
}

// PreSchedule initializes the prefix plugin state for the current scheduling cycle.
func (m *Plugin) PreSchedule(ctx *types.SchedulingContext) {
// PreCycle initializes the prefix plugin state for the current scheduling cycle.
func (m *Plugin) PreCycle(ctx *types.SchedulingContext) {
hashes := hashPrompt(ctx, m.HashBlockSize, m.MaxPrefixBlocksToMatch)
state := &schedulingContextState{
PrefixHashes: hashes,
PrefixCacheServers: m.matchLongestPrefix(ctx, hashes, DefaultNumServersToMatch),
}

ctx.CycleState.Write(types.StateKey(m.Name()), state)
ctx.Logger.V(logutil.TRACE).Info(fmt.Sprintf("PreSchedule, cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes)
ctx.Logger.V(logutil.TRACE).Info(fmt.Sprintf("PreCycle, cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes)
}

// PostSchedule records in the plugin cache the result of the scheduling selection.
func (m *Plugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) {
// PostCycle records in the plugin cache the result of the scheduling selection.
func (m *Plugin) PostCycle(ctx *types.SchedulingContext, res *types.Result) {
targetPod := res.TargetPod.GetPod()
state, err := m.getPrefixState(ctx.CycleState)
if err != nil {
Expand Down
Loading