@@ -125,8 +125,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
125
125
}
126
126
127
127
// compile-time type assertion
128
- var _ framework.Scorer = & Plugin {}
129
- var _ requestcontrol.PreRequest = & Plugin {}
128
+ var (
129
+ _ framework.Scorer = & Plugin {}
130
+ _ requestcontrol.PreRequest = & Plugin {}
131
+ )
130
132
131
133
// PrefixCachePluginFactory defines the factory function for Prefix plugin.
132
134
func PrefixCachePluginFactory (name string , rawParameters json.RawMessage , handle plugins.Handle ) (plugins.Plugin , error ) {
@@ -248,7 +250,6 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
248
250
for server := range cachedServers {
249
251
// Update servers with their longest prefix match.
250
252
res [server ]++
251
-
252
253
}
253
254
}
254
255
}
@@ -260,33 +261,39 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
260
261
// For block i, hash(i) = hash(block i content, hash(i-1)).
261
262
func hashPrompt (ctx context.Context , request * types.LLMRequest , cacheBlockSize int , maxPrefixBlocks int ) []BlockHash {
262
263
loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
263
- prompt := []byte (request .Prompt )
264
- if len (prompt ) < cacheBlockSize {
265
- loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (prompt ), "block size" , cacheBlockSize )
264
+ if request == nil || request .Body == nil {
265
+ loggerDebug .Info ("Request or request data is nil, skipping hashing" )
266
266
return nil
267
267
}
268
- if len (prompt ) > cacheBlockSize * maxPrefixBlocks {
269
- loggerDebug .Info ("Truncating input" , "size" , len (prompt ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
270
- prompt = prompt [:maxPrefixBlocks * cacheBlockSize ]
268
+
269
+ userInput , err := getUserInputBytes (request )
270
+ if err != nil {
271
+ loggerDebug .Error (err , "Failed to get user input bytes" )
272
+ return nil
271
273
}
272
- // Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
273
- // If the last block is smaller than cacheBlockSize, it will be ignored.
274
- res := make ([]BlockHash , 0 , 1 + len (prompt )/ cacheBlockSize )
275
- // Add the model to the first block hash so that different models have different hashes even with the same body.
276
274
277
- firstBlockSize := cacheBlockSize
278
- if len (prompt ) < cacheBlockSize {
279
- firstBlockSize = len ( prompt )
275
+ if len ( userInput ) < cacheBlockSize {
276
+ loggerDebug . Info ( "Request body too small for prefix cache" , "size" , len (userInput ), "block size" , cacheBlockSize )
277
+ return nil
280
278
}
281
- firstBlock := prompt [0 :firstBlockSize ]
282
- firstBlockWithModel := append ([]byte (request .TargetModel ), firstBlock ... )
283
- res = append (res , BlockHash (xxhash .Sum64 (firstBlockWithModel )))
284
-
285
- for i := cacheBlockSize ; i + cacheBlockSize <= len (prompt ); i += cacheBlockSize {
286
- block := prompt [i : i + cacheBlockSize ]
287
- prevBlockHash := res [len (res )- 1 ]
288
- block = append (block , toBytes (prevBlockHash )... )
289
- res = append (res , BlockHash (xxhash .Sum64 (block )))
279
+ if len (userInput ) > cacheBlockSize * maxPrefixBlocks {
280
+ loggerDebug .Info ("Truncating input" , "size" , len (userInput ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
281
+ userInput = userInput [:maxPrefixBlocks * cacheBlockSize ]
282
+ }
283
+ // Split the body into blocks of size cacheBlockSize.
284
+ // If the last block is smaller than cacheBlockSize, it will be ignored.
285
+ res := make ([]BlockHash , 0 , len (userInput )/ cacheBlockSize )
286
+ // Add the model to the first block hash so that different models have different hashes even with the same body.
287
+ h := xxhash .New ()
288
+ _ , _ = h .Write ([]byte (request .TargetModel ))
289
+ prevBlockHash := BlockHash (h .Sum64 ())
290
+ for i := 0 ; i + cacheBlockSize <= len (userInput ); i += cacheBlockSize {
291
+ h .Reset ()
292
+ _ , _ = h .Write (userInput [i : i + cacheBlockSize ])
293
+ _ , _ = h .Write (toBytes (prevBlockHash ))
294
+ res = append (res , BlockHash (h .Sum64 ()))
295
+
296
+ prevBlockHash = res [len (res )- 1 ]
290
297
}
291
298
return res
292
299
}
@@ -296,3 +303,12 @@ func toBytes(i BlockHash) []byte {
296
303
binary .LittleEndian .PutUint64 (bytes , uint64 (i ))
297
304
return bytes
298
305
}
306
+
307
+ func getUserInputBytes (request * types.LLMRequest ) ([]byte , error ) {
308
+ if request .Body .Completions != nil { // assumed to be valid if not nil
309
+ return []byte (request .Body .Completions .Prompt ), nil
310
+ }
311
+
312
+ // must be chat-completions request at this point, return bytes of entire messages
313
+ return json .Marshal (request .Body .ChatCompletions .Messages )
314
+ }
0 commit comments