Skip to content

Commit bc29bd0

Browse files
authored
generalize scheduling cycle state concept (#818)
* generalize scheduling cycle state concept Signed-off-by: Nir Rozenbaum <[email protected]> * typo Signed-off-by: Nir Rozenbaum <[email protected]> * make linter happy Signed-off-by: Nir Rozenbaum <[email protected]> * make prefix state struct internal to package instead of public Signed-off-by: Nir Rozenbaum <[email protected]> --------- Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent 62f226c commit bc29bd0

File tree

5 files changed

+203
-60
lines changed

5 files changed

+203
-60
lines changed

pkg/epp/scheduling/plugins/prefix/plugin.go

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,21 +78,37 @@ type Indexer interface {
7878
Add(hashes []BlockHash, server ServerID)
7979
}
8080

81+
// BlockHash is a hash of the block of request body.
82+
type BlockHash uint64
83+
84+
type ServerID k8stypes.NamespacedName
85+
86+
func (s ServerID) String() string {
87+
return k8stypes.NamespacedName(s).String()
88+
}
89+
90+
var _ types.StateData = &schedulingContextState{}
91+
8192
// This is the state of this plugin to be used during a scheduling cycle.
82-
type SchedulingContextState struct {
93+
type schedulingContextState struct {
8394
// PrefixHashes is a list of prefix hashes of the request prompt broken into blocks.
8495
PrefixHashes []BlockHash
8596
// A map of server to its longest prefix cache match length.
8697
PrefixCacheServers map[ServerID]int
8798
}
8899

89-
// BlockHash is a hash of the block of request body.
90-
type BlockHash uint64
91-
92-
type ServerID k8stypes.NamespacedName
100+
func (s *schedulingContextState) Clone() types.StateData {
101+
prefixHashes := make([]BlockHash, len(s.PrefixHashes))
102+
copy(prefixHashes, s.PrefixHashes)
103+
prefixCacheServers := make(map[ServerID]int, len(s.PrefixCacheServers))
104+
for key, value := range s.PrefixCacheServers {
105+
prefixCacheServers[key] = value
106+
}
93107

94-
func (s ServerID) String() string {
95-
return k8stypes.NamespacedName(s).String()
108+
return &schedulingContextState{
109+
PrefixHashes: prefixHashes,
110+
PrefixCacheServers: prefixCacheServers,
111+
}
96112
}
97113

98114
func New(config Config) *Plugin {
@@ -104,31 +120,43 @@ func New(config Config) *Plugin {
104120
}
105121

106122
func (m *Plugin) Name() string {
107-
return "prefixCache"
123+
return "prefix-cache"
108124
}
109125

110126
func (m *Plugin) PreSchedule(ctx *types.SchedulingContext) {
111127
hashes := hashPrompt(ctx, m.HashBlockSize, m.MaxPrefixBlocksToMatch)
112-
state := SchedulingContextState{
128+
state := &schedulingContextState{
113129
PrefixHashes: hashes,
114130
PrefixCacheServers: m.matchLongestPrefix(ctx, hashes, DefaultNumServersToMatch),
115131
}
116-
ctx.SetPluginState(types.PluginName(m.Name()), state)
132+
133+
ctx.CycleState.Write(types.StateKey(m.Name()), state)
117134
ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("PreSchedule, cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes)
118135
}
119136

120137
// If a request was routed to a server, record it in the cache:
121138
func (m *Plugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) {
122139
targetPod := res.TargetPod.GetPod()
123-
state := ctx.GetPluginState(types.PluginName(m.Name())).(SchedulingContextState)
140+
state, err := m.getPrefixState(ctx.CycleState)
141+
if err != nil {
142+
ctx.Logger.Error(err, "failed to read prefix plugin cycle state")
143+
return
144+
}
124145
m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName))
125146
total := len(state.PrefixHashes)
126147
matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)]
127148
metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize)
128149
}
129150

130151
func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 {
131-
state := ctx.GetPluginState(types.PluginName(m.Name())).(SchedulingContextState)
152+
scores := make(map[types.Pod]float64, len(pods))
153+
154+
state, err := m.getPrefixState(ctx.CycleState)
155+
if err != nil {
156+
ctx.Logger.Error(err, "failed to read prefix plugin cycle state")
157+
return scores
158+
}
159+
132160
total := len(state.PrefixHashes)
133161
podScoreFunc := func(pod types.Pod) float64 {
134162
if total == 0 {
@@ -138,7 +166,6 @@ func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types
138166
return float64(matchLen) / float64(total)
139167
}
140168

141-
scores := make(map[types.Pod]float64, len(pods))
142169
for _, pod := range pods {
143170
scores[pod] = podScoreFunc(pod)
144171
}
@@ -170,6 +197,21 @@ func (m *Plugin) matchLongestPrefix(ctx *types.SchedulingContext, hashes []Block
170197
return res
171198
}
172199

200+
func (m *Plugin) getPrefixState(cycleState *types.CycleState) (*schedulingContextState, error) {
201+
prefixStateKey := types.StateKey(m.Name())
202+
state, err := cycleState.Read(prefixStateKey)
203+
if err != nil {
204+
return nil, fmt.Errorf("failed reading %q from CycleState: %w", prefixStateKey, err)
205+
}
206+
207+
prefixSchedulingState, ok := state.(*schedulingContextState)
208+
if !ok {
209+
return nil, fmt.Errorf("invalid Prefix state, got type %T", state)
210+
}
211+
212+
return prefixSchedulingState, nil
213+
}
214+
173215
// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
174216
// hash(0) is the hash of the model name, since different models generally don't share prefix cache.
175217
// For block i, hash(i) = hash(block i content, hash(i-1)).

pkg/epp/scheduling/plugins/prefix/plugin_test.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ func TestPrefixPlugin(t *testing.T) {
3030
}
3131
ctx := types.NewSchedulingContext(context.Background(), req1, pods)
3232
plugin.PreSchedule(ctx)
33-
state := ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState)
33+
state, err := plugin.getPrefixState(ctx.CycleState)
34+
assert.NoError(t, err)
3435
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
3536
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
3637
// Total hashes = 2 (the first one is for the model)
@@ -54,7 +55,8 @@ func TestPrefixPlugin(t *testing.T) {
5455
}
5556
ctx = types.NewSchedulingContext(context.Background(), req2, pods)
5657
plugin.PreSchedule(ctx)
57-
state = ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState)
58+
state, err = plugin.getPrefixState(ctx.CycleState)
59+
assert.NoError(t, err)
5860
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
5961
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
6062
// Total hashes = 2 (the first one is for the model)
@@ -77,7 +79,8 @@ func TestPrefixPlugin(t *testing.T) {
7779
}
7880
ctx = types.NewSchedulingContext(context.Background(), req3, pods)
7981
plugin.PreSchedule(ctx)
80-
state = ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState)
82+
state, err = plugin.getPrefixState(ctx.CycleState)
83+
assert.NoError(t, err)
8184
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
8285
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
8386
// Total hashes = 3 (the first one is for the model)
@@ -99,7 +102,8 @@ func TestPrefixPlugin(t *testing.T) {
99102
}
100103
ctx = types.NewSchedulingContext(context.Background(), req4, pods)
101104
plugin.PreSchedule(ctx)
102-
state = ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState)
105+
state, err = plugin.getPrefixState(ctx.CycleState)
106+
assert.NoError(t, err)
103107
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
104108
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
105109
// Total hashes = 3 (the first one is for the model)
@@ -121,7 +125,8 @@ func TestPrefixPlugin(t *testing.T) {
121125
}
122126
ctx = types.NewSchedulingContext(context.Background(), req5, pods)
123127
plugin.PreSchedule(ctx)
124-
state = ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState)
128+
state, err = plugin.getPrefixState(ctx.CycleState)
129+
assert.NoError(t, err)
125130
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
126131
// Input size is 12, hash block size is 4, so 3 hashes will be calculated.
127132
// Total hashes = 4 (the first one is for the model)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package types
18+
19+
import (
20+
"errors"
21+
"sync"
22+
)
23+
24+
var (
25+
// ErrNotFound is the not found error message.
26+
ErrNotFound = errors.New("not found")
27+
)
28+
29+
// StateData is a generic type for arbitrary data stored in CycleState.
30+
type StateData interface {
31+
// Clone is an interface to make a copy of StateData.
32+
Clone() StateData
33+
}
34+
35+
// StateKey is the type of keys stored in CycleState.
36+
type StateKey string
37+
38+
// NewCycleState initializes a new CycleState and returns its pointer.
39+
func NewCycleState() *CycleState {
40+
return &CycleState{}
41+
}
42+
43+
// CycleState provides a mechanism for plugins to store and retrieve arbitrary data.
44+
// StateData stored by one plugin can be read, altered, or deleted by another plugin.
45+
// CycleState does not provide any data protection, as all plugins are assumed to be
46+
// trusted.
47+
// Note: CycleState uses a sync.Map to back the storage, because it is thread safe. It's aimed to optimize for the "write once and read many times" scenarios.
48+
type CycleState struct {
49+
// key: StateKey, value: StateData
50+
storage sync.Map
51+
}
52+
53+
// Clone creates a copy of CycleState and returns its pointer. Clone returns
54+
// nil if the context being cloned is nil.
55+
func (c *CycleState) Clone() *CycleState {
56+
if c == nil {
57+
return nil
58+
}
59+
copy := NewCycleState()
60+
// Safe copy storage in case of overwriting.
61+
c.storage.Range(func(k, v interface{}) bool {
62+
copy.storage.Store(k, v.(StateData).Clone())
63+
return true
64+
})
65+
66+
return copy
67+
}
68+
69+
// Read retrieves data with the given "key" from CycleState. If the key is not
70+
// present, ErrNotFound is returned.
71+
//
72+
// See CycleState for notes on concurrency.
73+
func (c *CycleState) Read(key StateKey) (StateData, error) {
74+
if v, ok := c.storage.Load(key); ok {
75+
return v.(StateData), nil
76+
}
77+
return nil, ErrNotFound
78+
}
79+
80+
// Write stores the given "val" in CycleState with the given "key".
81+
//
82+
// See CycleState for notes on concurrency.
83+
func (c *CycleState) Write(key StateKey, val StateData) {
84+
c.storage.Store(key, val)
85+
}
86+
87+
// Delete deletes data with the given key from CycleState.
88+
//
89+
// See CycleState for notes on concurrency.
90+
func (c *CycleState) Delete(key StateKey) {
91+
c.storage.Delete(key)
92+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package types
18+
19+
import (
20+
"context"
21+
22+
"github.com/go-logr/logr"
23+
"sigs.k8s.io/controller-runtime/pkg/log"
24+
)
25+
26+
func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext {
27+
logger := log.FromContext(ctx).WithValues("request", req)
28+
return &SchedulingContext{
29+
Context: ctx,
30+
Logger: logger,
31+
Req: req,
32+
PodsSnapshot: pods,
33+
CycleState: NewCycleState(),
34+
}
35+
}
36+
37+
// SchedulingContext holds contextual information during a scheduling operation.
38+
type SchedulingContext struct {
39+
context.Context
40+
Logger logr.Logger
41+
Req *LLMRequest
42+
PodsSnapshot []Pod
43+
// CycleState can be used by plugins to store state during a scheduling cycle, to communicate
44+
// between different extension points.
45+
CycleState *CycleState
46+
}

pkg/epp/scheduling/types/types.go

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,8 @@ limitations under the License.
1717
package types
1818

1919
import (
20-
"context"
2120
"fmt"
22-
"sync"
2321

24-
"github.com/go-logr/logr"
25-
"sigs.k8s.io/controller-runtime/pkg/log"
2622
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2723
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
2824
)
@@ -57,32 +53,6 @@ type ScoredPod struct {
5753
Score float64
5854
}
5955

60-
// SchedulingContext holds contextual information during a scheduling operation.
61-
type SchedulingContext struct {
62-
context.Context
63-
Logger logr.Logger
64-
Req *LLMRequest
65-
PodsSnapshot []Pod
66-
// PluginState can be used by plugins to store state during a scheduling cycle, to communicate
67-
// between different extension points.
68-
PluginState map[PluginName]any
69-
pluginStateMu *sync.RWMutex
70-
}
71-
72-
func (sc *SchedulingContext) GetPluginState(pluginName PluginName) any {
73-
sc.pluginStateMu.RLock()
74-
defer sc.pluginStateMu.RUnlock()
75-
return sc.PluginState[pluginName]
76-
}
77-
78-
func (sc *SchedulingContext) SetPluginState(pluginName PluginName, state any) {
79-
sc.pluginStateMu.Lock()
80-
defer sc.pluginStateMu.Unlock()
81-
sc.PluginState[pluginName] = state
82-
}
83-
84-
type PluginName string
85-
8656
func (pm *PodMetrics) String() string {
8757
if pm == nil {
8858
return ""
@@ -103,18 +73,6 @@ type PodMetrics struct {
10373
*backendmetrics.Metrics
10474
}
10575

106-
func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext {
107-
logger := log.FromContext(ctx).WithValues("request", req)
108-
return &SchedulingContext{
109-
Context: ctx,
110-
Logger: logger,
111-
Req: req,
112-
PodsSnapshot: pods,
113-
PluginState: make(map[PluginName]any),
114-
pluginStateMu: &sync.RWMutex{},
115-
}
116-
}
117-
11876
func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod {
11977
pm := make([]Pod, 0, len(pods))
12078
for _, pod := range pods {

0 commit comments

Comments
 (0)