Skip to content

Commit 1b7e9ae

Browse files
Copilotrootfs
andcommitted
Remove duplicate code in FindSimilar functions
Refactored FindSimilar() to delegate to FindSimilarWithThreshold() with default threshold instead of duplicating the entire implementation. This eliminates 226 lines of duplicate code across inmemory_cache.go and milvus_cache.go. Co-authored-by: rootfs <[email protected]>
1 parent 11324dd commit 1b7e9ae

File tree

2 files changed

+3
-226
lines changed

2 files changed

+3
-226
lines changed

src/semantic-router/pkg/cache/inmemory_cache.go

Lines changed: 2 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -207,131 +207,9 @@ func (c *InMemoryCache) AddEntry(requestID string, model string, query string, r
207207
return nil
208208
}
209209

210-
// FindSimilar searches for semantically similar cached requests
210+
// FindSimilar searches for semantically similar cached requests using the default threshold
211211
func (c *InMemoryCache) FindSimilar(model string, query string) ([]byte, bool, error) {
212-
start := time.Now()
213-
214-
if !c.enabled {
215-
observability.Debugf("InMemoryCache.FindSimilar: cache disabled")
216-
return nil, false, nil
217-
}
218-
queryPreview := query
219-
if len(query) > 50 {
220-
queryPreview = query[:50] + "..."
221-
}
222-
observability.Debugf("InMemoryCache.FindSimilar: searching for model='%s', query='%s' (len=%d chars)",
223-
model, queryPreview, len(query))
224-
225-
// Generate semantic embedding for similarity comparison
226-
queryEmbedding, err := candle_binding.GetEmbedding(query, 0) // Auto-detect dimension
227-
if err != nil {
228-
metrics.RecordCacheOperation("memory", "find_similar", "error", time.Since(start).Seconds())
229-
return nil, false, fmt.Errorf("failed to generate embedding: %w", err)
230-
}
231-
232-
c.mu.RLock()
233-
var (
234-
bestIndex = -1
235-
bestEntry CacheEntry
236-
bestSimilarity float32
237-
entriesChecked int
238-
expiredCount int
239-
)
240-
// Capture the lookup time after acquiring the read lock so TTL checks aren’t skewed by embedding work or lock wait
241-
now := time.Now()
242-
243-
// Compare with completed entries for the same model, tracking only the best match
244-
for entryIndex, entry := range c.entries {
245-
// Skip incomplete entries
246-
if entry.ResponseBody == nil {
247-
continue
248-
}
249-
250-
// Only consider entries for the same model
251-
if entry.Model != model {
252-
continue
253-
}
254-
255-
// Skip entries that have expired before considering them
256-
if c.isExpired(entry, now) {
257-
expiredCount++
258-
continue
259-
}
260-
261-
// Compute semantic similarity using dot product
262-
var dotProduct float32
263-
for i := 0; i < len(queryEmbedding) && i < len(entry.Embedding); i++ {
264-
dotProduct += queryEmbedding[i] * entry.Embedding[i]
265-
}
266-
267-
entriesChecked++
268-
if bestIndex == -1 || dotProduct > bestSimilarity {
269-
bestSimilarity = dotProduct
270-
bestIndex = entryIndex
271-
}
272-
}
273-
// Snapshot the best entry before releasing the read lock
274-
if bestIndex >= 0 {
275-
bestEntry = c.entries[bestIndex]
276-
}
277-
278-
// Unlock the read lock since we need the write lock to update the access info
279-
c.mu.RUnlock()
280-
281-
// Log if any expired entries were skipped
282-
if expiredCount > 0 {
283-
observability.Debugf("InMemoryCache: excluded %d expired entries during search (TTL: %ds)",
284-
expiredCount, c.ttlSeconds)
285-
observability.LogEvent("cache_expired_entries_found", map[string]interface{}{
286-
"backend": "memory",
287-
"expired_count": expiredCount,
288-
"ttl_seconds": c.ttlSeconds,
289-
})
290-
}
291-
292-
// Handle case where no suitable entries exist
293-
if bestIndex < 0 {
294-
atomic.AddInt64(&c.missCount, 1)
295-
observability.Debugf("InMemoryCache.FindSimilar: no entries found with responses")
296-
metrics.RecordCacheOperation("memory", "find_similar", "miss", time.Since(start).Seconds())
297-
metrics.RecordCacheMiss()
298-
return nil, false, nil
299-
}
300-
301-
// Check if the best match meets the similarity threshold
302-
if bestSimilarity >= c.similarityThreshold {
303-
atomic.AddInt64(&c.hitCount, 1)
304-
305-
c.mu.Lock()
306-
c.updateAccessInfo(bestIndex, bestEntry)
307-
c.mu.Unlock()
308-
309-
observability.Debugf("InMemoryCache.FindSimilar: CACHE HIT - similarity=%.4f >= threshold=%.4f, response_size=%d bytes",
310-
bestSimilarity, c.similarityThreshold, len(bestEntry.ResponseBody))
311-
observability.LogEvent("cache_hit", map[string]interface{}{
312-
"backend": "memory",
313-
"similarity": bestSimilarity,
314-
"threshold": c.similarityThreshold,
315-
"model": model,
316-
})
317-
metrics.RecordCacheOperation("memory", "find_similar", "hit", time.Since(start).Seconds())
318-
metrics.RecordCacheHit()
319-
return bestEntry.ResponseBody, true, nil
320-
}
321-
322-
atomic.AddInt64(&c.missCount, 1)
323-
observability.Debugf("InMemoryCache.FindSimilar: CACHE MISS - best_similarity=%.4f < threshold=%.4f (checked %d entries)",
324-
bestSimilarity, c.similarityThreshold, entriesChecked)
325-
observability.LogEvent("cache_miss", map[string]interface{}{
326-
"backend": "memory",
327-
"best_similarity": bestSimilarity,
328-
"threshold": c.similarityThreshold,
329-
"model": model,
330-
"entries_checked": entriesChecked,
331-
})
332-
metrics.RecordCacheOperation("memory", "find_similar", "miss", time.Since(start).Seconds())
333-
metrics.RecordCacheMiss()
334-
return nil, false, nil
212+
return c.FindSimilarWithThreshold(model, query, c.similarityThreshold)
335213
}
336214

337215
// FindSimilarWithThreshold searches for semantically similar cached requests using a specific threshold

src/semantic-router/pkg/cache/milvus_cache.go

Lines changed: 1 addition & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -487,108 +487,7 @@ func (c *MilvusCache) addEntry(id string, requestID string, model string, query
487487

488488
// FindSimilar searches for semantically similar cached requests
489489
func (c *MilvusCache) FindSimilar(model string, query string) ([]byte, bool, error) {
490-
start := time.Now()
491-
492-
if !c.enabled {
493-
observability.Debugf("MilvusCache.FindSimilar: cache disabled")
494-
return nil, false, nil
495-
}
496-
queryPreview := query
497-
if len(query) > 50 {
498-
queryPreview = query[:50] + "..."
499-
}
500-
observability.Debugf("MilvusCache.FindSimilar: searching for model='%s', query='%s' (len=%d chars)",
501-
model, queryPreview, len(query))
502-
503-
// Generate semantic embedding for similarity comparison
504-
queryEmbedding, err := candle_binding.GetEmbedding(query, 0) // Auto-detect dimension
505-
if err != nil {
506-
metrics.RecordCacheOperation("milvus", "find_similar", "error", time.Since(start).Seconds())
507-
return nil, false, fmt.Errorf("failed to generate embedding: %w", err)
508-
}
509-
510-
ctx := context.Background()
511-
512-
// Define search parameters
513-
searchParam, err := entity.NewIndexHNSWSearchParam(c.config.Search.Params.Ef)
514-
if err != nil {
515-
return nil, false, fmt.Errorf("failed to create search parameters: %w", err)
516-
}
517-
518-
// Use Milvus Search for efficient similarity search
519-
searchResult, err := c.client.Search(
520-
ctx,
521-
c.collectionName,
522-
[]string{},
523-
fmt.Sprintf("model == \"%s\" && response_body != \"\"", model),
524-
[]string{"response_body"},
525-
[]entity.Vector{entity.FloatVector(queryEmbedding)},
526-
c.config.Collection.VectorField.Name,
527-
entity.MetricType(c.config.Collection.VectorField.MetricType),
528-
c.config.Search.TopK,
529-
searchParam,
530-
)
531-
if err != nil {
532-
observability.Debugf("MilvusCache.FindSimilar: search failed: %v", err)
533-
atomic.AddInt64(&c.missCount, 1)
534-
metrics.RecordCacheOperation("milvus", "find_similar", "error", time.Since(start).Seconds())
535-
metrics.RecordCacheMiss()
536-
return nil, false, nil
537-
}
538-
539-
if len(searchResult) == 0 || searchResult[0].ResultCount == 0 {
540-
atomic.AddInt64(&c.missCount, 1)
541-
observability.Debugf("MilvusCache.FindSimilar: no entries found")
542-
metrics.RecordCacheOperation("milvus", "find_similar", "miss", time.Since(start).Seconds())
543-
metrics.RecordCacheMiss()
544-
return nil, false, nil
545-
}
546-
547-
bestScore := searchResult[0].Scores[0]
548-
if bestScore < c.similarityThreshold {
549-
atomic.AddInt64(&c.missCount, 1)
550-
observability.Debugf("MilvusCache.FindSimilar: CACHE MISS - best_similarity=%.4f < threshold=%.4f",
551-
bestScore, c.similarityThreshold)
552-
observability.LogEvent("cache_miss", map[string]interface{}{
553-
"backend": "milvus",
554-
"best_similarity": bestScore,
555-
"threshold": c.similarityThreshold,
556-
"model": model,
557-
"collection": c.collectionName,
558-
})
559-
metrics.RecordCacheOperation("milvus", "find_similar", "miss", time.Since(start).Seconds())
560-
metrics.RecordCacheMiss()
561-
return nil, false, nil
562-
}
563-
564-
// Cache Hit
565-
var responseBody []byte
566-
responseBodyColumn, ok := searchResult[0].Fields[0].(*entity.ColumnVarChar)
567-
if ok && responseBodyColumn.Len() > 0 {
568-
responseBody = []byte(responseBodyColumn.Data()[0])
569-
}
570-
571-
if responseBody == nil {
572-
observability.Debugf("MilvusCache.FindSimilar: cache hit but response_body is missing or not a string")
573-
atomic.AddInt64(&c.missCount, 1)
574-
metrics.RecordCacheOperation("milvus", "find_similar", "error", time.Since(start).Seconds())
575-
metrics.RecordCacheMiss()
576-
return nil, false, nil
577-
}
578-
579-
atomic.AddInt64(&c.hitCount, 1)
580-
observability.Debugf("MilvusCache.FindSimilar: CACHE HIT - similarity=%.4f >= threshold=%.4f, response_size=%d bytes",
581-
bestScore, c.similarityThreshold, len(responseBody))
582-
observability.LogEvent("cache_hit", map[string]interface{}{
583-
"backend": "milvus",
584-
"similarity": bestScore,
585-
"threshold": c.similarityThreshold,
586-
"model": model,
587-
"collection": c.collectionName,
588-
})
589-
metrics.RecordCacheOperation("milvus", "find_similar", "hit", time.Since(start).Seconds())
590-
metrics.RecordCacheHit()
591-
return responseBody, true, nil
490+
return c.FindSimilarWithThreshold(model, query, c.similarityThreshold)
592491
}
593492

594493
// FindSimilarWithThreshold searches for semantically similar cached requests using a specific threshold

0 commit comments

Comments
 (0)