@@ -17,11 +17,13 @@ limitations under the License.
1717package prefix
1818
1919import (
20+ "context"
2021 "encoding/binary"
2122 "fmt"
2223
2324 "github.com/cespare/xxhash/v2"
2425 k8stypes "k8s.io/apimachinery/pkg/types"
26+ "sigs.k8s.io/controller-runtime/pkg/log"
2527 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
2628 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
2729 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
@@ -131,24 +133,11 @@ func (m *Plugin) Name() string {
131133 return "prefix-cache"
132134}
133135
134- // PostCycle records in the plugin cache the result of the scheduling selection.
135- func (m * Plugin ) PostCycle (ctx * types.SchedulingContext , res * types.Result ) {
136- targetPod := res .TargetPod .GetPod ()
137- state , err := m .getPrefixState (ctx .CycleState )
138- if err != nil {
139- ctx .Logger .Error (err , "failed to read prefix plugin cycle state" )
140- return
141- }
142- m .indexer .Add (state .PrefixHashes , ServerID (targetPod .NamespacedName ))
143- total := len (state .PrefixHashes )
144- matchLen := state .PrefixCacheServers [ServerID (targetPod .NamespacedName )]
145- metrics .RecordPrefixCacheMatch (matchLen * m .HashBlockSize , total * m .HashBlockSize )
146- }
147-
148136// Score returns the scoring result for the given list of pods based on context.
149- func (m * Plugin ) Score (ctx * types.SchedulingContext , pods []types.Pod ) map [types.Pod ]float64 {
137+ func (m * Plugin ) Score (ctx context.Context , request * types.LLMRequest , cycleState * types.CycleState , pods []types.Pod ) map [types.Pod ]float64 {
138+ loggerTrace := log .FromContext (ctx ).V (logutil .TRACE )
150139 // pre score step, hashing prompt and find longest prefix match.
151- hashes := hashPrompt (ctx , m .HashBlockSize , m .MaxPrefixBlocksToMatch )
140+ hashes := hashPrompt (ctx , request , m .HashBlockSize , m .MaxPrefixBlocksToMatch )
152141 numServers := DefaultNumServersToMatch
153142 if numServers > len (pods ) {
154143 numServers = len (pods )
@@ -157,8 +146,8 @@ func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types
157146 PrefixHashes : hashes ,
158147 PrefixCacheServers : m .matchLongestPrefix (ctx , hashes , numServers ),
159148 }
160- ctx . CycleState .Write (types .StateKey (m .Name ()), state )
161- ctx . Logger . V ( logutil . TRACE ) .Info (fmt .Sprintf ("cached servers: %+v" , state .PrefixCacheServers ), "hashes" , state .PrefixHashes )
149+ cycleState .Write (types .StateKey (m .Name ()), state )
150+ loggerTrace .Info (fmt .Sprintf ("cached servers: %+v" , state .PrefixCacheServers ), "hashes" , state .PrefixHashes )
162151 // calculate the scores of pods
163152 scores := make (map [types.Pod ]float64 , len (pods ))
164153
@@ -177,16 +166,31 @@ func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types
177166 return scores
178167}
179168
169+ // PostCycle records in the plugin cache the result of the scheduling selection.
170+ func (m * Plugin ) PostCycle (ctx context.Context , cycleState * types.CycleState , res * types.Result ) {
171+ targetPod := res .TargetPod .GetPod ()
172+ state , err := m .getPrefixState (cycleState )
173+ if err != nil {
174+ log .FromContext (ctx ).Error (err , "failed to read prefix plugin cycle state" )
175+ return
176+ }
177+ m .indexer .Add (state .PrefixHashes , ServerID (targetPod .NamespacedName ))
178+ total := len (state .PrefixHashes )
179+ matchLen := state .PrefixCacheServers [ServerID (targetPod .NamespacedName )]
180+ metrics .RecordPrefixCacheMatch (matchLen * m .HashBlockSize , total * m .HashBlockSize )
181+ }
182+
180183// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
181- func (m * Plugin ) matchLongestPrefix (ctx * types.SchedulingContext , hashes []BlockHash , numServers int ) map [ServerID ]int {
184+ func (m * Plugin ) matchLongestPrefix (ctx context.Context , hashes []BlockHash , numServers int ) map [ServerID ]int {
185+ loggerTrace := log .FromContext (ctx ).V (logutil .TRACE )
182186 res := make (map [ServerID ]int )
183187 // Use a greedy strategy to search from the longest prefix.
184188 // NOTE: It's possible to further optimize this with a binary search.
185189 for i := len (hashes ) - 1 ; i >= 0 && len (res ) < numServers ; i -- {
186190 hash := hashes [i ]
187191 cachedServers := m .indexer .Get (hash )
188192 if len (cachedServers ) > 0 {
189- ctx . Logger . V ( logutil . TRACE ) .Info ("Found cached servers" , "cachedServers" , cachedServers , "total # blocks" , len (hashes ), "longest prefix" , i )
193+ loggerTrace .Info ("Found cached servers" , "cachedServers" , cachedServers , "total # blocks" , len (hashes ), "longest prefix" , i )
190194 for server := range cachedServers {
191195 // Update servers with their longest prefix match.
192196 // If we already found this server with longer prefix match, don't update it.
@@ -218,21 +222,22 @@ func (m *Plugin) getPrefixState(cycleState *types.CycleState) (*schedulingContex
218222// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
219223// hash(0) is the hash of the model name, since different models generally don't share prefix cache.
220224// For block i, hash(i) = hash(block i content, hash(i-1)).
221- func hashPrompt (ctx * types.SchedulingContext , cacheBlockSize int , maxPrefixBlocks int ) []BlockHash {
222- prompt := []byte (ctx .Req .Prompt )
225+ func hashPrompt (ctx context.Context , request * types.LLMRequest , cacheBlockSize int , maxPrefixBlocks int ) []BlockHash {
226+ loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
227+ prompt := []byte (request .Prompt )
223228 if len (prompt ) < cacheBlockSize {
224- ctx . Logger . V ( logutil . DEBUG ) .Info ("Request body too small for prefix cache" , "size" , len (prompt ), "block size" , cacheBlockSize )
229+ loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (prompt ), "block size" , cacheBlockSize )
225230 return nil
226231 }
227232 if len (prompt ) > cacheBlockSize * maxPrefixBlocks {
228- ctx . Logger . V ( logutil . DEBUG ) .Info ("Truncating input" , "size" , len (prompt ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
233+ loggerDebug .Info ("Truncating input" , "size" , len (prompt ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
229234 prompt = prompt [:maxPrefixBlocks * cacheBlockSize ]
230235 }
231236 // Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
232237 // If the last block is smaller than cacheBlockSize, it will be ignored.
233238 res := make ([]BlockHash , 0 , 1 + len (prompt )/ cacheBlockSize )
234239 // Add the model to the first block hash so that different models have different hashes even with the same body.
235- res = append (res , BlockHash (xxhash .Sum64String (ctx . Req .TargetModel )))
240+ res = append (res , BlockHash (xxhash .Sum64String (request .TargetModel )))
236241 for i := 0 ; i + cacheBlockSize <= len (prompt ); i += cacheBlockSize {
237242 block := prompt [i : i + cacheBlockSize ]
238243 prevBlockHash := res [len (res )- 1 ]
0 commit comments