Skip to content

Commit b83cc3f

Browse files
JoanFMggerganov
andauthored
llama : add Jina Embeddings architecture (#6826)
* feat: first things to do * feat: create tensors for Jina architecture * fix: use other tensors * feat: embedding gets results * fix: fix usage of ALIBI * fix: clean prints * fix: do some cleanup unused vars * fix: revert changes to Makefile and CMakeLists * fix: revert some changes * fix: fix small detail * fix: fix convert formatting * fix: fix linting and editor * feat: set proper vocab settings * fix: JinaBertForMaskedLM registration * feat: support q_normalization and k_normalization in Jina arch * feat: handle gpt2 tokenizer with Jina architecture * feat: example comments in embedding * feat: rename Jina Bert to Jina Bert V2 * fix: add some changes as per review * feat: proper KQ_pos for Jina embeddings * feat: add capacity to load models ES and DE for Spanish * llama : fix pre-tokenizers * ggml : full ALiBi support * ggml : update ggml_soft_max_ext() CUDA, SYCL * ggml : ggml_flash_attn_ext() support ALiBi (CPU) * ggml : ggml_flash_attn_ext() support ALiBi (Metal) * ggml : fix warning * ggml : ggml_flash_attn_ext() support ALiBi (CUDA) ggml-ci * minor : clean-up * embedding : add warning about missing SEP --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 9cb317f commit b83cc3f

File tree

6 files changed

+236
-41
lines changed

6 files changed

+236
-41
lines changed

convert-hf-to-gguf-update.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class TOKENIZER_TYPE(IntEnum):
7474
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
7575
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
7676
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
77+
{"name": "jina-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
78+
{"name": "jina-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
79+
{"name": "jina-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
7780
]
7881

7982
# make directory "models/tokenizers" if it doesn't exist

convert-hf-to-gguf.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,17 @@ def get_vocab_base_pre(self, tokenizer) -> str:
404404
# ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf
405405
res = "olmo"
406406
if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
407-
# ref: https://huggingface.co/databricks/dbrx-instruct
407+
# ref: https://huggingface.co/databricks/dbrx-base
408408
res = "dbrx"
409+
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
410+
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en
411+
res = "jina-en"
412+
if chkhsh == "171aeeedd6fb548d418a7461d053f11b6f1f1fc9b387bd66640d28a4b9f5c643":
413+
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-es
414+
res = "jina-es"
415+
if chkhsh == "27949a2493fc4a9f53f5b9b029c82689cfbe5d3a1929bb25e043089e28466de6":
416+
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-de
417+
res = "jina-de"
409418

410419
if res is None:
411420
logger.warning("\n")
@@ -2289,6 +2298,43 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
22892298
return [(self.map_tensor_name(name), data_torch)]
22902299

22912300

2301+
@Model.register("JinaBertModel", "JinaBertForMaskedLM")
2302+
class JinaBertV2Model(BertModel):
2303+
model_arch = gguf.MODEL_ARCH.JINA_BERT_V2
2304+
2305+
def __init__(self, *args, **kwargs):
2306+
super().__init__(*args, **kwargs)
2307+
self.intermediate_size = self.hparams["intermediate_size"]
2308+
2309+
def get_tensors(self):
2310+
for name, data in super().get_tensors():
2311+
if 'gated_layers' in name:
2312+
d1 = data[:self.intermediate_size, :]
2313+
name1 = name.replace('gated_layers', 'gated_layers_w')
2314+
d2 = data[self.intermediate_size:, :]
2315+
name2 = name.replace('gated_layers', 'gated_layers_v')
2316+
yield name1, d1
2317+
yield name2, d2
2318+
continue
2319+
2320+
yield name, data
2321+
2322+
def set_vocab(self, *args, **kwargs):
2323+
tokenizer_class = 'BertTokenizer'
2324+
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
2325+
tokenizer_class = json.load(f)['tokenizer_class']
2326+
2327+
if tokenizer_class == 'BertTokenizer':
2328+
super().set_vocab()
2329+
elif tokenizer_class == 'RobertaTokenizer':
2330+
self._set_vocab_gpt2()
2331+
self.gguf_writer.add_token_type_count(2)
2332+
else:
2333+
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
2334+
self.gguf_writer.add_add_bos_token(True)
2335+
self.gguf_writer.add_add_eos_token(True)
2336+
2337+
22922338
###### CONVERSION LOGIC ######
22932339

22942340

examples/embedding/embedding.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
4949
}
5050

5151
float * out = output + batch.seq_id[i][0] * n_embd;
52+
//TODO: I would also add a parameter here to enable normalization or not.
53+
/*fprintf(stdout, "unnormalized_embedding:");
54+
for (int hh = 0; hh < n_embd; hh++) {
55+
fprintf(stdout, "%9.6f ", embd[hh]);
56+
}
57+
fprintf(stdout, "\n");*/
5258
llama_embd_normalize(embd, out, n_embd);
5359
}
5460
}
@@ -123,10 +129,12 @@ int main(int argc, char ** argv) {
123129
inputs.push_back(inp);
124130
}
125131

126-
// add SEP if not present
132+
// check if the last token is SEP
133+
// it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
127134
for (auto & inp : inputs) {
128135
if (inp.empty() || inp.back() != llama_token_sep(model)) {
129-
inp.push_back(llama_token_sep(model));
136+
fprintf(stderr, "%s: warning: last token in the prompt is not SEP\n", __func__);
137+
fprintf(stderr, "%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
130138
}
131139
}
132140

gguf-py/gguf/constants.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class MODEL_ARCH(IntEnum):
118118
REFACT = auto()
119119
BERT = auto()
120120
NOMIC_BERT = auto()
121+
JINA_BERT_V2 = auto()
121122
BLOOM = auto()
122123
STABLELM = auto()
123124
QWEN = auto()
@@ -195,6 +196,7 @@ class MODEL_TENSOR(IntEnum):
195196
MODEL_ARCH.REFACT: "refact",
196197
MODEL_ARCH.BERT: "bert",
197198
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
199+
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
198200
MODEL_ARCH.BLOOM: "bloom",
199201
MODEL_ARCH.STABLELM: "stablelm",
200202
MODEL_ARCH.QWEN: "qwen",
@@ -380,6 +382,22 @@ class MODEL_TENSOR(IntEnum):
380382
MODEL_TENSOR.FFN_UP,
381383
MODEL_TENSOR.LAYER_OUT_NORM,
382384
],
385+
MODEL_ARCH.JINA_BERT_V2: [
386+
MODEL_TENSOR.TOKEN_EMBD,
387+
MODEL_TENSOR.TOKEN_EMBD_NORM,
388+
MODEL_TENSOR.TOKEN_TYPES,
389+
MODEL_TENSOR.ATTN_OUT_NORM,
390+
MODEL_TENSOR.ATTN_Q,
391+
MODEL_TENSOR.ATTN_Q_NORM,
392+
MODEL_TENSOR.ATTN_K,
393+
MODEL_TENSOR.ATTN_K_NORM,
394+
MODEL_TENSOR.ATTN_V,
395+
MODEL_TENSOR.ATTN_OUT,
396+
MODEL_TENSOR.FFN_UP,
397+
MODEL_TENSOR.FFN_GATE,
398+
MODEL_TENSOR.FFN_DOWN,
399+
MODEL_TENSOR.LAYER_OUT_NORM,
400+
],
383401
MODEL_ARCH.MPT: [
384402
MODEL_TENSOR.TOKEN_EMBD,
385403
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ class TensorNameMap:
243243
"model.layers.{bid}.feed_forward.w3", # internlm2
244244
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
245245
"model.layers.{bid}.mlp.c_fc", # starcoder2
246+
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
246247
),
247248

248249
MODEL_TENSOR.FFN_UP_EXP: (
@@ -269,6 +270,7 @@ class TensorNameMap:
269270
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
270271
"model.layers.{bid}.feed_forward.w1", # internlm2
271272
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
273+
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
272274
"transformer.h.{bid}.mlp.linear_1", # refact
273275
),
274276

@@ -303,6 +305,7 @@ class TensorNameMap:
303305
"model.layers.{bid}.feed_forward.w2", # internlm2
304306
"encoder.layers.{bid}.mlp.fc2", # nomic-bert
305307
"model.layers.{bid}.mlp.c_proj", # starcoder2
308+
"encoder.layer.{bid}.mlp.wo", # jina-bert-v2
306309
),
307310

308311
MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -321,13 +324,15 @@ class TensorNameMap:
321324
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
322325
"model.layers.{bid}.self_attn.q_norm", # cohere
323326
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
327+
"encoder.layer.{bid}.attention.self.layer_norm_q" # jina-bert-v2
324328
),
325329

326330
MODEL_TENSOR.ATTN_K_NORM: (
327331
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
328332
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
329333
"model.layers.{bid}.self_attn.k_norm", # cohere
330334
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
335+
"encoder.layer.{bid}.attention.self.layer_norm_k" # jina-bert-v2
331336
),
332337

333338
MODEL_TENSOR.ROPE_FREQS: (
@@ -338,6 +343,7 @@ class TensorNameMap:
338343
"encoder.layer.{bid}.output.LayerNorm", # bert
339344
"encoder.layers.{bid}.norm2", # nomic-bert
340345
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
346+
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
341347
),
342348

343349
MODEL_TENSOR.SSM_IN: (

0 commit comments

Comments
 (0)