Skip to content

Commit 97bdd26

Browse files
ngxsonslarencompilade
authored
Refactor lora adapter support (#8332)
* lora: load to devide buft * add patch tensor function * correct tensor patch * llama_lora_adapter_apply * correct ggml_backend_tensor_copy * add llm_build_mm * fix auto merge * update based on review comments * add convert script * no more transpose A * add f16 convert * add metadata check * add sanity check * fix ftype * add requirements * fix requirements * fix outfile * conversion: only allow selected models * fix types * cuda : do not use dmmv if the tensor does not have enough cols * llama : lora fixes * do not disable mmap with lora Co-authored-by: slaren <[email protected]> * llm_build_lora_mm_id * convert_lora : MoE LoRA conversion support * convert_lora : prefer safetensors, similarly to convert_hf * convert_hf : simplify modify_tensors for InternLM2 * convert_lora : lazy conversion * llama : load and use alpha from LoRA adapters * llama : use llm_build_lora_mm in most model graphs * auto scale * Revert "auto scale" This reverts commit 42415a4. * remove redundant params * Apply suggestions from code review Co-authored-by: slaren <[email protected]> * change kv metadata * move add_type to __init__ * convert_hf : move add_type to main() * convert_lora : use the GGUFWriter from Model instead of overwriting it --------- Co-authored-by: slaren <[email protected]> Co-authored-by: Francis Couture-Harpin <[email protected]>
1 parent 4db8f60 commit 97bdd26

12 files changed

+944
-511
lines changed

common/common.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -685,15 +685,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
685685
if (arg == "--lora") {
686686
CHECK_ARG
687687
params.lora_adapter.emplace_back(argv[i], 1.0f);
688-
params.use_mmap = false;
689688
return true;
690689
}
691690
if (arg == "--lora-scaled") {
692691
CHECK_ARG
693692
const char* lora_adapter = argv[i];
694693
CHECK_ARG
695694
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
696-
params.use_mmap = false;
697695
return true;
698696
}
699697
if (arg == "--lora-base") {
@@ -2089,19 +2087,14 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
20892087
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
20902088
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
20912089
float lora_scale = std::get<1>(params.lora_adapter[i]);
2092-
int err = llama_model_apply_lora_from_file(model,
2093-
lora_adapter.c_str(),
2094-
lora_scale,
2095-
((i > 0) || params.lora_base.empty())
2096-
? NULL
2097-
: params.lora_base.c_str(),
2098-
params.n_threads);
2099-
if (err != 0) {
2090+
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
2091+
if (adapter == nullptr) {
21002092
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
21012093
llama_free(lctx);
21022094
llama_free_model(model);
21032095
return std::make_tuple(nullptr, nullptr);
21042096
}
2097+
llama_lora_adapter_set(lctx, adapter, lora_scale);
21052098
}
21062099

21072100
if (params.ignore_eos) {

convert_hf_to_gguf.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2264,13 +2264,6 @@ def set_vocab(self):
22642264

22652265
special_vocab.add_to_gguf(self.gguf_writer)
22662266

2267-
def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
2268-
if n_head_kv is not None and n_head != n_head_kv:
2269-
n_head = n_head_kv
2270-
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
2271-
.swapaxes(1, 2)
2272-
.reshape(weights.shape))
2273-
22742267
def set_gguf_parameters(self):
22752268
self.gguf_writer.add_name("InternLM2")
22762269
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
@@ -2290,26 +2283,22 @@ def set_gguf_parameters(self):
22902283
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
22912284
num_heads = self.hparams["num_attention_heads"]
22922285
num_kv_heads = self.hparams["num_key_value_heads"]
2293-
hidden_size = self.hparams["hidden_size"]
2286+
n_embd = self.hparams["hidden_size"]
22942287
q_per_kv = num_heads // num_kv_heads
2295-
head_dim = hidden_size // num_heads
2288+
head_dim = n_embd // num_heads
22962289
num_groups = num_heads // q_per_kv
22972290

2298-
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
2299-
2300-
if re.match(qkv_pattern, name):
2301-
bid = re.findall(qkv_pattern, name)[0]
2291+
if bid is not None and f"model.layers.{bid}.attention.wqkv" in name:
23022292
qkv = data_torch
2303-
# qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
2304-
qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim))
2305-
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
2293+
2294+
qkv = qkv.reshape((num_groups, q_per_kv + 2, head_dim, n_embd))
2295+
q, k, v = qkv[:, : q_per_kv], qkv[:, -2], qkv[:, -1]
2296+
23062297
# The model weights of q and k equire additional reshape.
2307-
# q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
2308-
q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads)
2309-
# k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
2310-
k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads)
2311-
# v = rearrange(v, " o g n i -> o (g n i)").T
2312-
v = v.reshape((v.shape[0], -1)).T
2298+
q = LlamaModel.permute(q.reshape((-1, q.shape[-1])), num_heads, num_heads)
2299+
k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads)
2300+
v = v.reshape((-1, v.shape[-1]))
2301+
23132302
return [
23142303
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
23152304
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
@@ -3585,6 +3574,7 @@ def main() -> None:
35853574
small_first_shard=args.no_tensor_first_split)
35863575

35873576
logger.info("Set model parameters")
3577+
model_instance.gguf_writer.add_type(gguf.GGUFType.MODEL)
35883578
model_instance.set_gguf_parameters()
35893579

35903580
logger.info("Set model tokenizer")

0 commit comments

Comments
 (0)