Skip to content

Commit 0847c4c

Browse files
committed
fix: use unified classifier in classification endpoints
Fix API endpoints returning placeholder responses instead of real results. Modified ClassifyIntent, DetectPII, CheckSecurity to prioritize unified classifier over legacy classifier. Signed-off-by: OneZero-Y <[email protected]>
1 parent e75fc0f commit 0847c4c

File tree

1 file changed

+137
-4
lines changed

1 file changed

+137
-4
lines changed

src/semantic-router/pkg/services/classification.go

Lines changed: 137 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func GetGlobalClassificationService() *ClassificationService {
7171

7272
// HasClassifier returns true if the service has a real classifier (not placeholder)
7373
func (s *ClassificationService) HasClassifier() bool {
74-
return s.classifier != nil
74+
return s.unifiedClassifier != nil || s.classifier != nil
7575
}
7676

7777
// NewPlaceholderClassificationService creates a placeholder service for API-only mode
@@ -118,7 +118,12 @@ func (s *ClassificationService) ClassifyIntent(req IntentRequest) (*IntentRespon
118118
return nil, fmt.Errorf("text cannot be empty")
119119
}
120120

121-
// Check if classifier is available
121+
// Prioritize unified classifier if available
122+
if s.unifiedClassifier != nil {
123+
return s.ClassifyIntentUnified(req)
124+
}
125+
126+
// Check if legacy classifier is available
122127
if s.classifier == nil {
123128
// Return placeholder response
124129
processingTime := time.Since(start).Milliseconds()
@@ -210,7 +215,12 @@ func (s *ClassificationService) DetectPII(req PIIRequest) (*PIIResponse, error)
210215
return nil, fmt.Errorf("text cannot be empty")
211216
}
212217

213-
// Check if classifier is available
218+
// Prioritize unified classifier if available
219+
if s.unifiedClassifier != nil {
220+
return s.DetectPIIUnified(req)
221+
}
222+
223+
// Check if legacy classifier is available
214224
if s.classifier == nil {
215225
// Return placeholder response
216226
processingTime := time.Since(start).Milliseconds()
@@ -290,7 +300,12 @@ func (s *ClassificationService) CheckSecurity(req SecurityRequest) (*SecurityRes
290300
return nil, fmt.Errorf("text cannot be empty")
291301
}
292302

293-
// Check if classifier is available
303+
// Prioritize unified classifier if available
304+
if s.unifiedClassifier != nil {
305+
return s.CheckSecurityUnified(req)
306+
}
307+
308+
// Check if legacy classifier is available
294309
if s.classifier == nil {
295310
// Return placeholder response
296311
processingTime := time.Since(start).Milliseconds()
@@ -454,6 +469,59 @@ func (s *ClassificationService) ClassifyPIIUnified(texts []string) ([]classifica
454469
return results.PIIResults, nil
455470
}
456471

472+
// DetectPIIUnified performs PII detection using unified classifier and returns PIIResponse format
473+
func (s *ClassificationService) DetectPIIUnified(req PIIRequest) (*PIIResponse, error) {
474+
start := time.Now()
475+
476+
if req.Text == "" {
477+
return nil, fmt.Errorf("text cannot be empty")
478+
}
479+
480+
// Use unified classifier for PII detection
481+
piiResults, err := s.ClassifyPIIUnified([]string{req.Text})
482+
if err != nil {
483+
return nil, fmt.Errorf("PII detection failed: %w", err)
484+
}
485+
486+
processingTime := time.Since(start).Milliseconds()
487+
488+
// Convert PIIResult to PIIResponse format
489+
if len(piiResults) == 0 {
490+
return &PIIResponse{
491+
HasPII: false,
492+
Entities: []PIIEntity{},
493+
SecurityRecommendation: "allow",
494+
ProcessingTimeMs: processingTime,
495+
}, nil
496+
}
497+
498+
piiResult := piiResults[0]
499+
response := &PIIResponse{
500+
HasPII: piiResult.HasPII,
501+
Entities: []PIIEntity{},
502+
ProcessingTimeMs: processingTime,
503+
}
504+
505+
// Convert PII types to entities
506+
for _, piiType := range piiResult.PIITypes {
507+
entity := PIIEntity{
508+
Type: piiType,
509+
Value: "[DETECTED]", // Placeholder - unified classifier doesn't provide exact positions yet
510+
Confidence: float64(piiResult.Confidence),
511+
}
512+
response.Entities = append(response.Entities, entity)
513+
}
514+
515+
// Set security recommendation
516+
if response.HasPII {
517+
response.SecurityRecommendation = "block"
518+
} else {
519+
response.SecurityRecommendation = "allow"
520+
}
521+
522+
return response, nil
523+
}
524+
457525
// ClassifySecurityUnified performs security detection using unified classifier
458526
func (s *ClassificationService) ClassifySecurityUnified(texts []string) ([]classification.SecurityResult, error) {
459527
if s.unifiedClassifier == nil {
@@ -468,6 +536,71 @@ func (s *ClassificationService) ClassifySecurityUnified(texts []string) ([]class
468536
return results.SecurityResults, nil
469537
}
470538

539+
// CheckSecurityUnified performs security detection using unified classifier and returns SecurityResponse format
540+
func (s *ClassificationService) CheckSecurityUnified(req SecurityRequest) (*SecurityResponse, error) {
541+
start := time.Now()
542+
543+
if req.Text == "" {
544+
return nil, fmt.Errorf("text cannot be empty")
545+
}
546+
547+
// Use unified classifier for security detection
548+
securityResults, err := s.ClassifySecurityUnified([]string{req.Text})
549+
if err != nil {
550+
return nil, fmt.Errorf("security detection failed: %w", err)
551+
}
552+
553+
processingTime := time.Since(start).Milliseconds()
554+
555+
// Convert SecurityResult to SecurityResponse format
556+
if len(securityResults) == 0 {
557+
return &SecurityResponse{
558+
IsJailbreak: false,
559+
RiskScore: 0.1,
560+
DetectionTypes: []string{},
561+
Confidence: 0.9,
562+
Recommendation: "allow",
563+
PatternsDetected: []string{},
564+
ProcessingTimeMs: processingTime,
565+
}, nil
566+
}
567+
568+
securityResult := securityResults[0]
569+
response := &SecurityResponse{
570+
IsJailbreak: securityResult.IsJailbreak,
571+
RiskScore: float64(securityResult.Confidence),
572+
Confidence: float64(securityResult.Confidence),
573+
ProcessingTimeMs: processingTime,
574+
}
575+
576+
// Set detection types based on threat type
577+
if securityResult.ThreatType != "" {
578+
response.DetectionTypes = []string{securityResult.ThreatType}
579+
response.PatternsDetected = []string{securityResult.ThreatType}
580+
} else {
581+
response.DetectionTypes = []string{}
582+
response.PatternsDetected = []string{}
583+
}
584+
585+
// Set recommendation based on jailbreak detection
586+
if response.IsJailbreak {
587+
response.Recommendation = "block"
588+
} else {
589+
response.Recommendation = "allow"
590+
}
591+
592+
// Add reasoning if requested
593+
if req.Options != nil && req.Options.IncludeReasoning {
594+
if response.IsJailbreak {
595+
response.Reasoning = fmt.Sprintf("Detected %s with confidence %.2f", securityResult.ThreatType, securityResult.Confidence)
596+
} else {
597+
response.Reasoning = "No security threats detected"
598+
}
599+
}
600+
601+
return response, nil
602+
}
603+
471604
// HasUnifiedClassifier returns true if the service has a unified classifier
472605
func (s *ClassificationService) HasUnifiedClassifier() bool {
473606
return s.unifiedClassifier != nil && s.unifiedClassifier.IsInitialized()

0 commit comments

Comments
 (0)