@@ -250,33 +250,39 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
250
250
// For block i, hash(i) = hash(block i content, hash(i-1)).
251
251
func hashPrompt (ctx context.Context , request * types.LLMRequest , cacheBlockSize int , maxPrefixBlocks int ) []BlockHash {
252
252
loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
253
- prompt := []byte (request .Prompt )
254
- if len (prompt ) < cacheBlockSize {
255
- loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (prompt ), "block size" , cacheBlockSize )
253
+ if request == nil || request .Data == nil {
254
+ loggerDebug .Info ("Request or request data is nil, skipping hashing" )
256
255
return nil
257
256
}
258
- if len (prompt ) > cacheBlockSize * maxPrefixBlocks {
259
- loggerDebug .Info ("Truncating input" , "size" , len (prompt ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
260
- prompt = prompt [:maxPrefixBlocks * cacheBlockSize ]
257
+
258
+ userInput , err := getUserInputBytes (request )
259
+ if err != nil {
260
+ loggerDebug .Error (err , "Failed to get user input bytes" )
261
+ return nil
262
+ }
263
+
264
+ if len (userInput ) < cacheBlockSize {
265
+ loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (userInput ), "block size" , cacheBlockSize )
266
+ return nil
267
+ }
268
+ if len (userInput ) > cacheBlockSize * maxPrefixBlocks {
269
+ loggerDebug .Info ("Truncating input" , "size" , len (userInput ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
270
+ userInput = userInput [:maxPrefixBlocks * cacheBlockSize ]
261
271
}
262
272
// Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
263
273
// If the last block is smaller than cacheBlockSize, it will be ignored.
264
- res := make ([]BlockHash , 0 , 1 + len (prompt )/ cacheBlockSize )
274
+ res := make ([]BlockHash , 0 , 1 + len (userInput )/ cacheBlockSize )
265
275
// Add the model to the first block hash so that different models have different hashes even with the same body.
266
-
267
- firstBlockSize := cacheBlockSize
268
- if len (prompt ) < cacheBlockSize {
269
- firstBlockSize = len (prompt )
270
- }
271
- firstBlock := prompt [0 :firstBlockSize ]
272
- firstBlockWithModel := append ([]byte (request .TargetModel ), firstBlock ... )
273
- res = append (res , BlockHash (xxhash .Sum64 (firstBlockWithModel )))
274
-
275
- for i := cacheBlockSize ; i + cacheBlockSize <= len (prompt ); i += cacheBlockSize {
276
- block := prompt [i : i + cacheBlockSize ]
277
- prevBlockHash := res [len (res )- 1 ]
278
- block = append (block , toBytes (prevBlockHash )... )
279
- res = append (res , BlockHash (xxhash .Sum64 (block )))
276
+ h := xxhash .New ()
277
+ _ , _ = h .Write ([]byte (request .TargetModel ))
278
+ prevBlockHash := BlockHash (h .Sum64 ())
279
+ for i := 0 ; i + cacheBlockSize <= len (userInput ); i += cacheBlockSize {
280
+ h .Reset ()
281
+ _ , _ = h .Write (userInput [i : i + cacheBlockSize ])
282
+ _ , _ = h .Write (toBytes (prevBlockHash ))
283
+ res = append (res , BlockHash (h .Sum64 ()))
284
+
285
+ prevBlockHash = res [len (res )- 1 ]
280
286
}
281
287
return res
282
288
}
@@ -286,3 +292,12 @@ func toBytes(i BlockHash) []byte {
286
292
binary .LittleEndian .PutUint64 (bytes , uint64 (i ))
287
293
return bytes
288
294
}
295
+
296
+ func getUserInputBytes (request * types.LLMRequest ) ([]byte , error ) {
297
+ if request .Data .Completions != nil { // assumed to be valid if not nil
298
+ return []byte (request .Data .Completions .Prompt ), nil
299
+ }
300
+
301
+ // must be chat-completions request at this point, return bytes of entire messages
302
+ return json .Marshal (request .Data .ChatCompletions .Messages )
303
+ }
0 commit comments