Skip to content

Commit 05ec72c

Browse files
Frapschenliu-cong
authored andcommitted
support vLLM cache salting in prefix aware scorer (kubernetes-sigs#1646)
* support vLLM cache salting in prefix aware scorer * Apply suggestions from code review Co-authored-by: Cong Liu <[email protected]> * fix lint --------- Co-authored-by: Cong Liu <[email protected]>
1 parent b0fbffb commit 05ec72c

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
263263
}
264264

265265
// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
266-
// hash(0) is the hash of the model name, since different models generally don't share prefix cache.
266+
// hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache.
267267
// For block i, hash(i) = hash(block i content, hash(i-1)).
268268
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
269269
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
@@ -292,6 +292,10 @@ func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize i
292292
// Add the model to the first block hash so that different models have different hashes even with the same body.
293293
h := xxhash.New()
294294
_, _ = h.Write([]byte(request.TargetModel))
295+
if cacheSalt := request.Body.CacheSalt(); cacheSalt != "" {
296+
_, _ = h.Write([]byte(cacheSalt))
297+
}
298+
295299
prevBlockHash := BlockHash(h.Sum64())
296300
for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize {
297301
h.Reset()

pkg/epp/scheduling/types/types.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,27 @@ type LLMRequestBody struct {
7373
ChatCompletions *ChatCompletionsRequest `json:"chat_completions,omitempty"`
7474
}
7575

76+
func (r *LLMRequestBody) CacheSalt() string {
77+
if r.ChatCompletions == nil && r.Completions == nil {
78+
return ""
79+
}
80+
81+
if r.ChatCompletions != nil {
82+
return r.ChatCompletions.CacheSalt
83+
}
84+
85+
return r.Completions.CacheSalt
86+
}
87+
7688
// CompletionsRequest is a structured representation of the fields we parse out of the
7789
// /v1/completions request body.
7890
// This struct includes fields usable for plugins and scheduling decisions - and not the entire
7991
// API spec.
8092
type CompletionsRequest struct {
8193
// Prompt is the prompt that was sent in the request body.
8294
Prompt string `json:"prompt,omitempty"`
95+
// CacheSalt is an optional request parameter to isolate prefix caches for security reasons.
96+
CacheSalt string `json:"cache_salt,omitempty"`
8397
}
8498

8599
func (r *CompletionsRequest) String() string {
@@ -105,6 +119,8 @@ type ChatCompletionsRequest struct {
105119
ContinueFinalMessage bool `json:"continue_final_message,omitempty"`
106120
AddGenerationPrompt bool `json:"add_generation_prompt,omitempty"`
107121
ChatTemplateKWArgs map[string]interface{} `json:"chat_template_kwargs,omitempty"`
122+
// CacheSalt is an optional request parameter to isolate prefix caches for security reasons.
123+
CacheSalt string `json:"cache_salt,omitempty"`
108124
}
109125

110126
func (r *ChatCompletionsRequest) String() string {

pkg/epp/util/request/body_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,44 @@ func TestExtractRequestData(t *testing.T) {
225225
},
226226
wantErr: true,
227227
},
228+
{
229+
name: "completions request with cache_salt",
230+
body: map[string]any{
231+
"model": "test",
232+
"prompt": "test prompt",
233+
"cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
234+
},
235+
want: &types.LLMRequestBody{
236+
Completions: &types.CompletionsRequest{
237+
Prompt: "test prompt",
238+
CacheSalt: "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
239+
},
240+
},
241+
},
242+
{
243+
name: "chat completions request with cache_salt",
244+
body: map[string]any{
245+
"model": "test",
246+
"messages": []any{
247+
map[string]any{
248+
"role": "system", "content": "this is a system message",
249+
},
250+
map[string]any{
251+
"role": "user", "content": "hello",
252+
},
253+
},
254+
"cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
255+
},
256+
want: &types.LLMRequestBody{
257+
ChatCompletions: &types.ChatCompletionsRequest{
258+
Messages: []types.Message{
259+
{Role: "system", Content: "this is a system message"},
260+
{Role: "user", Content: "hello"},
261+
},
262+
CacheSalt: "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
263+
},
264+
},
265+
},
228266
}
229267

230268
for _, tt := range tests {

0 commit comments

Comments
 (0)