Skip to content

Commit 9a39ccb

Browse files
authored
add lora embedding and loading (non-functional)
1 parent 8ab8d36 commit 9a39ccb

File tree

7 files changed

+85
-12
lines changed

7 files changed

+85
-12
lines changed

convert_hf_to_gguf.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4185,22 +4185,36 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
41854185
if self._position_offset is not None:
41864186
data_torch = data_torch[self._position_offset:,:]
41874187

4188-
if name.endswith(".lora_A"):
4189-
# TODO: convert loras
4190-
return []
4191-
4192-
if name.endswith(".lora_B"):
4193-
# TODO: convert loras
4194-
return []
4188+
if name.endswith(".weight.0.lora_A") or name.endswith(".weight.0.lora_B"):
4189+
if name.startswith("pooler.dense"):
4190+
return
4191+
4192+
lora_name = self.hparams["lora_adaptations"]
4193+
num_loras = data_torch.size(0)
4194+
assert num_loras == len(lora_name)
4195+
4196+
# Split out each LoRA in their own named tensors
4197+
# Remove "weight" from the name to not confuse quantize
4198+
for i in range(num_loras):
4199+
data_lora = data_torch[i, :, :]
4200+
yield (self.map_tensor_name(name[:-16]) + name[-16:].lower().replace("weight.0.", f"<{lora_name[i]}>"), data_lora)
4201+
return
41954202

4196-
return super().modify_tensors(data_torch, name, bid)
4203+
yield from super().modify_tensors(data_torch, name, bid)
41974204

41984205
def set_gguf_parameters(self):
41994206
super().set_gguf_parameters()
42004207

42014208
# jina-embeddings-v3
42024209
if rotary_emb_base := self.hparams.get("rotary_emb_base"):
42034210
self.gguf_writer.add_rope_freq_base(rotary_emb_base)
4211+
if lora_alpha := self.hparams.get("lora_alpha"):
4212+
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, lora_alpha)
4213+
if lora_names := self.hparams.get("lora_adaptations"):
4214+
self.gguf_writer.add_array(gguf.Keys.Adapter.LORA_NAMES, lora_names)
4215+
if lora_prompt_prefixes := self.hparams.get("task_instructions"):
4216+
assert lora_names and all(lora_name in lora_prompt_prefixes for lora_name in lora_names)
4217+
self.gguf_writer.add_array(gguf.Keys.Adapter.LORA_PROMPT_PREFIXES, [lora_prompt_prefixes[lora_name] for lora_name in lora_names])
42044218

42054219

42064220
@ModelBase.register("GemmaForCausalLM")

gguf-py/gguf/constants.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,10 @@ class Tokenizer:
227227
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
228228

229229
class Adapter:
230-
TYPE = "adapter.type"
231-
LORA_ALPHA = "adapter.lora.alpha"
230+
TYPE = "adapter.type"
231+
LORA_ALPHA = "adapter.lora.alpha"
232+
LORA_NAMES = "adapter.lora.names"
233+
LORA_PROMPT_PREFIXES = "adapter.lora.prompt_prefixes"
232234

233235
class Clip:
234236
PROJECTOR_TYPE = "clip.projector_type"

src/llama-adapter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ struct llama_adapter_lora {
6666
std::vector<ggml_backend_buffer_ptr> bufs;
6767

6868
float alpha;
69+
std::string prompt_prefix;
6970

7071
llama_adapter_lora() = default;
7172
~llama_adapter_lora() = default;

src/llama-arch.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
217217
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
218218
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
219219

220-
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
221-
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
220+
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
221+
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
222+
{ LLM_KV_ADAPTER_LORA_NAMES, "adapter.lora.names" },
223+
{ LLM_KV_ADAPTER_LORA_PROMPT_PREFIXES, "adapter.lora.prompt_prefixes" },
222224

223225
// deprecated
224226
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },

src/llama-arch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ enum llm_kv {
215215

216216
LLM_KV_ADAPTER_TYPE,
217217
LLM_KV_ADAPTER_LORA_ALPHA,
218+
LLM_KV_ADAPTER_LORA_NAMES,
219+
LLM_KV_ADAPTER_LORA_PROMPT_PREFIXES,
218220

219221
LLM_KV_POSNET_EMBEDDING_LENGTH,
220222
LLM_KV_POSNET_BLOCK_COUNT,

src/llama-model.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
17201720
ggml_backend_buffer_type_t first_moved_from_buft = nullptr;
17211721
ggml_backend_buffer_type_t first_moved_to_buft = nullptr;
17221722

1723+
auto add_lora_tensors = [&](const std::string & lora_name, const std::string & tensor_name) -> void {
1724+
std::string base_name = tensor_name.substr(0, tensor_name.size() - 6);
1725+
1726+
ggml_tensor * lora_a = ml.get_tensor_meta((base_name + "<" + lora_name + ">lora_a").c_str());
1727+
ggml_tensor * lora_b = ml.get_tensor_meta((base_name + "<" + lora_name + ">lora_b").c_str());
1728+
loras[lora_name]->ab_map[tensor_name] = llama_adapter_lora_weight(lora_a, lora_b);
1729+
1730+
ml.n_created += 2;
1731+
};
1732+
17231733
auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags) -> ggml_tensor * {
17241734
ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str());
17251735

@@ -2246,6 +2256,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
22462256
case LLM_ARCH_NOMIC_BERT_MOE:
22472257
case LLM_ARCH_JINA_BERT_V3:
22482258
{
2259+
std::vector<std::string> lora_names;
2260+
22492261
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
22502262
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED);
22512263

@@ -2262,6 +2274,31 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
22622274
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
22632275
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
22642276

2277+
if (arch == LLM_ARCH_JINA_BERT_V3) {
2278+
float lora_alpha = 1.0f;
2279+
std::vector<std::string> lora_prompt_prefixes;
2280+
2281+
ml.get_key(LLM_KV_ADAPTER_LORA_ALPHA, lora_alpha, false);
2282+
ml.get_arr(LLM_KV_ADAPTER_LORA_NAMES, lora_names, false);
2283+
ml.get_arr(LLM_KV_ADAPTER_LORA_PROMPT_PREFIXES, lora_prompt_prefixes, false);
2284+
GGML_ASSERT(lora_names.size() == lora_prompt_prefixes.size());
2285+
2286+
for (size_t i = 0; i < lora_names.size(); ++i) {
2287+
llama_adapter_lora * adapter = new llama_adapter_lora();
2288+
std::string lora_name = lora_names[i];
2289+
2290+
adapter->alpha = lora_alpha;
2291+
adapter->prompt_prefix = lora_prompt_prefixes[i];
2292+
loras[lora_name] = adapter;
2293+
2294+
add_lora_tensors(lora_name, tok_embd->name);
2295+
2296+
if (type_embd) {
2297+
add_lora_tensors(lora_name, type_embd->name);
2298+
}
2299+
}
2300+
}
2301+
22652302
for (int i = 0; i < n_layer; ++i) {
22662303
auto & layer = layers[i];
22672304

@@ -2300,6 +2337,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
23002337
}
23012338
}
23022339

2340+
if (arch == LLM_ARCH_JINA_BERT_V3) {
2341+
GGML_ASSERT(layer.wqkv != nullptr);
2342+
2343+
for (const auto & lora_name : lora_names) {
2344+
add_lora_tensors(lora_name, layer.wqkv->name);
2345+
add_lora_tensors(lora_name, layer.wo->name);
2346+
add_lora_tensors(lora_name, layer.ffn_up->name);
2347+
add_lora_tensors(lora_name, layer.ffn_down->name);
2348+
}
2349+
}
2350+
23032351
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
23042352
layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
23052353
}

src/llama-model.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "llama-memory.h"
88
#include "llama-vocab.h"
99

10+
#include <map>
1011
#include <memory>
1112
#include <string>
1213
#include <unordered_map>
@@ -383,6 +384,9 @@ struct llama_model {
383384

384385
llama_model_params params;
385386

387+
// built-in LoRAs
388+
std::map<std::string, llama_adapter_lora *> loras;
389+
386390
// gguf metadata
387391
std::unordered_map<std::string, std::string> gguf_kv;
388392

0 commit comments

Comments
 (0)