@@ -250,33 +250,37 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
250
250
// For block i, hash(i) = hash(block i content, hash(i-1)).
251
251
func hashPrompt (ctx context.Context , request * types.LLMRequest , cacheBlockSize int , maxPrefixBlocks int ) []BlockHash {
252
252
loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
253
- prompt := []byte (request .Prompt )
254
- if len (prompt ) < cacheBlockSize {
255
- loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (prompt ), "block size" , cacheBlockSize )
253
+ if request == nil || request .Data == nil {
254
+ loggerDebug .Info ("Request or request data is nil, skipping hashing" )
256
255
return nil
257
256
}
258
- if len (prompt ) > cacheBlockSize * maxPrefixBlocks {
259
- loggerDebug .Info ("Truncating input" , "size" , len (prompt ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
260
- prompt = prompt [:maxPrefixBlocks * cacheBlockSize ]
257
+
258
+ userInput , err := getUserInputBytes (request )
259
+ if err != nil {
260
+ loggerDebug .Error (err , "Failed to get user input bytes" )
261
+ return nil
262
+ }
263
+
264
+ if len (userInput ) < cacheBlockSize {
265
+ loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (userInput ), "block size" , cacheBlockSize )
266
+ return nil
267
+ }
268
+ if len (userInput ) > cacheBlockSize * maxPrefixBlocks {
269
+ loggerDebug .Info ("Truncating input" , "size" , len (userInput ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
270
+ userInput = userInput [:maxPrefixBlocks * cacheBlockSize ]
261
271
}
262
272
// Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
263
273
// If the last block is smaller than cacheBlockSize, it will be ignored.
264
- res := make ([]BlockHash , 0 , 1 + len (prompt )/ cacheBlockSize )
274
+ res := make ([]BlockHash , 0 , 1 + len (userInput )/ cacheBlockSize )
265
275
// Add the model to the first block hash so that different models have different hashes even with the same body.
266
-
267
- firstBlockSize := cacheBlockSize
268
- if len (prompt ) < cacheBlockSize {
269
- firstBlockSize = len (prompt )
270
- }
271
- firstBlock := prompt [0 :firstBlockSize ]
272
- firstBlockWithModel := append ([]byte (request .TargetModel ), firstBlock ... )
273
- res = append (res , BlockHash (xxhash .Sum64 (firstBlockWithModel )))
274
-
275
- for i := cacheBlockSize ; i + cacheBlockSize <= len (prompt ); i += cacheBlockSize {
276
- block := prompt [i : i + cacheBlockSize ]
276
+ res = append (res , BlockHash (xxhash .Sum64String (request .TargetModel )))
277
+ h := xxhash .New ()
278
+ for i := 0 ; i + cacheBlockSize <= len (userInput ); i += cacheBlockSize {
279
+ h .Reset ()
280
+ _ , _ = h .Write (userInput [i : i + cacheBlockSize ])
277
281
prevBlockHash := res [len (res )- 1 ]
278
- block = append ( block , toBytes (prevBlockHash )... )
279
- res = append (res , BlockHash (xxhash .Sum64 (block )))
282
+ _ , _ = h . Write ( toBytes (prevBlockHash ))
283
+ res = append (res , BlockHash (h .Sum64 ()))
280
284
}
281
285
return res
282
286
}
@@ -286,3 +290,12 @@ func toBytes(i BlockHash) []byte {
286
290
binary .LittleEndian .PutUint64 (bytes , uint64 (i ))
287
291
return bytes
288
292
}
293
+
294
+ func getUserInputBytes (request * types.LLMRequest ) ([]byte , error ) {
295
+ if request .Data .Completions != nil { // assumed to be valid if not nil
296
+ return []byte (request .Data .Completions .Prompt ), nil
297
+ }
298
+
299
+ // must be chat-completions request at this point, return bytes of entire messages
300
+ return json .Marshal (request .Data .ChatCompletions .Messages )
301
+ }
0 commit comments