diff --git a/Cargo.lock b/Cargo.lock index 89b011d6..c43b1884 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4494,6 +4494,7 @@ version = "1.7.0" dependencies = [ "clap", "nohash-hasher", + "serde", "thiserror 1.0.69", ] diff --git a/backends/candle/src/models/flash_modernbert.rs b/backends/candle/src/models/flash_modernbert.rs index 95b0cfb1..3c876c63 100644 --- a/backends/candle/src/models/flash_modernbert.rs +++ b/backends/candle/src/models/flash_modernbert.rs @@ -260,11 +260,7 @@ impl FlashModernBertModel { let (pool, classifier) = match model_type { ModelType::Classifier => { - let pool: Pool = config - .classifier_pooling - .as_deref() - .and_then(|s| Pool::from_str(s).ok()) - .unwrap_or(Pool::Cls); + let pool: Pool = config.classifier_pooling.clone().unwrap_or(Pool::Cls); let classifier: Box = Box::new(ModernBertClassificationHead::load(vb.clone(), config)?); diff --git a/backends/candle/src/models/modernbert.rs b/backends/candle/src/models/modernbert.rs index ba913560..046a1547 100644 --- a/backends/candle/src/models/modernbert.rs +++ b/backends/candle/src/models/modernbert.rs @@ -7,7 +7,6 @@ use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Embedding, VarBuilder}; use serde::Deserialize; use text_embeddings_backend_core::{Batch, ModelType, Pool}; -use std::str::FromStr; // https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/configuration_modernbert.py #[derive(Debug, Clone, PartialEq, Deserialize)] @@ -38,7 +37,7 @@ pub struct ModernBertConfig { pub mlp_bias: Option, pub mlp_dropout: Option, pub decoder_bias: Option, - pub classifier_pooling: Option, + pub classifier_pooling: Option, pub classifier_dropout: Option, pub classifier_bias: Option, pub classifier_activation: HiddenAct, @@ -485,11 +484,7 @@ impl ModernBertModel { pub fn load(vb: VarBuilder, config: &ModernBertConfig, model_type: ModelType) -> Result { let (pool, classifier) = match model_type { ModelType::Classifier => { - let pool: Pool = config - .classifier_pooling - .as_deref() - .and_then(|s| Pool::from_str(s).ok()) - .unwrap_or(Pool::Cls); + let pool: Pool = config.classifier_pooling.clone().unwrap_or(Pool::Cls); let classifier: Box = Box::new(ModernBertClassificationHead::load(vb.clone(), config)?); diff --git a/backends/candle/tests/snapshots/test_modernbert__modernbert_classification_mean_pooling.snap.new b/backends/candle/tests/snapshots/test_modernbert__modernbert_classification_mean_pooling.snap similarity index 83% rename from backends/candle/tests/snapshots/test_modernbert__modernbert_classification_mean_pooling.snap.new rename to backends/candle/tests/snapshots/test_modernbert__modernbert_classification_mean_pooling.snap index 4167173d..ab58e0cf 100644 --- a/backends/candle/tests/snapshots/test_modernbert__modernbert_classification_mean_pooling.snap.new +++ b/backends/candle/tests/snapshots/test_modernbert__modernbert_classification_mean_pooling.snap @@ -1,7 +1,5 @@ --- source: backends/candle/tests/test_modernbert.rs -assertion_line: 229 expression: predictions_single --- - - -0.30617672 - diff --git a/backends/candle/tests/snapshots/test_modernbert__modernbert_classification_single.snap b/backends/candle/tests/snapshots/test_modernbert__modernbert_classification_single.snap index 20ef4d37..4674caa5 100644 --- a/backends/candle/tests/snapshots/test_modernbert__modernbert_classification_single.snap +++ b/backends/candle/tests/snapshots/test_modernbert__modernbert_classification_single.snap @@ -2,4 +2,4 @@ source: backends/candle/tests/test_modernbert.rs expression: predictions_single --- -- - 2.2585099 +- - 2.13616 diff --git a/backends/candle/tests/test_modernbert.rs b/backends/candle/tests/test_modernbert.rs index e26ab61d..9419a2d9 100644 --- a/backends/candle/tests/test_modernbert.rs +++ b/backends/candle/tests/test_modernbert.rs @@ -233,4 +233,4 @@ fn test_modernbert_classification_mean_pooling() -> Result<()> { ); Ok(()) -} \ No newline at end of file +} diff --git a/backends/core/Cargo.toml b/backends/core/Cargo.toml index 754f9526..89800a85 100644 --- a/backends/core/Cargo.toml +++ b/backends/core/Cargo.toml @@ -9,6 +9,7 @@ homepage.workspace = true thiserror = { workspace = true } clap = { workspace = true, optional = true } nohash-hasher = { workspace = true } +serde = { workspace = true } [features] clap = ["dep:clap"] diff --git a/backends/core/src/lib.rs b/backends/core/src/lib.rs index 7b927787..8e134d2b 100644 --- a/backends/core/src/lib.rs +++ b/backends/core/src/lib.rs @@ -1,6 +1,7 @@ #[cfg(feature = "clap")] use clap::ValueEnum; use nohash_hasher::IntMap; +use serde::Deserialize; use std::fmt; use thiserror::Error; @@ -52,8 +53,9 @@ pub enum ModelType { Embedding(Pool), } -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Clone, Deserialize)] #[cfg_attr(feature = "clap", derive(ValueEnum))] +#[serde(rename_all = "snake_case")] pub enum Pool { /// Select the CLS token as embedding Cls, @@ -78,23 +80,6 @@ impl fmt::Display for Pool { } } -impl std::str::FromStr for Pool { - type Err = String; - - fn from_str(s: &str) -> Result { - match s.trim().to_lowercase().as_str() { - "cls" => Ok(Pool::Cls), - "mean" => Ok(Pool::Mean), - "splade" => Ok(Pool::Splade), - "last_token" => Ok(Pool::LastToken), - _ => Err(format!( - "Invalid pooling method '{}'. Valid options: cls, mean, splade, last_token", - s - )), - } - } -} - #[derive(Debug, Error, Clone)] pub enum BackendError { #[error("No backend found")]