Skip to content

Commit 198e6ca

Browse files
Frapschenliu-cong
andauthored
support vLLM cache salting in prefix aware scorer (#1646)
* support vLLM cache salting in prefix aware scorer * Apply suggestions from code review Co-authored-by: Cong Liu <[email protected]> * fix lint --------- Co-authored-by: Cong Liu <[email protected]>
1 parent 435b414 commit 198e6ca

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
257257
}
258258

259259
// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
260-
// hash(0) is the hash of the model name, since different models generally don't share prefix cache.
260+
// hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache.
261261
// For block i, hash(i) = hash(block i content, hash(i-1)).
262262
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
263263
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
@@ -286,6 +286,10 @@ func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize i
286286
// Add the model to the first block hash so that different models have different hashes even with the same body.
287287
h := xxhash.New()
288288
_, _ = h.Write([]byte(request.TargetModel))
289+
if cacheSalt := request.Body.CacheSalt(); cacheSalt != "" {
290+
_, _ = h.Write([]byte(cacheSalt))
291+
}
292+
289293
prevBlockHash := BlockHash(h.Sum64())
290294
for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize {
291295
h.Reset()

pkg/epp/scheduling/types/types.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,27 @@ type LLMRequestBody struct {
5656
ChatCompletions *ChatCompletionsRequest `json:"chat_completions,omitempty"`
5757
}
5858

59+
func (r *LLMRequestBody) CacheSalt() string {
60+
if r.ChatCompletions == nil && r.Completions == nil {
61+
return ""
62+
}
63+
64+
if r.ChatCompletions != nil {
65+
return r.ChatCompletions.CacheSalt
66+
}
67+
68+
return r.Completions.CacheSalt
69+
}
70+
5971
// CompletionsRequest is a structured representation of the fields we parse out of the
6072
// /v1/completions request body.
6173
// This struct includes fields usable for plugins and scheduling decisions - and not the entire
6274
// API spec.
6375
type CompletionsRequest struct {
6476
// Prompt is the prompt that was sent in the request body.
6577
Prompt string `json:"prompt,omitempty"`
78+
// CacheSalt is an optional request parameter to isolate prefix caches for security reasons.
79+
CacheSalt string `json:"cache_salt,omitempty"`
6680
}
6781

6882
func (r *CompletionsRequest) String() string {
@@ -88,6 +102,8 @@ type ChatCompletionsRequest struct {
88102
ContinueFinalMessage bool `json:"continue_final_message,omitempty"`
89103
AddGenerationPrompt bool `json:"add_generation_prompt,omitempty"`
90104
ChatTemplateKWArgs map[string]interface{} `json:"chat_template_kwargs,omitempty"`
105+
// CacheSalt is an optional request parameter to isolate prefix caches for security reasons.
106+
CacheSalt string `json:"cache_salt,omitempty"`
91107
}
92108

93109
func (r *ChatCompletionsRequest) String() string {

pkg/epp/util/request/body_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,44 @@ func TestExtractRequestData(t *testing.T) {
225225
},
226226
wantErr: true,
227227
},
228+
{
229+
name: "completions request with cache_salt",
230+
body: map[string]any{
231+
"model": "test",
232+
"prompt": "test prompt",
233+
"cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
234+
},
235+
want: &types.LLMRequestBody{
236+
Completions: &types.CompletionsRequest{
237+
Prompt: "test prompt",
238+
CacheSalt: "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
239+
},
240+
},
241+
},
242+
{
243+
name: "chat completions request with cache_salt",
244+
body: map[string]any{
245+
"model": "test",
246+
"messages": []any{
247+
map[string]any{
248+
"role": "system", "content": "this is a system message",
249+
},
250+
map[string]any{
251+
"role": "user", "content": "hello",
252+
},
253+
},
254+
"cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
255+
},
256+
want: &types.LLMRequestBody{
257+
ChatCompletions: &types.ChatCompletionsRequest{
258+
Messages: []types.Message{
259+
{Role: "system", Content: "this is a system message"},
260+
{Role: "user", Content: "hello"},
261+
},
262+
CacheSalt: "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
263+
},
264+
},
265+
},
228266
}
229267

230268
for _, tt := range tests {

0 commit comments

Comments
 (0)