Skip to content

Commit 9522284

Browse files
committed
Update llama.cpp
1 parent b4a3db3 commit 9522284

File tree

3 files changed

+72
-5
lines changed

3 files changed

+72
-5
lines changed

llama_cpp/llama.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,15 +282,18 @@ def __init__(
282282
if not os.path.exists(model_path):
283283
raise ValueError(f"Model path does not exist: {model_path}")
284284

285-
self.ctx = llama_cpp.llama_init_from_file(
285+
self.model = llama_cpp.llama_load_model_from_file(
286286
self.model_path.encode("utf-8"), self.params
287287
)
288+
assert self.model is not None
289+
290+
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.params)
288291

289292
assert self.ctx is not None
290293

291294
if self.lora_path:
292-
if llama_cpp.llama_apply_lora_from_file(
293-
self.ctx,
295+
if llama_cpp.llama_model_apply_lora_from_file(
296+
self.model,
294297
llama_cpp.c_char_p(self.lora_path.encode("utf-8")),
295298
llama_cpp.c_char_p(self.lora_base.encode("utf-8"))
296299
if self.lora_base is not None

llama_cpp/llama_cpp.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
c_size_t,
1616
)
1717
import pathlib
18-
from typing import List
18+
from typing import List, Union
1919

2020

2121
# Load the library
@@ -105,6 +105,9 @@ def _load_shared_library(lib_base_name: str):
105105
LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN
106106
LLAMA_SESSION_VERSION = c_int(1)
107107

108+
# struct llama_model;
109+
llama_model_p = c_void_p
110+
108111
# struct llama_context;
109112
llama_context_p = c_void_p
110113

@@ -161,6 +164,7 @@ class llama_token_data_array(Structure):
161164
# // context pointer passed to the progress callback
162165
# void * progress_callback_user_data;
163166

167+
164168
# // Keep the booleans together to avoid misalignment during copy-by-value.
165169
# bool low_vram; // if true, reduce VRAM usage at the cost of performance
166170
# bool f16_kv; // use fp16 for KV cache
@@ -296,6 +300,41 @@ def llama_init_backend():
296300
_lib.llama_init_backend.restype = None
297301

298302

303+
# LLAMA_API struct llama_model * llama_load_model_from_file(
304+
# const char * path_model,
305+
# struct llama_context_params params);
306+
def llama_load_model_from_file(
307+
path_model: bytes, params: llama_context_params
308+
) -> llama_model_p:
309+
return _lib.llama_load_model_from_file(path_model, params)
310+
311+
312+
_lib.llama_load_model_from_file.argtypes = [c_char_p, llama_context_params]
313+
_lib.llama_load_model_from_file.restype = llama_model_p
314+
315+
316+
# LLAMA_API void llama_free_model(struct llama_model * model);
317+
def llama_free_model(model: llama_model_p):
318+
return _lib.llama_free_model(model)
319+
320+
321+
_lib.llama_free_model.argtypes = [llama_model_p]
322+
_lib.llama_free_model.restype = None
323+
324+
325+
# LLAMA_API struct llama_context * llama_new_context_with_model(
326+
# struct llama_model * model,
327+
# struct llama_context_params params);
328+
def llama_new_context_with_model(
329+
model: llama_model_p, params: llama_context_params
330+
) -> llama_context_p:
331+
return _lib.llama_new_context_with_model(model, params)
332+
333+
334+
_lib.llama_new_context_with_model.argtypes = [llama_model_p, llama_context_params]
335+
_lib.llama_new_context_with_model.restype = llama_context_p
336+
337+
299338
# LLAMA_API int64_t llama_time_us();
300339
def llama_time_us() -> int:
301340
return _lib.llama_time_us()
@@ -376,6 +415,31 @@ def llama_apply_lora_from_file(
376415
_lib.llama_apply_lora_from_file.restype = c_int
377416

378417

418+
# LLAMA_API int llama_model_apply_lora_from_file(
419+
# const struct llama_model * model,
420+
# const char * path_lora,
421+
# const char * path_base_model,
422+
# int n_threads);
423+
def llama_model_apply_lora_from_file(
424+
model: llama_model_p,
425+
path_lora: Union[c_char_p, bytes],
426+
path_base_model: Union[c_char_p, bytes],
427+
n_threads: c_int,
428+
) -> int:
429+
return _lib.llama_model_apply_lora_from_file(
430+
model, path_lora, path_base_model, n_threads
431+
)
432+
433+
434+
_lib.llama_model_apply_lora_from_file.argtypes = [
435+
llama_model_p,
436+
c_char_p,
437+
c_char_p,
438+
c_int,
439+
]
440+
_lib.llama_model_apply_lora_from_file.restype = c_int
441+
442+
379443
# Returns the number of tokens in the KV cache
380444
# LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
381445
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int:

vendor/llama.cpp

0 commit comments

Comments
 (0)