Skip to content

Commit 7e54562

Browse files
yehudit1987Yehudit Kerido
andauthored
fix skipped tests (#621)
Signed-off-by: Yehudit Kerido <[email protected]> Co-authored-by: Yehudit Kerido <[email protected]>
1 parent aab9aa7 commit 7e54562

File tree

5 files changed

+59
-10
lines changed

5 files changed

+59
-10
lines changed

candle-binding/semantic-router_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,15 +1362,15 @@ func TestCandleBertTokensWithLabels(t *testing.T) {
13621362

13631363
success := InitCandleBertTokenClassifier(BertPIITokenClassifierModelPath, 9, true) // 9 PII classes
13641364
if !success {
1365-
t.Skipf("Candle BERT token classifier not available at path: %s", BertPIITokenClassifierModelPath)
1365+
t.Fatalf("Failed to initialize Candle BERT token classifier at path: %s. Model should be available in CI (included in LoRA model set).", BertPIITokenClassifierModelPath)
13661366
}
13671367

13681368
testText := "Contact Dr. Sarah Johnson at [email protected] for medical records"
13691369

13701370
result, err := ClassifyCandleBertTokensWithLabels(testText, id2labelJSON)
13711371
if err != nil {
13721372
if isModelInitializationError(err) {
1373-
t.Skipf("Skipping Candle BERT token classifier tests due to model initialization error: %v", err)
1373+
t.Fatalf("Candle BERT token classifier failed with model initialization error: %v. Model should be initialized earlier in test.", err)
13741374
}
13751375
t.Fatalf("Token classification with labels failed: %v", err)
13761376
}

candle-binding/src/classifiers/lora/token_lora.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,20 +187,23 @@ impl LoRATokenClassifier {
187187
for (i, (token, token_embedding)) in tokens.iter().enumerate() {
188188
// Use real BERT embedding from tokenization
189189

190+
// Add batch dimension: [hidden_size] -> [1, hidden_size]
191+
let token_embedding_batched = token_embedding.unsqueeze(0)?;
192+
190193
// Apply base classifier
191-
let base_logits = self.base_classifier.forward(&token_embedding)?;
194+
let base_logits = self.base_classifier.forward(&token_embedding_batched)?;
192195

193196
// Apply LoRA adapters if available
194197
let enhanced_logits = if let Some(adapter) = self.adapters.get("token_classification") {
195-
let adapter_output = adapter.forward(&token_embedding, false)?; // false = not training
198+
let adapter_output = adapter.forward(&token_embedding_batched, false)?; // false = not training
196199
(&base_logits + &adapter_output)?
197200
} else {
198201
base_logits
199202
};
200203

201-
// Apply softmax to get probabilities
204+
// Apply softmax to get probabilities and remove batch dimension
202205
let probabilities = candle_nn::ops::softmax(&enhanced_logits, 1)?;
203-
let probs_vec = probabilities.to_vec1::<f32>()?;
206+
let probs_vec = probabilities.squeeze(0)?.to_vec1::<f32>()?;
204207

205208
// Find the class with highest probability
206209
let (predicted_id, confidence) = probs_vec

candle-binding/src/ffi/classify.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -520,8 +520,40 @@ pub extern "C" fn classify_candle_bert_tokens_with_labels(
520520
}
521521
};
522522

523-
// Use TraditionalBertTokenClassifier for token-level classification with labels
523+
// Intelligent routing: Check LoRA token classifier first, then fall back to traditional
524524

525+
// Try LoRA token classifier first
526+
if let Some(classifier) = crate::ffi::init::LORA_TOKEN_CLASSIFIER.get() {
527+
let classifier = classifier.clone();
528+
match classifier.classify_tokens(text) {
529+
Ok(token_results) => {
530+
// Filter out "O" (Outside) labels - only return actual entities
531+
let token_entities: Vec<(String, String, f32)> = token_results
532+
.iter()
533+
.filter(|result| result.label_name != "O" && result.label_id != 0)
534+
.map(|result| {
535+
(
536+
result.token.clone(),
537+
result.label_name.clone(),
538+
result.confidence,
539+
)
540+
})
541+
.collect();
542+
543+
let entities_ptr = unsafe { allocate_bert_token_entity_array(&token_entities) };
544+
545+
return BertTokenClassificationResult {
546+
entities: entities_ptr,
547+
num_entities: token_entities.len() as i32,
548+
};
549+
}
550+
Err(_e) => {
551+
// Fall through to traditional classifier
552+
}
553+
}
554+
}
555+
556+
// Fall back to traditional BERT token classifier
525557
if let Some(classifier) = TRADITIONAL_BERT_TOKEN_CLASSIFIER.get() {
526558
let classifier = classifier.clone();
527559
match classifier.classify_tokens(text) {
@@ -581,17 +613,19 @@ pub extern "C" fn classify_candle_bert_tokens(
581613
let lora_classifier = lora_classifier.clone();
582614
match lora_classifier.classify_tokens(text) {
583615
Ok(lora_results) => {
616+
// Filter out "O" (Outside) labels - only return actual entities
584617
// Convert LoRA results to BertTokenEntity format
585618
let token_entities: Vec<(String, String, f32)> = lora_results
586619
.iter()
620+
.filter(|r| r.label_name != "O" && r.label_id != 0)
587621
.map(|r| (r.token.clone(), r.label_name.clone(), r.confidence))
588622
.collect();
589623

590624
let entities_ptr = unsafe { allocate_bert_token_entity_array(&token_entities) };
591625

592626
return BertTokenClassificationResult {
593627
entities: entities_ptr,
594-
num_entities: lora_results.len() as i32,
628+
num_entities: token_entities.len() as i32,
595629
};
596630
}
597631
Err(_e) => {

candle-binding/src/ffi/init.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,11 @@ pub extern "C" fn init_candle_bert_token_classifier(
594594

595595
match model_type {
596596
ModelType::LoRA => {
597+
// Check if already initialized
598+
if LORA_TOKEN_CLASSIFIER.get().is_some() {
599+
return true; // Already initialized, return success
600+
}
601+
597602
// Route to LoRA token classifier initialization
598603
match crate::classifiers::lora::token_lora::LoRATokenClassifier::new(
599604
model_path, use_cpu,
@@ -606,6 +611,14 @@ pub extern "C" fn init_candle_bert_token_classifier(
606611
}
607612
}
608613
ModelType::Traditional => {
614+
// Check if already initialized
615+
if crate::model_architectures::traditional::bert::TRADITIONAL_BERT_TOKEN_CLASSIFIER
616+
.get()
617+
.is_some()
618+
{
619+
return true; // Already initialized, return success
620+
}
621+
609622
// Route to traditional BERT token classifier
610623
match crate::model_architectures::traditional::bert::TraditionalBertTokenClassifier::new(
611624
model_path,

tools/make/rust.mk

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ test-binding-lora: $(if $(CI),rust-ci,rust) ## Run Go tests with LoRA and advanc
6464
@echo "Running candle-binding tests with LoRA and advanced embedding models..."
6565
@export LD_LIBRARY_PATH=${PWD}/candle-binding/target/release && \
6666
cd candle-binding && CGO_ENABLED=1 go test -v -race \
67-
-run "^Test(BertTokenClassification|BertSequenceClassification|CandleBertClassifier|CandleBertTokenClassifier|CandleBertTokensWithLabels|LoRAUnifiedClassifier|GetEmbeddingSmart|InitEmbeddingModels|GetEmbeddingWithDim|EmbeddingConsistency|EmbeddingPriorityRouting|EmbeddingConcurrency)$$" \
68-
|| { echo "⚠️ Warning: Some LoRA/embedding tests failed (may be due to missing restricted models), continuing..."; $(if $(CI),true,exit 1); }
67+
-run "^Test(BertTokenClassification|BertSequenceClassification|CandleBertClassifier|CandleBertTokenClassifier|CandleBertTokensWithLabels|LoRAUnifiedClassifier|GetEmbeddingSmart|InitEmbeddingModels|GetEmbeddingWithDim|EmbeddingConsistency|EmbeddingPriorityRouting|EmbeddingConcurrency)$$"
6968

7069
# Test the Rust library - all tests (conditionally use rust-ci in CI environments)
7170
test-binding: $(if $(CI),rust-ci,rust) ## Run all Go tests with the Rust static library

0 commit comments

Comments
 (0)