Skip to content

Commit 4361b59

Browse files
authored
Refactor LLMRequest: Structured RequestData for Completions & Chat-Completions (#1446)
* - added more useful fields to types.LLMRequest: 1. cleaner API declaration 2. data fields are preserved, after-read transformations are done in plugins 3. prefix-cache scorer does not need naive templating - minor bugfixes and improvements Signed-off-by: Maroon Ayoub <[email protected]> * removed LLMRequestData::String Signed-off-by: Maroon Ayoub <[email protected]> * - rename LLMRequestData to LLMRequestBody - rename LLMRequest.Data to LLMRequest.Body - test refactoring after rebase Signed-off-by: Maroon Ayoub <[email protected]> --------- Signed-off-by: Maroon Ayoub <[email protected]>
1 parent a7d943e commit 4361b59

File tree

6 files changed

+590
-194
lines changed

6 files changed

+590
-194
lines changed

pkg/epp/requestcontrol/director.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,11 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
103103
}
104104
reqCtx.Request.Body["model"] = reqCtx.TargetModelName
105105

106-
prompt, err := requtil.ExtractPromptFromRequestBody(requestBodyMap)
106+
requestBody, err := requtil.ExtractRequestBody(reqCtx.Request.Body)
107107
if err != nil {
108-
return reqCtx, err
108+
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()}
109109
}
110+
110111
infObjective := d.datastore.ObjectiveGet(reqCtx.ObjectiveKey)
111112
if infObjective == nil {
112113
logger.V(logutil.VERBOSE).Info("No associated InferenceObjective found, using default", "objectiveKey", reqCtx.ObjectiveKey)
@@ -124,7 +125,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
124125
reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{
125126
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
126127
TargetModel: reqCtx.TargetModelName,
127-
Prompt: prompt,
128+
Body: requestBody,
128129
Headers: reqCtx.Request.Headers,
129130
}
130131

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
125125
}
126126

127127
// compile-time type assertion
128-
var _ framework.Scorer = &Plugin{}
129-
var _ requestcontrol.PreRequest = &Plugin{}
128+
var (
129+
_ framework.Scorer = &Plugin{}
130+
_ requestcontrol.PreRequest = &Plugin{}
131+
)
130132

131133
// PrefixCachePluginFactory defines the factory function for Prefix plugin.
132134
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
@@ -248,7 +250,6 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
248250
for server := range cachedServers {
249251
// Update servers with their longest prefix match.
250252
res[server]++
251-
252253
}
253254
}
254255
}
@@ -260,33 +261,39 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
260261
// For block i, hash(i) = hash(block i content, hash(i-1)).
261262
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
262263
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
263-
prompt := []byte(request.Prompt)
264-
if len(prompt) < cacheBlockSize {
265-
loggerDebug.Info("Request body too small for prefix cache", "size", len(prompt), "block size", cacheBlockSize)
264+
if request == nil || request.Body == nil {
265+
loggerDebug.Info("Request or request data is nil, skipping hashing")
266266
return nil
267267
}
268-
if len(prompt) > cacheBlockSize*maxPrefixBlocks {
269-
loggerDebug.Info("Truncating input", "size", len(prompt), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
270-
prompt = prompt[:maxPrefixBlocks*cacheBlockSize]
268+
269+
userInput, err := getUserInputBytes(request)
270+
if err != nil {
271+
loggerDebug.Error(err, "Failed to get user input bytes")
272+
return nil
271273
}
272-
// Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
273-
// If the last block is smaller than cacheBlockSize, it will be ignored.
274-
res := make([]BlockHash, 0, 1+len(prompt)/cacheBlockSize)
275-
// Add the model to the first block hash so that different models have different hashes even with the same body.
276274

277-
firstBlockSize := cacheBlockSize
278-
if len(prompt) < cacheBlockSize {
279-
firstBlockSize = len(prompt)
275+
if len(userInput) < cacheBlockSize {
276+
loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize)
277+
return nil
280278
}
281-
firstBlock := prompt[0:firstBlockSize]
282-
firstBlockWithModel := append([]byte(request.TargetModel), firstBlock...)
283-
res = append(res, BlockHash(xxhash.Sum64(firstBlockWithModel)))
284-
285-
for i := cacheBlockSize; i+cacheBlockSize <= len(prompt); i += cacheBlockSize {
286-
block := prompt[i : i+cacheBlockSize]
287-
prevBlockHash := res[len(res)-1]
288-
block = append(block, toBytes(prevBlockHash)...)
289-
res = append(res, BlockHash(xxhash.Sum64(block)))
279+
if len(userInput) > cacheBlockSize*maxPrefixBlocks {
280+
loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
281+
userInput = userInput[:maxPrefixBlocks*cacheBlockSize]
282+
}
283+
// Split the body into blocks of size cacheBlockSize.
284+
// If the last block is smaller than cacheBlockSize, it will be ignored.
285+
res := make([]BlockHash, 0, len(userInput)/cacheBlockSize)
286+
// Add the model to the first block hash so that different models have different hashes even with the same body.
287+
h := xxhash.New()
288+
_, _ = h.Write([]byte(request.TargetModel))
289+
prevBlockHash := BlockHash(h.Sum64())
290+
for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize {
291+
h.Reset()
292+
_, _ = h.Write(userInput[i : i+cacheBlockSize])
293+
_, _ = h.Write(toBytes(prevBlockHash))
294+
res = append(res, BlockHash(h.Sum64()))
295+
296+
prevBlockHash = res[len(res)-1]
290297
}
291298
return res
292299
}
@@ -296,3 +303,12 @@ func toBytes(i BlockHash) []byte {
296303
binary.LittleEndian.PutUint64(bytes, uint64(i))
297304
return bytes
298305
}
306+
307+
func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
308+
if request.Body.Completions != nil { // assumed to be valid if not nil
309+
return []byte(request.Body.Completions.Prompt), nil
310+
}
311+
312+
// must be chat-completions request at this point, return bytes of entire messages
313+
return json.Marshal(request.Body.ChatCompletions.Messages)
314+
}

0 commit comments

Comments
 (0)