Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,24 @@ Ember, GTE and E5. TEI implements many features such as:
#### Text Embeddings

Text Embeddings Inference currently supports Nomic, BERT, CamemBERT, XLM-RoBERTa models with absolute positions, JinaBERT
model with Alibi positions and Mistral, Alibaba GTE, Qwen2 models with Rope positions, MPNet, and ModernBERT.
model with Alibi positions and Mistral, Alibaba GTE, Qwen2 models with Rope positions, MPNet, ModernBERT, and Qwen3.

Below are some examples of the currently supported models:

| MTEB Rank | Model Size | Model Type | Model ID |
|-----------|---------------------|-------------|--------------------------------------------------------------------------------------------------|
| 3 | 7B (Very Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-7B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-7B-instruct) |
| 11 | 1.5B (Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct) |
| 14 | 7B (Very Expensive) | Mistral | [Salesforce/SFR-Embedding-2_R](https://hf.co/Salesforce/SFR-Embedding-2_R) |
| 20 | 0.3B | Bert | [WhereIsAI/UAE-Large-V1](https://hf.co/WhereIsAI/UAE-Large-V1) |
| 31 | 0.5B | XLM-RoBERTa | [Snowflake/snowflake-arctic-embed-l-v2.0](https://hf.co/Snowflake/snowflake-arctic-embed-l-v2.0) |
| 37 | 0.3B | Alibaba GTE | [Snowflake/snowflake-arctic-embed-m-v2.0](https://hf.co/Snowflake/snowflake-arctic-embed-m-v2.0) |
| 49 | 0.5B | XLM-RoBERTa | [intfloat/multilingual-e5-large-instruct](https://hf.co/intfloat/multilingual-e5-large-instruct) |
| 2 | 8B (Very Expensive) | Qwen3 | [Qwen/Qwen3-Embedding-8B](https://hf.co/Qwen/Qwen3-Embedding-8B) |
| 4 | 0.6B | Qwen3 | [Qwen/Qwen3-Embedding-0.6B](https://hf.co/Qwen/Qwen3-Embedding-0.6B) |
| 6 | 7B (Very Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-7B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-7B-instruct) |
| 7 | 0.5B | XLM-RoBERTa | [intfloat/multilingual-e5-large-instruct](https://hf.co/intfloat/multilingual-e5-large-instruct) |
| 14 | 1.5B (Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct) |
| 17 | 7B (Very Expensive) | Mistral | [Salesforce/SFR-Embedding-2_R](https://hf.co/Salesforce/SFR-Embedding-2_R) |
| 34 | 0.5B | XLM-RoBERTa | [Snowflake/snowflake-arctic-embed-l-v2.0](https://hf.co/Snowflake/snowflake-arctic-embed-l-v2.0) |
| 40 | 0.3B | Alibaba GTE | [Snowflake/snowflake-arctic-embed-m-v2.0](https://hf.co/Snowflake/snowflake-arctic-embed-m-v2.0) |
| 51 | 0.3B | Bert | [WhereIsAI/UAE-Large-V1](https://hf.co/WhereIsAI/UAE-Large-V1) |
| N/A | 0.4B | Alibaba GTE | [Alibaba-NLP/gte-large-en-v1.5](https://hf.co/Alibaba-NLP/gte-large-en-v1.5) |
| N/A | 0.4B | ModernBERT | [answerdotai/ModernBERT-large](https://hf.co/answerdotai/ModernBERT-large) |
| N/A | 0.3B | NomicBert | [nomic-ai/nomic-embed-text-v2-moe](https://hf.co/nomic-ai/nomic-embed-text-v2-moe) |
| N/A | 0.4B | ModernBERT | [answerdotai/ModernBERT-large](https://hf.co/answerdotai/ModernBERT-large) |
| N/A | 0.3B | NomicBert | [nomic-ai/nomic-embed-text-v2-moe](https://hf.co/nomic-ai/nomic-embed-text-v2-moe) |
| N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1](https://hf.co/nomic-ai/nomic-embed-text-v1) |
| N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1.5](https://hf.co/nomic-ai/nomic-embed-text-v1.5) |
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
Expand Down
14 changes: 8 additions & 6 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ impl CandleBackend {
(Config::Qwen3(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting Qwen3 model on {:?}", device);
Ok(Box::new(Qwen3Model::load(vb, &config, model_type).s()?))
},
}
(Config::MPNet(config), _) => {
tracing::info!("Starting MPNet model on {:?}", device);
Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?))
Expand Down Expand Up @@ -457,12 +457,14 @@ impl CandleBackend {
if dtype != DType::F16
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
{
return Err(BackendError::Start("Qwen3 is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
tracing::info!("Starting Qwen3 model on {:?}", device);
Ok(Box::new(Qwen3Model::load(vb, &config, model_type).s()?))
} else {
tracing::info!("Starting FlashQwen3 model on {:?}", device);
Ok(Box::new(
FlashQwen3Model::load(vb, &config, model_type).s()?,
))
}
tracing::info!("Starting FlashQwen3 model on {:?}", device);
Ok(Box::new(
FlashQwen3Model::load(vb, &config, model_type).s()?,
))
}
};

Expand Down
160 changes: 114 additions & 46 deletions backends/candle/src/models/qwen3.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::layers::{apply_rotary, get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm};
use crate::layers::{
apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, Linear, RMSNorm,
};
use crate::models::Model;
use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module, VarBuilder};
Expand Down Expand Up @@ -58,7 +60,10 @@ impl Qwen3Attention {
"weight",
)?;
let query_bias = if config.attention_bias {
Some(vb.pp("q_proj").get(num_attention_heads * attention_head_size, "bias")?)
Some(
vb.pp("q_proj")
.get(num_attention_heads * attention_head_size, "bias")?,
)
} else {
None
};
Expand Down Expand Up @@ -127,6 +132,8 @@ impl Qwen3Attention {
) -> Result<Tensor> {
let _enter = self.span.enter();

let device = hidden_states.device();

let q = self.q_proj.forward(hidden_states)?;
let k = self.k_proj.forward(hidden_states)?;
let v = self.v_proj.forward(hidden_states)?;
Expand Down Expand Up @@ -157,16 +164,13 @@ impl Qwen3Attention {
.concat(),
)?;

// Apply q_norm and k_norm
let (q, _res) = self.q_norm.forward(&q, None)?;
let (k, _res) = self.k_norm.forward(&k, None)?;

// Transpose to [batch, heads, seq_len, head_dim] for compatibility with apply_rotary
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;

// Apply rotary embeddings
let q = apply_rotary(&q, cos, sin, self.attention_head_size)?;
let k = apply_rotary(&k, cos, sin, self.attention_head_size)?;

Expand All @@ -179,6 +183,7 @@ impl Qwen3Attention {
} else {
k
};

let v = if self.num_key_value_heads != self.num_attention_heads {
let repeat_factor = self.num_attention_heads / self.num_key_value_heads;
let (b, h, s, d) = v.shape().dims4()?;
Expand All @@ -188,17 +193,67 @@ impl Qwen3Attention {
v
};

let attention_scores = q.matmul(&k.t()?)?;
let mut attention_scores = (attention_scores * self.softmax_scale)?;
#[allow(unused_variables)]
let context_layer = if let (Device::Cuda(_), Some(cublaslt)) =
(device, get_cublas_lt_wrapper())
{
#[cfg(feature = "cuda")]
{
let (batch_size, _, seq_len, _) = k.shape().dims4()?;
let q = q.flatten(0, 1)?;
let k = k.flatten(0, 1)?;
let v = v.flatten(0, 1)?;
let attention_bias = attention_bias.map(|mask| mask.flatten(0, 1)).transpose()?;

let beta = match attention_bias.is_some() {
true => Some(1.0),
false => None,
};

if let Some(attention_bias) = attention_bias {
attention_scores = attention_scores.add(attention_bias)?;
}
let attention_scores = cublaslt.batch_matmul(
&k,
&q,
attention_bias.as_ref(),
Some(self.softmax_scale as f32),
beta,
None,
None,
)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;

let context_layer = cublaslt.batch_matmul(
&v.t()?.contiguous()?,
&attention_probs,
Some(&q),
None,
None,
None,
None,
)?;

context_layer.reshape((
batch_size,
self.num_attention_heads,
seq_len,
self.attention_head_size,
))
}
#[cfg(not(feature = "cuda"))]
{
candle::bail!("`cuda` feature is not enabled")
}
} else {
let attn_weights = q.matmul(&k.t()?)?;
let mut attn_weights = (attn_weights * self.softmax_scale)?;

if let Some(attention_bias) = attention_bias {
attn_weights = attn_weights.add(attention_bias)?;
}

let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
let context_layer = attention_probs.matmul(&v.contiguous()?)?;
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&v.contiguous()?)
}?;

// Transpose back and flatten: [batch, heads, seq_len, head_dim] -> [batch, seq_len, hidden_size]
let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?;

self.o_proj.forward(&context_layer)
Expand All @@ -209,6 +264,7 @@ struct Qwen3MLP {
gate_up_proj: Linear,
down_proj: Linear,

activation: HiddenAct,
intermediate_size: usize,

span: tracing::Span,
Expand Down Expand Up @@ -237,6 +293,7 @@ impl Qwen3MLP {
Ok(Self {
gate_up_proj,
down_proj,
activation: config.hidden_act.clone(),
intermediate_size,
span: tracing::span!(tracing::Level::TRACE, "mlp"),
})
Expand All @@ -247,9 +304,10 @@ impl Qwen3MLP {

let gate_up_states = self.gate_up_proj.forward(hidden_states)?;
let gate_states = gate_up_states.narrow(D::Minus1, 0, self.intermediate_size)?;
let up_states = gate_up_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?;
let up_states =
gate_up_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?;

let gate_states = gate_states.silu()?;
let gate_states = self.activation.forward(&gate_states)?;
self.down_proj.forward(&(gate_states * up_states)?)
}
}
Expand Down Expand Up @@ -298,11 +356,14 @@ impl Qwen3Layer {
let _enter = self.span.enter();

let (normed_hidden_states, res) = self.input_layer_norm.forward(hidden_states, None)?;
let attn_output = self.attention.forward(&normed_hidden_states, attention_bias, cos, sin)?;
let (normed_attn_res_output, attn_res) = self.post_attention_layer_norm.forward(&attn_output, Some(&res))?;
let attn_output =
self.attention
.forward(&normed_hidden_states, attention_bias, cos, sin)?;
let (normed_attn_res_output, attn_res) = self
.post_attention_layer_norm
.forward(&attn_output, Some(&res))?;
let mlp_output = self.mlp.forward(&normed_attn_res_output)?;

// Add residual connections
let output = (&mlp_output + &attn_res)?;
Ok(output)
}
Expand Down Expand Up @@ -330,7 +391,8 @@ impl Qwen3Model {
ModelType::Embedding(pool) => pool,
};

// Handle potential "model" prefix for reranker models
// The Qwen3-Reranker models contain the `model` key
// https://huggingface.co/collections/Qwen/qwen3-reranker-6841b22d0192d7ade9cdefea
let vb = if vb.contains_tensor("model.embed_tokens.weight") {
vb.pp("model")
} else {
Expand All @@ -353,19 +415,10 @@ impl Qwen3Model {
.head_dim
.unwrap_or(config.hidden_size / config.num_attention_heads);

let inv_freqs = get_inv_freqs(
rotary_dim,
config.rope_theta,
vb.device(),
None,
)?;
let inv_freqs = get_inv_freqs(rotary_dim, config.rope_theta, vb.device(), None)?;

let rotary_cache = get_cos_sin(
config.max_position_embeddings,
&inv_freqs,
vb.dtype(),
true,
)?;
let rotary_cache =
get_cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype(), true)?;

Ok(Self {
embeddings,
Expand Down Expand Up @@ -452,20 +505,39 @@ impl Qwen3Model {
None
};

(input_ids, position_ids, input_lengths, attention_bias, attention_mask)
(
input_ids,
position_ids,
input_lengths,
attention_bias,
attention_mask,
)
} else {
// Single sequence
let input_ids = Tensor::from_vec(batch.input_ids.clone(), (1, batch.input_ids.len()), &self.device)?;
let position_ids = Tensor::from_vec(batch.position_ids.clone(), (1, batch.position_ids.len()), &self.device)?;
let input_ids = Tensor::from_vec(
batch.input_ids.clone(),
(1, batch.input_ids.len()),
&self.device,
)?;
let position_ids = Tensor::from_vec(
batch.position_ids.clone(),
(1, batch.position_ids.len()),
&self.device,
)?;
let input_lengths = vec![batch.input_ids.len()];

(input_ids, position_ids, input_lengths, None, None)
};

let mut hidden_states = self.embeddings.forward(&input_ids)?;

let cos = self.rotary_cache.0.index_select(&position_ids.flatten_all()?, 0)?;
let sin = self.rotary_cache.1.index_select(&position_ids.flatten_all()?, 0)?;
let cos = self
.rotary_cache
.0
.index_select(&position_ids.flatten_all()?, 0)?;
let sin = self
.rotary_cache
.1
.index_select(&position_ids.flatten_all()?, 0)?;

let cos = cos.reshape((batch_size, 1, max_length, self.rotary_dim))?;
let sin = sin.reshape((batch_size, 1, max_length, self.rotary_dim))?;
Expand All @@ -490,7 +562,11 @@ impl Qwen3Model {
)?;

let cls_indices = if has_raw_requests {
Tensor::zeros(batch.pooled_indices.len(), candle::DType::U32, &self.device)?
Tensor::zeros(
batch.pooled_indices.len(),
candle::DType::U32,
&self.device,
)?
} else {
Tensor::arange(0u32, batch_size as u32, &self.device)?
};
Expand Down Expand Up @@ -548,14 +624,6 @@ impl Qwen3Model {
None
};

// Note: L2 normalization removed to match flash implementation behavior
// let pooled_embeddings = if let Some(embeddings) = pooled_embeddings {
// let norm = embeddings.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
// Some(embeddings.broadcast_div(&norm)?)
// } else {
// None
// };

let raw_embeddings = if has_raw_requests {
if batch_size > 1 && has_pooling_requests {
let mut final_embeddings = Vec::new();
Expand Down
2 changes: 1 addition & 1 deletion backends/candle/tests/test_qwen3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ fn test_qwen3() -> Result<()> {
assert_eq!(embeddings_batch[2], embeddings_single[0]);

Ok(())
}
}