Skip to content
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ Below are some examples of the currently supported models:
| 49 | 0.5B | XLM-RoBERTa | [intfloat/multilingual-e5-large-instruct](https://hf.co/intfloat/multilingual-e5-large-instruct) |
| N/A | 0.4B | Alibaba GTE | [Alibaba-NLP/gte-large-en-v1.5](https://hf.co/Alibaba-NLP/gte-large-en-v1.5) |
| N/A | 0.4B | ModernBERT | [answerdotai/ModernBERT-large](https://hf.co/answerdotai/ModernBERT-large) |
| N/A | 0.3B | NomicBert | [nomic-ai/nomic-embed-text-v2-moe](https://hf.co/nomic-ai/nomic-embed-text-v2-moe) |
| N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1](https://hf.co/nomic-ai/nomic-embed-text-v1) |
| N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1.5](https://hf.co/nomic-ai/nomic-embed-text-v1.5) |
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
Expand Down
37 changes: 25 additions & 12 deletions backends/candle/src/models/flash_nomic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear};
use crate::models::nomic::{NomicBertEmbeddings, NomicBertGatedMLP};
use crate::models::nomic::{
NomicBertEmbeddings, NomicBertGatedMLP, NomicBertMLP, NomicMLP, NomicMoELayer,
};
use crate::models::{Model, NomicConfig};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::VarBuilder;
Expand All @@ -25,16 +27,25 @@ impl NomicAttention {
let attention_head_size = config.n_embd / config.n_head;
let hidden_size = config.n_embd;

let qkv_weight = vb.pp("Wqkv").get(
(3 * num_attention_heads * attention_head_size, hidden_size),
"weight",
)?;
let qkv_linear = Linear::new(qkv_weight, None, None);
let qkv_dim = 3 * num_attention_heads * attention_head_size;

let qkv_weight = vb.pp("Wqkv").get((qkv_dim, hidden_size), "weight")?;
let qkv_bias = if config.qkv_proj_bias {
Some(vb.pp("Wqkv").get((qkv_dim,), "bias")?)
} else {
None
};
let qkv_linear = Linear::new(qkv_weight, qkv_bias, None);

let out_proj_weight = vb
.pp("out_proj")
.get((hidden_size, hidden_size), "weight")?;
let out_proj = Linear::new(out_proj_weight, None, None);
let out_proj_bias = if config.qkv_proj_bias {
Some(vb.pp("out_proj").get((hidden_size,), "bias")?)
} else {
None
};
let out_proj = Linear::new(out_proj_weight, out_proj_bias, None);

let softmax_scale = (1. / (attention_head_size as f64).sqrt()) as f32;

Expand Down Expand Up @@ -93,17 +104,18 @@ impl NomicAttention {

struct NomicBertBlock {
attention: NomicAttention,
mlp: NomicBertGatedMLP,
mlp: NomicMLP,
post_attention_layer_norm: LayerNorm,
output_layer_norm: LayerNorm,

span: tracing::Span,
}

impl NomicBertBlock {
pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result<Self> {
pub fn load(vb: VarBuilder, index: usize, config: &NomicConfig) -> Result<Self> {
let attention = NomicAttention::load(vb.pp("attn"), config)?;
let mlp = NomicBertGatedMLP::load(vb.pp("mlp"), config)?;

let mlp = NomicMLP::load(vb.pp("mlp"), index, config)?;

let post_attention_layer_norm =
LayerNorm::load(vb.pp("norm1"), config.n_embd, config.layer_norm_epsilon)?;
Expand Down Expand Up @@ -132,6 +144,7 @@ impl NomicBertBlock {
let attn_output = self
.attention
.forward(&hidden_states, cu_seqlens, cos, sin, max_s)?;

let hidden_states = self
.post_attention_layer_norm
.forward(&hidden_states, Some(&attn_output))?;
Expand All @@ -145,13 +158,14 @@ impl NomicBertBlock {

struct NomicBertEncoder {
layers: Vec<NomicBertBlock>,

span: tracing::Span,
}

impl NomicBertEncoder {
pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result<Self> {
let layers = (0..config.n_layer)
.map(|index| NomicBertBlock::load(vb.pp(format!("layers.{index}")), config))
.map(|index| NomicBertBlock::load(vb.pp(format!("layers.{index}")), index, config))
.collect::<Result<Vec<_>>>()?;

let span = tracing::span!(tracing::Level::TRACE, "encoder");
Expand All @@ -170,7 +184,6 @@ impl NomicBertEncoder {

let mut hidden_states = hidden_states.clone();

// Use a loop rather than a fold as it's easier to modify when adding debug/...
for layer in self.layers.iter() {
hidden_states = layer.forward(&hidden_states, cu_seqlens, cos, sin, max_s)?
}
Expand Down
Loading