Skip to content

Commit 0e17618

Browse files
author
Mug
committed
Added instruction mode, fixed infinite generation, and various other fixes
1 parent 0b32bb3 commit 0e17618

File tree

1 file changed

+44
-18
lines changed

1 file changed

+44
-18
lines changed

examples/low_level_api_chatllama_cpp.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
* n_predict can be set to -1 for unlimited length responses (or just a really high value)
99
* It's always in interactive mode, generation ends either by reaching an antiprompt
1010
or running out of n_predict.
11-
* Instruction mode adds its own antiprompt
11+
* Instruction mode adds its own antiprompt.
12+
You should also still be feeding the model with a "primer" prompt that
13+
shows it the expected format.
1214
"""
1315
import llama_cpp
1416

@@ -31,6 +33,8 @@ def __init__(self,
3133
top_p: float=1.,
3234
temp: float=1.0,
3335
repeat_penalty: float=1,
36+
instruct_inp_prefix: str="\n\n### Instruction:\n\n",
37+
instruct_inp_suffix: str="\n\n### Response:\n\n",
3438
) -> None:
3539
# input args
3640
self.instruct = instruct
@@ -66,12 +70,12 @@ def __init__(self,
6670

6771
# determine newline token
6872
self.llama_token_newline = self._tokenize("\n", False)
69-
self.inp_prefix = self._tokenize("\n\n### Instruction:\n\n")
70-
self.inp_suffix = self._tokenize("\n\n### Response:\n\n", False)
73+
self.inp_prefix = self._tokenize(instruct_inp_prefix)
74+
self.inp_suffix = self._tokenize(instruct_inp_suffix, False)
7175

7276
# add instruction as antiprompt
7377
if (self.instruct):
74-
self.first_antiprompt.append(self.inp_prefix)
78+
self.first_antiprompt.append(self.inp_prefix.strip())
7579

7680
# primer feed
7781
if (len(primer) > 0):
@@ -117,10 +121,9 @@ def generate(self):
117121

118122
# insert n_left/2 tokens at the start of embd from last_n_tokens
119123
_insert = self.last_n_tokens[
120-
-(int(n_left/2) - len(self.embd)):-len(self.embd)
124+
self.n_ctx - int(n_left/2) - len(self.embd):-len(self.embd)
121125
]
122-
self.embd[:len(_insert)] = _insert
123-
#TODO: Still untested
126+
self.embd = _insert + self.embd
124127

125128
if (llama_cpp.llama_eval(
126129
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.n_threads
@@ -197,6 +200,12 @@ def generate(self):
197200
self.embd_inp += self.first_antiprompt[0]
198201
break
199202

203+
def __enter__(self):
204+
return self
205+
206+
def __exit__(self, type, value, tb):
207+
llama_cpp.llama_free(self.ctx)
208+
200209
# return past text
201210
def past(self):
202211
for id in self.last_n_tokens[-self.n_past:]:
@@ -206,7 +215,7 @@ def past(self):
206215
def input(self, prompt: str):
207216
if (self.instruct):
208217
self.embd_inp += self.inp_prefix
209-
self.embd_inp += self._tokenize(prompt + "\n")
218+
self.embd_inp += self._tokenize(prompt)
210219
if (self.instruct):
211220
self.embd_inp += self.inp_suffix
212221

@@ -242,21 +251,38 @@ def output(self):
242251
{USER_NAME}:"""
243252

244253
print("Loading model...")
245-
m = LLaMAInteract(prompt,
254+
with LLaMAInteract(prompt,
246255
model="./models/30B/ggml-model-q4_0.bin",
247256
n_ctx=2048,
248257
antiprompt=[f"\n{USER_NAME}:"],
249258
repeat_last_n=256,
250259
n_predict=2048,
251260
temp=0.7, top_p=0.5, top_k=40, repeat_penalty=1.17647
252-
)
253-
print("Loaded model!")
261+
) as m:
262+
print("Loaded model!")
254263

255-
for i in m.output():
256-
print(i,end="",flush=True)
257-
m.input_echo = False
258-
259-
while True:
260-
m.input(" " + input('\n> ' if m.instruct else " "))
261264
for i in m.output():
262-
print(i,end="",flush=True)
265+
print(i,end="",flush=True)
266+
m.input_echo = False
267+
268+
def inp():
269+
out = ""
270+
while (t := input()).endswith("\\"):
271+
out += t[:-1] + "\n"
272+
return out + t + "\n"
273+
274+
while True:
275+
if (m.instruct):
276+
print('\n> ', end="")
277+
m.input(inp())
278+
else:
279+
print(f" ", end="")
280+
m.input(f" {inp()}{AI_NAME}:")
281+
print(f"{AI_NAME}: ",end="")
282+
283+
try:
284+
for i in m.output():
285+
print(i,end="",flush=True)
286+
except KeyboardInterrupt:
287+
print(f"\n{USER_NAME}:",end="")
288+
m.input(f"\n{USER_NAME}:")

0 commit comments

Comments
 (0)