@@ -123,8 +123,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
123
123
}
124
124
125
125
// compile-time type assertion
126
- var _ framework.Scorer = & Plugin {}
127
- var _ requestcontrol.PreRequest = & Plugin {}
126
+ var (
127
+ _ framework.Scorer = & Plugin {}
128
+ _ requestcontrol.PreRequest = & Plugin {}
129
+ )
128
130
129
131
// PrefixCachePluginFactory defines the factory function for Prefix plugin.
130
132
func PrefixCachePluginFactory (name string , rawParameters json.RawMessage , handle plugins.Handle ) (plugins.Plugin , error ) {
@@ -238,7 +240,6 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
238
240
for server := range cachedServers {
239
241
// Update servers with their longest prefix match.
240
242
res [server ]++
241
-
242
243
}
243
244
}
244
245
}
@@ -250,33 +251,39 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
250
251
// For block i, hash(i) = hash(block i content, hash(i-1)).
251
252
func hashPrompt (ctx context.Context , request * types.LLMRequest , cacheBlockSize int , maxPrefixBlocks int ) []BlockHash {
252
253
loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
253
- prompt := []byte (request .Prompt )
254
- if len (prompt ) < cacheBlockSize {
255
- loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (prompt ), "block size" , cacheBlockSize )
254
+ if request == nil || request .Data == nil {
255
+ loggerDebug .Info ("Request or request data is nil, skipping hashing" )
256
256
return nil
257
257
}
258
- if len (prompt ) > cacheBlockSize * maxPrefixBlocks {
259
- loggerDebug .Info ("Truncating input" , "size" , len (prompt ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
260
- prompt = prompt [:maxPrefixBlocks * cacheBlockSize ]
258
+
259
+ userInput , err := getUserInputBytes (request )
260
+ if err != nil {
261
+ loggerDebug .Error (err , "Failed to get user input bytes" )
262
+ return nil
261
263
}
262
- // Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
263
- // If the last block is smaller than cacheBlockSize, it will be ignored.
264
- res := make ([]BlockHash , 0 , 1 + len (prompt )/ cacheBlockSize )
265
- // Add the model to the first block hash so that different models have different hashes even with the same body.
266
264
267
- firstBlockSize := cacheBlockSize
268
- if len (prompt ) < cacheBlockSize {
269
- firstBlockSize = len ( prompt )
265
+ if len ( userInput ) < cacheBlockSize {
266
+ loggerDebug . Info ( "Request body too small for prefix cache" , "size" , len (userInput ), "block size" , cacheBlockSize )
267
+ return nil
270
268
}
271
- firstBlock := prompt [0 :firstBlockSize ]
272
- firstBlockWithModel := append ([]byte (request .TargetModel ), firstBlock ... )
273
- res = append (res , BlockHash (xxhash .Sum64 (firstBlockWithModel )))
274
-
275
- for i := cacheBlockSize ; i + cacheBlockSize <= len (prompt ); i += cacheBlockSize {
276
- block := prompt [i : i + cacheBlockSize ]
277
- prevBlockHash := res [len (res )- 1 ]
278
- block = append (block , toBytes (prevBlockHash )... )
279
- res = append (res , BlockHash (xxhash .Sum64 (block )))
269
+ if len (userInput ) > cacheBlockSize * maxPrefixBlocks {
270
+ loggerDebug .Info ("Truncating input" , "size" , len (userInput ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
271
+ userInput = userInput [:maxPrefixBlocks * cacheBlockSize ]
272
+ }
273
+ // Split the body into blocks of size cacheBlockSize.
274
+ // If the last block is smaller than cacheBlockSize, it will be ignored.
275
+ res := make ([]BlockHash , 0 , len (userInput )/ cacheBlockSize )
276
+ // Add the model to the first block hash so that different models have different hashes even with the same body.
277
+ h := xxhash .New ()
278
+ _ , _ = h .Write ([]byte (request .TargetModel ))
279
+ prevBlockHash := BlockHash (h .Sum64 ())
280
+ for i := 0 ; i + cacheBlockSize <= len (userInput ); i += cacheBlockSize {
281
+ h .Reset ()
282
+ _ , _ = h .Write (userInput [i : i + cacheBlockSize ])
283
+ _ , _ = h .Write (toBytes (prevBlockHash ))
284
+ res = append (res , BlockHash (h .Sum64 ()))
285
+
286
+ prevBlockHash = res [len (res )- 1 ]
280
287
}
281
288
return res
282
289
}
@@ -286,3 +293,12 @@ func toBytes(i BlockHash) []byte {
286
293
binary .LittleEndian .PutUint64 (bytes , uint64 (i ))
287
294
return bytes
288
295
}
296
+
297
+ func getUserInputBytes (request * types.LLMRequest ) ([]byte , error ) {
298
+ if request .Data .Completions != nil { // assumed to be valid if not nil
299
+ return []byte (request .Data .Completions .Prompt ), nil
300
+ }
301
+
302
+ // must be chat-completions request at this point, return bytes of entire messages
303
+ return json .Marshal (request .Data .ChatCompletions .Messages )
304
+ }
0 commit comments