Skip to content

Commit 60ffd6f

Browse files
committed
- added more useful fields to types.LLMRequest:
1. cleaner API declaration 2. data fields are preserved, after-read transformations are done in plugins 3. prefix-cache scorer does not need naive templating - minor bugfixes and improvements Signed-off-by: Maroon Ayoub <[email protected]>
1 parent 40fdedb commit 60ffd6f

File tree

6 files changed

+700
-185
lines changed

6 files changed

+700
-185
lines changed

pkg/epp/requestcontrol/director.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,11 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
100100
}
101101
reqCtx.Request.Body["model"] = reqCtx.TargetModelName
102102

103-
prompt, err := requtil.ExtractPromptFromRequestBody(requestBodyMap)
103+
requestData, err := requtil.ExtractRequestData(reqCtx.Request.Body)
104104
if err != nil {
105-
return reqCtx, err
105+
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()}
106106
}
107+
107108
infObjective := d.datastore.ObjectiveGet(reqCtx.ObjectiveKey)
108109
if infObjective == nil {
109110
logger.V(logutil.VERBOSE).Info("No associated InferenceObjective found, using default", "objectiveKey", reqCtx.ObjectiveKey)
@@ -121,7 +122,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
121122
reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{
122123
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
123124
TargetModel: reqCtx.TargetModelName,
124-
Prompt: prompt,
125+
Data: requestData,
125126
Headers: reqCtx.Request.Headers,
126127
}
127128

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -250,33 +250,39 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
250250
// For block i, hash(i) = hash(block i content, hash(i-1)).
251251
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
252252
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
253-
prompt := []byte(request.Prompt)
254-
if len(prompt) < cacheBlockSize {
255-
loggerDebug.Info("Request body too small for prefix cache", "size", len(prompt), "block size", cacheBlockSize)
253+
if request == nil || request.Data == nil {
254+
loggerDebug.Info("Request or request data is nil, skipping hashing")
256255
return nil
257256
}
258-
if len(prompt) > cacheBlockSize*maxPrefixBlocks {
259-
loggerDebug.Info("Truncating input", "size", len(prompt), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
260-
prompt = prompt[:maxPrefixBlocks*cacheBlockSize]
257+
258+
userInput, err := getUserInputBytes(request)
259+
if err != nil {
260+
loggerDebug.Error(err, "Failed to get user input bytes")
261+
return nil
262+
}
263+
264+
if len(userInput) < cacheBlockSize {
265+
loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize)
266+
return nil
267+
}
268+
if len(userInput) > cacheBlockSize*maxPrefixBlocks {
269+
loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
270+
userInput = userInput[:maxPrefixBlocks*cacheBlockSize]
261271
}
262272
// Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
263273
// If the last block is smaller than cacheBlockSize, it will be ignored.
264-
res := make([]BlockHash, 0, 1+len(prompt)/cacheBlockSize)
274+
res := make([]BlockHash, 0, 1+len(userInput)/cacheBlockSize)
265275
// Add the model to the first block hash so that different models have different hashes even with the same body.
266-
267-
firstBlockSize := cacheBlockSize
268-
if len(prompt) < cacheBlockSize {
269-
firstBlockSize = len(prompt)
270-
}
271-
firstBlock := prompt[0:firstBlockSize]
272-
firstBlockWithModel := append([]byte(request.TargetModel), firstBlock...)
273-
res = append(res, BlockHash(xxhash.Sum64(firstBlockWithModel)))
274-
275-
for i := cacheBlockSize; i+cacheBlockSize <= len(prompt); i += cacheBlockSize {
276-
block := prompt[i : i+cacheBlockSize]
277-
prevBlockHash := res[len(res)-1]
278-
block = append(block, toBytes(prevBlockHash)...)
279-
res = append(res, BlockHash(xxhash.Sum64(block)))
276+
h := xxhash.New()
277+
_, _ = h.Write([]byte(request.TargetModel))
278+
prevBlockHash := BlockHash(h.Sum64())
279+
for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize {
280+
h.Reset()
281+
_, _ = h.Write(userInput[i : i+cacheBlockSize])
282+
_, _ = h.Write(toBytes(prevBlockHash))
283+
res = append(res, BlockHash(h.Sum64()))
284+
285+
prevBlockHash = res[len(res)-1]
280286
}
281287
return res
282288
}
@@ -286,3 +292,12 @@ func toBytes(i BlockHash) []byte {
286292
binary.LittleEndian.PutUint64(bytes, uint64(i))
287293
return bytes
288294
}
295+
296+
func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
297+
if request.Data.Completions != nil { // assumed to be valid if not nil
298+
return []byte(request.Data.Completions.Prompt), nil
299+
}
300+
301+
// must be chat-completions request at this point, return bytes of entire messages
302+
return json.Marshal(request.Data.ChatCompletions.Messages)
303+
}

0 commit comments

Comments
 (0)