Skip to content

Commit 40fdedb

Browse files
authored
fix: first hash of prefix cache with same model name (#1341)
* fix: first hash of prefix cache with same model name * fix: no hash if the prompt is smaller than cacheBlockSize * fix: optimize if else for more concise * chore: clean test comments
1 parent 3bb9e19 commit 40fdedb

File tree

2 files changed

+23
-15
lines changed

2 files changed

+23
-15
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,16 @@ func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize i
263263
// If the last block is smaller than cacheBlockSize, it will be ignored.
264264
res := make([]BlockHash, 0, 1+len(prompt)/cacheBlockSize)
265265
// Add the model to the first block hash so that different models have different hashes even with the same body.
266-
res = append(res, BlockHash(xxhash.Sum64String(request.TargetModel)))
267-
for i := 0; i+cacheBlockSize <= len(prompt); i += cacheBlockSize {
266+
267+
firstBlockSize := cacheBlockSize
268+
if len(prompt) < cacheBlockSize {
269+
firstBlockSize = len(prompt)
270+
}
271+
firstBlock := prompt[0:firstBlockSize]
272+
firstBlockWithModel := append([]byte(request.TargetModel), firstBlock...)
273+
res = append(res, BlockHash(xxhash.Sum64(firstBlockWithModel)))
274+
275+
for i := cacheBlockSize; i+cacheBlockSize <= len(prompt); i += cacheBlockSize {
268276
block := prompt[i : i+cacheBlockSize]
269277
prevBlockHash := res[len(res)-1]
270278
block = append(block, toBytes(prevBlockHash)...)

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ func TestPrefixPlugin(t *testing.T) {
5757
assert.NoError(t, err)
5858
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
5959
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
60-
// Total hashes = 2 (the first one is for the model)
61-
assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect")
60+
// Total hashes = 1 (the first one is for the prefix with model)
61+
assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect")
6262
assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers")
6363
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
6464
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
@@ -84,8 +84,8 @@ func TestPrefixPlugin(t *testing.T) {
8484
assert.NoError(t, err)
8585
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
8686
// Input size is 6, hash block size is 4, the last 2 characters are ignored.
87-
// Total hashes = 2 (the first one is for the model)
88-
assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect")
87+
// Total hashes = 1 (the first one is for the prefix with model)
88+
assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect")
8989
assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers")
9090
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
9191
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
@@ -110,10 +110,10 @@ func TestPrefixPlugin(t *testing.T) {
110110
assert.NoError(t, err)
111111
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
112112
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
113-
// Total hashes = 3 (the first one is for the model)
114-
assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect")
113+
// Total hashes = 2 (the first one is for the prefix with model)
114+
assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect")
115115
assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix")
116-
assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match")
116+
assert.Equal(t, 0.5, scores[pod1], "score should be 0.5 - the model and the first prefix block match")
117117
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
118118

119119
schedulingResult = &types.SchedulingResult{
@@ -135,8 +135,8 @@ func TestPrefixPlugin(t *testing.T) {
135135
assert.NoError(t, err)
136136
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
137137
// Input size is 8, hash block size is 4, so 2 hashes will be calculated.
138-
// Total hashes = 3 (the first one is for the model)
139-
assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect")
138+
// Total hashes = 2 (the first one is for the prefix with model)
139+
assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect")
140140
assert.Equal(t, 0, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix")
141141
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
142142
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
@@ -160,10 +160,10 @@ func TestPrefixPlugin(t *testing.T) {
160160
assert.NoError(t, err)
161161
t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
162162
// Input size is 12, hash block size is 4, so 3 hashes will be calculated.
163-
// Total hashes = 4 (the first one is for the model)
164-
assert.Equal(t, 4, len(state.PrefixHashes), "number of hashes is incorrect")
163+
// Total hashes = 3 (the first one is for the prefix with model)
164+
assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect")
165165
assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix")
166-
assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match")
166+
assert.Equal(t, 2./3, scores[pod1], "score should be 2./3 - the model and the first 2 prefix blocks match")
167167
assert.Equal(t, float64(0), scores[pod2], "score for pod2")
168168

169169
schedulingResult = &types.SchedulingResult{
@@ -224,7 +224,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
224224
// Second cycle: validate internal state
225225
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, PrefixCachePluginType)
226226
assert.NoError(b, err)
227-
expectedHashes := int(math.Min(float64(maxPrefixBlocks+1), float64(len(req.Prompt)/blockSize+1))) // the extra one is for the model.
227+
expectedHashes := int(math.Min(float64(maxPrefixBlocks), float64(len(req.Prompt)/blockSize)))
228228
assert.Equal(b, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect")
229229
}
230230
}

0 commit comments

Comments
 (0)