diff --git a/README.md b/README.md index 49d38d8c..7712990e 100644 --- a/README.md +++ b/README.md @@ -66,22 +66,24 @@ Ember, GTE and E5. TEI implements many features such as: #### Text Embeddings Text Embeddings Inference currently supports Nomic, BERT, CamemBERT, XLM-RoBERTa models with absolute positions, JinaBERT -model with Alibi positions and Mistral, Alibaba GTE, Qwen2 models with Rope positions, MPNet, and ModernBERT. +model with Alibi positions and Mistral, Alibaba GTE, Qwen2 models with Rope positions, MPNet, ModernBERT, and Qwen3. Below are some examples of the currently supported models: | MTEB Rank | Model Size | Model Type | Model ID | |-----------|---------------------|-------------|--------------------------------------------------------------------------------------------------| -| 3 | 7B (Very Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-7B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-7B-instruct) | -| 11 | 1.5B (Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct) | -| 14 | 7B (Very Expensive) | Mistral | [Salesforce/SFR-Embedding-2_R](https://hf.co/Salesforce/SFR-Embedding-2_R) | -| 20 | 0.3B | Bert | [WhereIsAI/UAE-Large-V1](https://hf.co/WhereIsAI/UAE-Large-V1) | -| 31 | 0.5B | XLM-RoBERTa | [Snowflake/snowflake-arctic-embed-l-v2.0](https://hf.co/Snowflake/snowflake-arctic-embed-l-v2.0) | -| 37 | 0.3B | Alibaba GTE | [Snowflake/snowflake-arctic-embed-m-v2.0](https://hf.co/Snowflake/snowflake-arctic-embed-m-v2.0) | -| 49 | 0.5B | XLM-RoBERTa | [intfloat/multilingual-e5-large-instruct](https://hf.co/intfloat/multilingual-e5-large-instruct) | +| 2 | 8B (Very Expensive) | Qwen3 | [Qwen/Qwen3-Embedding-8B](https://hf.co/Qwen/Qwen3-Embedding-8B) | +| 4 | 0.6B | Qwen3 | [Qwen/Qwen3-Embedding-0.6B](https://hf.co/Qwen/Qwen3-Embedding-0.6B) | +| 6 | 7B (Very Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-7B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-7B-instruct) | +| 7 | 0.5B | XLM-RoBERTa | [intfloat/multilingual-e5-large-instruct](https://hf.co/intfloat/multilingual-e5-large-instruct) | +| 14 | 1.5B (Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct) | +| 17 | 7B (Very Expensive) | Mistral | [Salesforce/SFR-Embedding-2_R](https://hf.co/Salesforce/SFR-Embedding-2_R) | +| 34 | 0.5B | XLM-RoBERTa | [Snowflake/snowflake-arctic-embed-l-v2.0](https://hf.co/Snowflake/snowflake-arctic-embed-l-v2.0) | +| 40 | 0.3B | Alibaba GTE | [Snowflake/snowflake-arctic-embed-m-v2.0](https://hf.co/Snowflake/snowflake-arctic-embed-m-v2.0) | +| 51 | 0.3B | Bert | [WhereIsAI/UAE-Large-V1](https://hf.co/WhereIsAI/UAE-Large-V1) | | N/A | 0.4B | Alibaba GTE | [Alibaba-NLP/gte-large-en-v1.5](https://hf.co/Alibaba-NLP/gte-large-en-v1.5) | -| N/A | 0.4B | ModernBERT | [answerdotai/ModernBERT-large](https://hf.co/answerdotai/ModernBERT-large) | -| N/A | 0.3B | NomicBert | [nomic-ai/nomic-embed-text-v2-moe](https://hf.co/nomic-ai/nomic-embed-text-v2-moe) | +| N/A | 0.4B | ModernBERT | [answerdotai/ModernBERT-large](https://hf.co/answerdotai/ModernBERT-large) | +| N/A | 0.3B | NomicBert | [nomic-ai/nomic-embed-text-v2-moe](https://hf.co/nomic-ai/nomic-embed-text-v2-moe) | | N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1](https://hf.co/nomic-ai/nomic-embed-text-v1) | | N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1.5](https://hf.co/nomic-ai/nomic-embed-text-v1.5) | | N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) | diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 30cf5762..882cdb8a 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -278,7 +278,7 @@ impl CandleBackend { (Config::Qwen3(config), Device::Cpu | Device::Metal(_)) => { tracing::info!("Starting Qwen3 model on {:?}", device); Ok(Box::new(Qwen3Model::load(vb, &config, model_type).s()?)) - }, + } (Config::MPNet(config), _) => { tracing::info!("Starting MPNet model on {:?}", device); Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?)) @@ -457,12 +457,14 @@ impl CandleBackend { if dtype != DType::F16 || !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) { - return Err(BackendError::Start("Qwen3 is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string())); + tracing::info!("Starting Qwen3 model on {:?}", device); + Ok(Box::new(Qwen3Model::load(vb, &config, model_type).s()?)) + } else { + tracing::info!("Starting FlashQwen3 model on {:?}", device); + Ok(Box::new( + FlashQwen3Model::load(vb, &config, model_type).s()?, + )) } - tracing::info!("Starting FlashQwen3 model on {:?}", device); - Ok(Box::new( - FlashQwen3Model::load(vb, &config, model_type).s()?, - )) } }; diff --git a/backends/candle/src/models/qwen3.rs b/backends/candle/src/models/qwen3.rs index be8a421e..014976df 100644 --- a/backends/candle/src/models/qwen3.rs +++ b/backends/candle/src/models/qwen3.rs @@ -1,4 +1,6 @@ -use crate::layers::{apply_rotary, get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm}; +use crate::layers::{ + apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, Linear, RMSNorm, +}; use crate::models::Model; use candle::{Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module, VarBuilder}; @@ -58,7 +60,10 @@ impl Qwen3Attention { "weight", )?; let query_bias = if config.attention_bias { - Some(vb.pp("q_proj").get(num_attention_heads * attention_head_size, "bias")?) + Some( + vb.pp("q_proj") + .get(num_attention_heads * attention_head_size, "bias")?, + ) } else { None }; @@ -127,6 +132,8 @@ impl Qwen3Attention { ) -> Result { let _enter = self.span.enter(); + let device = hidden_states.device(); + let q = self.q_proj.forward(hidden_states)?; let k = self.k_proj.forward(hidden_states)?; let v = self.v_proj.forward(hidden_states)?; @@ -157,16 +164,13 @@ impl Qwen3Attention { .concat(), )?; - // Apply q_norm and k_norm let (q, _res) = self.q_norm.forward(&q, None)?; let (k, _res) = self.k_norm.forward(&k, None)?; - // Transpose to [batch, heads, seq_len, head_dim] for compatibility with apply_rotary let q = q.transpose(1, 2)?; let k = k.transpose(1, 2)?; let v = v.transpose(1, 2)?; - // Apply rotary embeddings let q = apply_rotary(&q, cos, sin, self.attention_head_size)?; let k = apply_rotary(&k, cos, sin, self.attention_head_size)?; @@ -179,6 +183,7 @@ impl Qwen3Attention { } else { k }; + let v = if self.num_key_value_heads != self.num_attention_heads { let repeat_factor = self.num_attention_heads / self.num_key_value_heads; let (b, h, s, d) = v.shape().dims4()?; @@ -188,17 +193,67 @@ impl Qwen3Attention { v }; - let attention_scores = q.matmul(&k.t()?)?; - let mut attention_scores = (attention_scores * self.softmax_scale)?; + #[allow(unused_variables)] + let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = + (device, get_cublas_lt_wrapper()) + { + #[cfg(feature = "cuda")] + { + let (batch_size, _, seq_len, _) = k.shape().dims4()?; + let q = q.flatten(0, 1)?; + let k = k.flatten(0, 1)?; + let v = v.flatten(0, 1)?; + let attention_bias = attention_bias.map(|mask| mask.flatten(0, 1)).transpose()?; + + let beta = match attention_bias.is_some() { + true => Some(1.0), + false => None, + }; - if let Some(attention_bias) = attention_bias { - attention_scores = attention_scores.add(attention_bias)?; - } + let attention_scores = cublaslt.batch_matmul( + &k, + &q, + attention_bias.as_ref(), + Some(self.softmax_scale as f32), + beta, + None, + None, + )?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + + let context_layer = cublaslt.batch_matmul( + &v.t()?.contiguous()?, + &attention_probs, + Some(&q), + None, + None, + None, + None, + )?; + + context_layer.reshape(( + batch_size, + self.num_attention_heads, + seq_len, + self.attention_head_size, + )) + } + #[cfg(not(feature = "cuda"))] + { + candle::bail!("`cuda` feature is not enabled") + } + } else { + let attn_weights = q.matmul(&k.t()?)?; + let mut attn_weights = (attn_weights * self.softmax_scale)?; + + if let Some(attention_bias) = attention_bias { + attn_weights = attn_weights.add(attention_bias)?; + } - let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; - let context_layer = attention_probs.matmul(&v.contiguous()?)?; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&v.contiguous()?) + }?; - // Transpose back and flatten: [batch, heads, seq_len, head_dim] -> [batch, seq_len, hidden_size] let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; self.o_proj.forward(&context_layer) @@ -209,6 +264,7 @@ struct Qwen3MLP { gate_up_proj: Linear, down_proj: Linear, + activation: HiddenAct, intermediate_size: usize, span: tracing::Span, @@ -237,6 +293,7 @@ impl Qwen3MLP { Ok(Self { gate_up_proj, down_proj, + activation: config.hidden_act.clone(), intermediate_size, span: tracing::span!(tracing::Level::TRACE, "mlp"), }) @@ -247,9 +304,10 @@ impl Qwen3MLP { let gate_up_states = self.gate_up_proj.forward(hidden_states)?; let gate_states = gate_up_states.narrow(D::Minus1, 0, self.intermediate_size)?; - let up_states = gate_up_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?; + let up_states = + gate_up_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?; - let gate_states = gate_states.silu()?; + let gate_states = self.activation.forward(&gate_states)?; self.down_proj.forward(&(gate_states * up_states)?) } } @@ -298,11 +356,14 @@ impl Qwen3Layer { let _enter = self.span.enter(); let (normed_hidden_states, res) = self.input_layer_norm.forward(hidden_states, None)?; - let attn_output = self.attention.forward(&normed_hidden_states, attention_bias, cos, sin)?; - let (normed_attn_res_output, attn_res) = self.post_attention_layer_norm.forward(&attn_output, Some(&res))?; + let attn_output = + self.attention + .forward(&normed_hidden_states, attention_bias, cos, sin)?; + let (normed_attn_res_output, attn_res) = self + .post_attention_layer_norm + .forward(&attn_output, Some(&res))?; let mlp_output = self.mlp.forward(&normed_attn_res_output)?; - // Add residual connections let output = (&mlp_output + &attn_res)?; Ok(output) } @@ -330,7 +391,8 @@ impl Qwen3Model { ModelType::Embedding(pool) => pool, }; - // Handle potential "model" prefix for reranker models + // The Qwen3-Reranker models contain the `model` key + // https://huggingface.co/collections/Qwen/qwen3-reranker-6841b22d0192d7ade9cdefea let vb = if vb.contains_tensor("model.embed_tokens.weight") { vb.pp("model") } else { @@ -353,19 +415,10 @@ impl Qwen3Model { .head_dim .unwrap_or(config.hidden_size / config.num_attention_heads); - let inv_freqs = get_inv_freqs( - rotary_dim, - config.rope_theta, - vb.device(), - None, - )?; + let inv_freqs = get_inv_freqs(rotary_dim, config.rope_theta, vb.device(), None)?; - let rotary_cache = get_cos_sin( - config.max_position_embeddings, - &inv_freqs, - vb.dtype(), - true, - )?; + let rotary_cache = + get_cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype(), true)?; Ok(Self { embeddings, @@ -452,11 +505,24 @@ impl Qwen3Model { None }; - (input_ids, position_ids, input_lengths, attention_bias, attention_mask) + ( + input_ids, + position_ids, + input_lengths, + attention_bias, + attention_mask, + ) } else { - // Single sequence - let input_ids = Tensor::from_vec(batch.input_ids.clone(), (1, batch.input_ids.len()), &self.device)?; - let position_ids = Tensor::from_vec(batch.position_ids.clone(), (1, batch.position_ids.len()), &self.device)?; + let input_ids = Tensor::from_vec( + batch.input_ids.clone(), + (1, batch.input_ids.len()), + &self.device, + )?; + let position_ids = Tensor::from_vec( + batch.position_ids.clone(), + (1, batch.position_ids.len()), + &self.device, + )?; let input_lengths = vec![batch.input_ids.len()]; (input_ids, position_ids, input_lengths, None, None) @@ -464,8 +530,14 @@ impl Qwen3Model { let mut hidden_states = self.embeddings.forward(&input_ids)?; - let cos = self.rotary_cache.0.index_select(&position_ids.flatten_all()?, 0)?; - let sin = self.rotary_cache.1.index_select(&position_ids.flatten_all()?, 0)?; + let cos = self + .rotary_cache + .0 + .index_select(&position_ids.flatten_all()?, 0)?; + let sin = self + .rotary_cache + .1 + .index_select(&position_ids.flatten_all()?, 0)?; let cos = cos.reshape((batch_size, 1, max_length, self.rotary_dim))?; let sin = sin.reshape((batch_size, 1, max_length, self.rotary_dim))?; @@ -490,7 +562,11 @@ impl Qwen3Model { )?; let cls_indices = if has_raw_requests { - Tensor::zeros(batch.pooled_indices.len(), candle::DType::U32, &self.device)? + Tensor::zeros( + batch.pooled_indices.len(), + candle::DType::U32, + &self.device, + )? } else { Tensor::arange(0u32, batch_size as u32, &self.device)? }; @@ -548,14 +624,6 @@ impl Qwen3Model { None }; - // Note: L2 normalization removed to match flash implementation behavior - // let pooled_embeddings = if let Some(embeddings) = pooled_embeddings { - // let norm = embeddings.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?; - // Some(embeddings.broadcast_div(&norm)?) - // } else { - // None - // }; - let raw_embeddings = if has_raw_requests { if batch_size > 1 && has_pooling_requests { let mut final_embeddings = Vec::new(); diff --git a/backends/candle/tests/test_qwen3.rs b/backends/candle/tests/test_qwen3.rs index e425d667..1f173886 100644 --- a/backends/candle/tests/test_qwen3.rs +++ b/backends/candle/tests/test_qwen3.rs @@ -48,4 +48,4 @@ fn test_qwen3() -> Result<()> { assert_eq!(embeddings_batch[2], embeddings_single[0]); Ok(()) -} \ No newline at end of file +}