Skip to content

Commit db6ceb6

Browse files
committed
Merge branch 'main' of github.com:abetlen/llama_cpp_python into main
2 parents 9bc912b + b615fc3 commit db6ceb6

File tree

9 files changed

+91
-60
lines changed

9 files changed

+91
-60
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ Below is a short example demonstrating how to use the low-level API to tokenize
187187
>>> import ctypes
188188
>>> params = llama_cpp.llama_context_default_params()
189189
# use bytes for char * params
190-
>>> ctx = llama_cpp.llama_init_from_file(b"./models/7b/ggml-model.bin", params)
190+
>>> model = llama_cpp.llama_load_model_from_file(b"./models/7b/ggml-model.bin", params)
191+
>>> ctx = llama_cpp.llama_new_context_with_model(model, params)
191192
>>> max_tokens = params.n_ctx
192193
# use ctypes arrays for array params
193194
>>> tokens = (llama_cpp.llama_token * int(max_tokens))()

examples/low_level_api/low_level_api_chat_cpp.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ class LLaMAInteract:
2424
def __init__(self, params: GptParams) -> None:
2525
# input args
2626
self.params = params
27+
if self.params.path_session is None:
28+
self.params.path_session = ""
29+
if self.params.antiprompt is None:
30+
self.params.antiprompt = ""
2731

2832
if (self.params.perplexity):
2933
raise NotImplementedError("""************
@@ -66,7 +70,9 @@ def __init__(self, params: GptParams) -> None:
6670
self.lparams.use_mlock = self.params.use_mlock
6771
self.lparams.use_mmap = self.params.use_mmap
6872

69-
self.ctx = llama_cpp.llama_init_from_file(self.params.model.encode("utf8"), self.lparams)
73+
self.model = llama_cpp.llama_load_model_from_file(
74+
self.params.model.encode("utf8"), self.lparams)
75+
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.lparams)
7076
if (not self.ctx):
7177
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
7278

@@ -181,12 +187,12 @@ def __init__(self, params: GptParams) -> None:
181187
number of tokens in prompt = {len(self.embd_inp)}""", file=sys.stderr)
182188

183189
for i in range(len(self.embd_inp)):
184-
print(f"{self.embd_inp[i]} -> '{llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i])}'", file=sys.stderr)
190+
print(f"{self.embd_inp[i]} -> '{self.token_to_str(self.embd_inp[i])}'", file=sys.stderr)
185191

186192
if (self.params.n_keep > 0):
187193
print("static prompt based on n_keep: '")
188194
for i in range(self.params.n_keep):
189-
print(llama_cpp.llama_token_to_str(self.ctx, self.embd_inp[i]), file=sys.stderr)
195+
print(self.token_to_str(self.embd_inp[i]), file=sys.stderr)
190196
print("'", file=sys.stderr)
191197
print(file=sys.stderr)
192198

@@ -339,7 +345,7 @@ def generate(self):
339345
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
340346

341347
# Apply penalties
342-
nl_logit = logits[llama_cpp.llama_token_nl()]
348+
nl_logit = logits[llama_cpp.llama_token_nl(self.ctx)]
343349
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
344350

345351
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
@@ -380,7 +386,7 @@ def generate(self):
380386
self.last_n_tokens.append(id)
381387

382388
# replace end of text token with newline token when in interactive mode
383-
if (id == llama_cpp.llama_token_eos() and self.params.interactive and not self.params.instruct):
389+
if (id == llama_cpp.llama_token_eos(self.ctx) and self.params.interactive and not self.params.instruct):
384390
id = self.llama_token_newline[0]
385391
self.embd.append(id)
386392
if (self.use_antiprompt()):
@@ -437,7 +443,7 @@ def generate(self):
437443
break
438444

439445
# end of text token
440-
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
446+
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos(self.ctx):
441447
if (not self.params.instruct):
442448
for i in self.llama_token_eot:
443449
yield i
@@ -464,10 +470,18 @@ def exit(self):
464470
llama_cpp.llama_free(self.ctx)
465471
self.set_color(util.CONSOLE_COLOR_DEFAULT)
466472

473+
def token_to_str(self, token_id: int) -> bytes:
474+
size = 32
475+
buffer = (ctypes.c_char * size)()
476+
n = llama_cpp.llama_token_to_piece_with_model(
477+
self.model, llama_cpp.llama_token(token_id), buffer, size)
478+
assert n <= size
479+
return bytes(buffer[:n])
480+
467481
# return past text
468482
def past(self):
469483
for id in self.last_n_tokens[-self.n_past:]:
470-
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf8", errors="ignore")
484+
yield self.token_to_str(id).decode("utf8", errors="ignore")
471485

472486
# write input
473487
def input(self, prompt: str):
@@ -481,7 +495,7 @@ def input(self, prompt: str):
481495
def output(self):
482496
self.remaining_tokens = self.params.n_predict
483497
for id in self.generate():
484-
cur_char = llama_cpp.llama_token_to_str(self.ctx, id)
498+
cur_char = self.token_to_str(id)
485499

486500
# Add remainder of missing bytes
487501
if None in self.multibyte_fix:

examples/low_level_api/low_level_api_llama_cpp.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
import llama_cpp
2-
1+
import ctypes
2+
import os
33
import multiprocessing
44

55
import llama_cpp
66

77
N_THREADS = multiprocessing.cpu_count()
8+
MODEL_PATH = os.environ.get('MODEL', "../models/7B/ggml-model.bin")
89

910
prompt = b"\n\n### Instruction:\nWhat is the capital of France?\n\n### Response:\n"
1011

1112
lparams = llama_cpp.llama_context_default_params()
12-
ctx = llama_cpp.llama_init_from_file(b"../models/7B/ggml-model.bin", lparams)
13+
model = llama_cpp.llama_load_model_from_file(MODEL_PATH.encode('utf-8'), lparams)
14+
ctx = llama_cpp.llama_new_context_with_model(model, lparams)
1315

1416
# determine the required inference memory per token:
1517
tmp = [0, 1, 2, 3]
@@ -58,7 +60,8 @@
5860
llama_cpp.llama_token_data(token_id, logits[token_id], 0.0)
5961
for token_id in range(n_vocab)
6062
])
61-
candidates_p = llama_cpp.ctypes.pointer(llama_cpp.llama_token_data_array(_arr, len(_arr), False))
63+
candidates_p = llama_cpp.ctypes.pointer(
64+
llama_cpp.llama_token_data_array(_arr, len(_arr), False))
6265

6366
_arr = (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data)
6467
llama_cpp.llama_sample_repetition_penalty(ctx, candidates_p,
@@ -68,9 +71,9 @@
6871
_arr,
6972
last_n_repeat, frequency_penalty, presence_penalty)
7073

71-
llama_cpp.llama_sample_top_k(ctx, candidates_p, 40)
72-
llama_cpp.llama_sample_top_p(ctx, candidates_p, 0.8)
73-
llama_cpp.llama_sample_temperature(ctx, candidates_p, 0.2)
74+
llama_cpp.llama_sample_top_k(ctx, candidates_p, k=40, min_keep=1)
75+
llama_cpp.llama_sample_top_p(ctx, candidates_p, p=0.8, min_keep=1)
76+
llama_cpp.llama_sample_temperature(ctx, candidates_p, temp=0.2)
7477
id = llama_cpp.llama_sample_token(ctx, candidates_p)
7578

7679
last_n_tokens_data = last_n_tokens_data[1:] + [id]
@@ -86,13 +89,18 @@
8689
break
8790
if not input_noecho:
8891
for id in embd:
92+
size = 32
93+
buffer = (ctypes.c_char * size)()
94+
n = llama_cpp.llama_token_to_piece_with_model(
95+
model, llama_cpp.llama_token(id), buffer, size)
96+
assert n <= size
8997
print(
90-
llama_cpp.llama_token_to_str(ctx, id).decode("utf-8", errors="ignore"),
98+
buffer[:n].decode('utf-8'),
9199
end="",
92100
flush=True,
93101
)
94102

95-
if len(embd) > 0 and embd[-1] == llama_cpp.llama_token_eos():
103+
if len(embd) > 0 and embd[-1] == llama_cpp.llama_token_eos(ctx):
96104
break
97105

98106
print()

llama_cpp/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
from .llama_cpp import *
22
from .llama import *
3+
4+
from .version import __version__

llama_cpp/version.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "0.1.84"

poetry.lock

+38-38
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ sse-starlette = { version = ">=1.6.1", optional = true }
2323
pydantic-settings = { version = ">=2.0.1", optional = true }
2424

2525
[tool.poetry.group.dev.dependencies]
26-
black = "^23.7.0"
26+
black = "^23.9.1"
2727
twine = "^4.0.2"
2828
mkdocs = "^1.5.2"
2929
mkdocstrings = {extras = ["python"], version = "^0.23.0"}
30-
mkdocs-material = "^9.2.8"
30+
mkdocs-material = "^9.3.1"
3131
pytest = "^7.4.2"
32-
httpx = "^0.24.1"
32+
httpx = "^0.25.0"
3333
scikit-build = "0.17.6"
3434

3535
[tool.poetry.extras]

setup.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
this_directory = Path(__file__).parent
66
long_description = (this_directory / "README.md").read_text(encoding="utf-8")
77

8+
exec(open('llama_cpp/version.py').read())
9+
810
setup(
911
name="llama_cpp_python",
1012
description="A Python wrapper for llama.cpp",
1113
long_description=long_description,
1214
long_description_content_type="text/markdown",
13-
version="0.1.84",
15+
version=__version__,
1416
author="Andrei Betlen",
1517
author_email="[email protected]",
1618
license="MIT",

tests/test_llama.py

+3
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,6 @@ def test_llama_server():
181181
}
182182
],
183183
}
184+
185+
def test_llama_cpp_version():
186+
assert llama_cpp.__version__

0 commit comments

Comments
 (0)