Skip to content

Commit c3972b6

Browse files
committed
Add basic tests. Closes ggml-org#24
1 parent 51dbcf2 commit c3972b6

File tree

3 files changed

+167
-1
lines changed

3 files changed

+167
-1
lines changed

poetry.lock

Lines changed: 87 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ twine = "^4.0.2"
2323
mkdocs = "^1.4.2"
2424
mkdocstrings = {extras = ["python"], version = "^0.20.0"}
2525
mkdocs-material = "^9.1.4"
26+
pytest = "^7.2.2"
2627

2728
[build-system]
2829
requires = [

tests/test_llama.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import llama_cpp
2+
3+
MODEL = "./vendor/llama.cpp/models/ggml-vocab.bin"
4+
5+
6+
def test_llama():
7+
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
8+
9+
assert llama
10+
assert llama.ctx is not None
11+
12+
text = b"Hello World"
13+
14+
assert llama.detokenize(llama.tokenize(text)) == text
15+
16+
17+
def test_llama_patch(monkeypatch):
18+
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
19+
20+
## Set up mock function
21+
def mock_eval(*args, **kwargs):
22+
return 0
23+
24+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
25+
26+
output_text = " jumps over the lazy dog."
27+
output_tokens = llama.tokenize(output_text.encode("utf-8"))
28+
token_eos = llama.token_eos()
29+
n = 0
30+
31+
def mock_sample(*args, **kwargs):
32+
nonlocal n
33+
if n < len(output_tokens):
34+
n += 1
35+
return output_tokens[n - 1]
36+
else:
37+
return token_eos
38+
39+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)
40+
41+
text = "The quick brown fox"
42+
43+
## Test basic completion until eos
44+
n = 0 # reset
45+
completion = llama.create_completion(text, max_tokens=20)
46+
assert completion["choices"][0]["text"] == output_text
47+
assert completion["choices"][0]["finish_reason"] == "stop"
48+
49+
## Test streaming completion until eos
50+
n = 0 # reset
51+
chunks = llama.create_completion(text, max_tokens=20, stream=True)
52+
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text
53+
assert completion["choices"][0]["finish_reason"] == "stop"
54+
55+
## Test basic completion until stop sequence
56+
n = 0 # reset
57+
completion = llama.create_completion(text, max_tokens=20, stop=["lazy"])
58+
assert completion["choices"][0]["text"] == " jumps over the "
59+
assert completion["choices"][0]["finish_reason"] == "stop"
60+
61+
## Test streaming completion until stop sequence
62+
n = 0 # reset
63+
chunks = llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"])
64+
assert (
65+
"".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the "
66+
)
67+
assert completion["choices"][0]["finish_reason"] == "stop"
68+
69+
## Test basic completion until length
70+
n = 0 # reset
71+
completion = llama.create_completion(text, max_tokens=2)
72+
assert completion["choices"][0]["text"] == " j"
73+
assert completion["choices"][0]["finish_reason"] == "length"
74+
75+
## Test streaming completion until length
76+
n = 0 # reset
77+
chunks = llama.create_completion(text, max_tokens=2, stream=True)
78+
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " j"
79+
assert completion["choices"][0]["finish_reason"] == "length"

0 commit comments

Comments
 (0)