Skip to content

Commit fb5ce41

Browse files
committed
- added more useful fields to types.LLMRequest:
1. cleaner API declaration 2. data fields are preserved, after-read transformations are done in plugins 3. prefix-cache scorer does not need naive templating - minor bugfixes and improvements Signed-off-by: Maroon Ayoub <[email protected]>
1 parent 40fdedb commit fb5ce41

File tree

6 files changed

+740
-146
lines changed

6 files changed

+740
-146
lines changed

pkg/epp/requestcontrol/director.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,11 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
100100
}
101101
reqCtx.Request.Body["model"] = reqCtx.TargetModelName
102102

103-
prompt, err := requtil.ExtractPromptFromRequestBody(requestBodyMap)
103+
requestData, err := requtil.ExtractRequestData(reqCtx.Request.Body)
104104
if err != nil {
105-
return reqCtx, err
105+
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()}
106106
}
107+
107108
infObjective := d.datastore.ObjectiveGet(reqCtx.ObjectiveKey)
108109
if infObjective == nil {
109110
logger.V(logutil.VERBOSE).Info("No associated InferenceObjective found, using default", "objectiveKey", reqCtx.ObjectiveKey)
@@ -121,7 +122,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
121122
reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{
122123
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
123124
TargetModel: reqCtx.TargetModelName,
124-
Prompt: prompt,
125+
Data: requestData,
125126
Headers: reqCtx.Request.Headers,
126127
}
127128

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -250,33 +250,37 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
250250
// For block i, hash(i) = hash(block i content, hash(i-1)).
251251
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
252252
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
253-
prompt := []byte(request.Prompt)
254-
if len(prompt) < cacheBlockSize {
255-
loggerDebug.Info("Request body too small for prefix cache", "size", len(prompt), "block size", cacheBlockSize)
253+
if request == nil || request.Data == nil {
254+
loggerDebug.Info("Request or request data is nil, skipping hashing")
256255
return nil
257256
}
258-
if len(prompt) > cacheBlockSize*maxPrefixBlocks {
259-
loggerDebug.Info("Truncating input", "size", len(prompt), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
260-
prompt = prompt[:maxPrefixBlocks*cacheBlockSize]
257+
258+
userInput, err := getUserInputBytes(request)
259+
if err != nil {
260+
loggerDebug.Error(err, "Failed to get user input bytes")
261+
return nil
262+
}
263+
264+
if len(userInput) < cacheBlockSize {
265+
loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize)
266+
return nil
267+
}
268+
if len(userInput) > cacheBlockSize*maxPrefixBlocks {
269+
loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
270+
userInput = userInput[:maxPrefixBlocks*cacheBlockSize]
261271
}
262272
// Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
263273
// If the last block is smaller than cacheBlockSize, it will be ignored.
264-
res := make([]BlockHash, 0, 1+len(prompt)/cacheBlockSize)
274+
res := make([]BlockHash, 0, 1+len(userInput)/cacheBlockSize)
265275
// Add the model to the first block hash so that different models have different hashes even with the same body.
266-
267-
firstBlockSize := cacheBlockSize
268-
if len(prompt) < cacheBlockSize {
269-
firstBlockSize = len(prompt)
270-
}
271-
firstBlock := prompt[0:firstBlockSize]
272-
firstBlockWithModel := append([]byte(request.TargetModel), firstBlock...)
273-
res = append(res, BlockHash(xxhash.Sum64(firstBlockWithModel)))
274-
275-
for i := cacheBlockSize; i+cacheBlockSize <= len(prompt); i += cacheBlockSize {
276-
block := prompt[i : i+cacheBlockSize]
276+
res = append(res, BlockHash(xxhash.Sum64String(request.TargetModel)))
277+
h := xxhash.New()
278+
for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize {
279+
h.Reset()
280+
_, _ = h.Write(userInput[i : i+cacheBlockSize])
277281
prevBlockHash := res[len(res)-1]
278-
block = append(block, toBytes(prevBlockHash)...)
279-
res = append(res, BlockHash(xxhash.Sum64(block)))
282+
_, _ = h.Write(toBytes(prevBlockHash))
283+
res = append(res, BlockHash(h.Sum64()))
280284
}
281285
return res
282286
}
@@ -286,3 +290,12 @@ func toBytes(i BlockHash) []byte {
286290
binary.LittleEndian.PutUint64(bytes, uint64(i))
287291
return bytes
288292
}
293+
294+
func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
295+
if request.Data.Completions != nil { // assumed to be valid if not nil
296+
return []byte(request.Data.Completions.Prompt), nil
297+
}
298+
299+
// must be chat-completions request at this point, return bytes of entire messages
300+
return json.Marshal(request.Data.ChatCompletions.Messages)
301+
}

0 commit comments

Comments
 (0)