Skip to content

Commit 0b32bb3

Browse files
author
Mug
committed
Add instruction mode
1 parent f1615f0 commit 0b32bb3

File tree

1 file changed

+64
-37
lines changed

1 file changed

+64
-37
lines changed

examples/low_level_api_chatllama_cpp.py

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,26 @@
55
* Input is always echoed if on, so it should be turned off when using "input()"
66
* The first antiprompt should be the userprompt like "\nUser:",
77
because its added when n_predict is reached (aka generation ended prematurely)
8-
* n_predict can be set to -1 for unlimited length responses
8+
* n_predict can be set to -1 for unlimited length responses (or just a really high value)
9+
* It's always in interactive mode, generation ends either by reaching an antiprompt
10+
or running out of n_predict.
11+
* Instruction mode adds its own antiprompt
912
"""
1013
import llama_cpp
1114

12-
def toIntArray(lst):
13-
return [int(i) for i in lst]
14-
1515
# A LLaMA interactive session
1616
class LLaMAInteract:
1717
def __init__(self,
1818
primer: str="",
1919
model: str="./models/30B/ggml-model-q4_0.bin",
20+
instruct: bool=False,
2021
n_ctx: int=1024,
2122
seed: int=0,
2223
n_threads: int=8,
2324
antiprompt: list[str]=[],
2425
input_echo: bool=True,
2526
n_predict: int=20,
27+
n_keep: int=0,
2628
n_batch: int=8,
2729
repeat_last_n: int=64,
2830
top_k: int=50,
@@ -31,17 +33,17 @@ def __init__(self,
3133
repeat_penalty: float=1,
3234
) -> None:
3335
# input args
36+
self.instruct = instruct
3437
self.n_threads = n_threads
3538
self.input_echo = input_echo
3639
self.n_predict = n_predict
40+
self.n_keep = n_keep
3741
self.n_batch = n_batch
3842
self.repeat_last_n = repeat_last_n
3943
self.top_k=top_k
4044
self.top_p=top_p
4145
self.temp=temp
4246
self.repeat_penalty=repeat_penalty
43-
self.n_ctx = n_ctx
44-
self.seed = seed
4547

4648
# runtime args
4749
self.input_consumed = 0
@@ -54,38 +56,53 @@ def __init__(self,
5456

5557
# model load
5658
self.lparams = llama_cpp.llama_context_default_params()
57-
self.lparams.n_ctx = self.n_ctx
58-
self.lparams.seed = self.seed
59+
self.lparams.n_ctx = n_ctx
60+
self.lparams.seed = seed
5961
self.ctx = llama_cpp.llama_init_from_file(model.encode("utf8"), self.lparams)
6062

6163
# determine the required inference memory per token:
6264
tmp = [0, 1, 2, 3]
6365
llama_cpp.llama_eval(self.ctx, (llama_cpp.c_int * len(tmp))(*tmp), len(tmp), 0, self.n_threads)
6466

6567
# determine newline token
66-
self.llama_token_newline = (llama_cpp.llama_token * 1)()
67-
llama_cpp.llama_tokenize(self.ctx, b"\n", self.llama_token_newline, len(self.llama_token_newline), False)
68-
self.llama_token_newline = toIntArray(self.llama_token_newline)
68+
self.llama_token_newline = self._tokenize("\n", False)
69+
self.inp_prefix = self._tokenize("\n\n### Instruction:\n\n")
70+
self.inp_suffix = self._tokenize("\n\n### Response:\n\n", False)
71+
72+
# add instruction as antiprompt
73+
if (self.instruct):
74+
self.first_antiprompt.append(self.inp_prefix)
6975

7076
# primer feed
7177
if (len(primer) > 0):
72-
self.input(primer)
73-
self.n_keep = len(self.embd_inp)
78+
self.embd_inp += self._tokenize(primer)
79+
80+
# break immediately if using instruct
81+
self.init_break = self.instruct
82+
83+
# number of tokens to keep when resetting context
84+
if (self.n_keep < 0 or self.n_keep > len(self.embd_inp) or self.instruct):
85+
self.n_keep = len(self.embd_inp)
7486

7587
# create internal context
76-
self.n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
88+
self.n_ctx = llama_cpp.llama_n_ctx(self.ctx)
7789
self.last_n_tokens = [0]*self.n_ctx #TODO: deque doesnt support slices
7890

7991
# determine antiprompt tokens
8092
for i in antiprompt:
81-
d_antiprompt = (llama_cpp.llama_token * (len(i) + 1))()
82-
n_antiprompt = llama_cpp.llama_tokenize(self.ctx, i.encode("utf8"), d_antiprompt, len(d_antiprompt), False)
83-
self.first_antiprompt.append(toIntArray(d_antiprompt[:n_antiprompt]))
93+
self.first_antiprompt.append(self._tokenize(i, False))
94+
95+
# tokenize a prompt
96+
def _tokenize(self, prompt, bos=True):
97+
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
98+
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos)
99+
return _arr[:_n]
84100

85101
# if an antiprompt is present
86102
def use_antiprompt(self):
87103
return len(self.first_antiprompt) > 0
88104

105+
# generate tokens
89106
def generate(self):
90107
while self.remaining_tokens > 0 or self.use_antiprompt():
91108
# predict
@@ -125,16 +142,16 @@ def generate(self):
125142
self.repeat_penalty,
126143
)
127144
self.last_n_tokens.pop(0)
128-
self.last_n_tokens.append(int(id))
145+
self.last_n_tokens.append(id)
129146

130147
# replace end of text token with newline token when in interactive mode
131-
if (id == llama_cpp.llama_token_eos() and self.use_antiprompt()):
148+
if (id == llama_cpp.llama_token_eos() and self.use_antiprompt() and not self.instruct):
132149
id = self.llama_token_newline[0]
133150
# tokenize and inject first reverse prompt
134151
self.embd_inp += self.first_antiprompt[0]
135152

136153
# add it to the context
137-
self.embd.append(int(id))
154+
self.embd.append(id)
138155

139156
# echo this to console
140157
self.output_echo = True
@@ -147,9 +164,9 @@ def generate(self):
147164

148165
# some user input remains from prompt or interaction, forward it to processing
149166
while len(self.embd_inp) > self.input_consumed:
150-
self.embd.append(int(self.embd_inp[self.input_consumed]))
167+
self.embd.append(self.embd_inp[self.input_consumed])
151168
self.last_n_tokens.pop(0)
152-
self.last_n_tokens.append(int(self.embd_inp[self.input_consumed]))
169+
self.last_n_tokens.append(self.embd_inp[self.input_consumed])
153170
self.input_consumed += 1
154171
if len(self.embd) >= self.n_batch:
155172
break
@@ -159,11 +176,17 @@ def generate(self):
159176
for id in self.embd:
160177
yield id
161178

162-
# if antiprompt is present, stop
163-
if (self.use_antiprompt() and len(self.embd_inp) <= self.input_consumed):
164-
for i in self.first_antiprompt:
165-
if i == self.last_n_tokens[-len(i):]:
166-
return
179+
if (len(self.embd_inp) <= self.input_consumed):
180+
# if antiprompt is present, stop
181+
if (self.use_antiprompt()):
182+
for i in self.first_antiprompt:
183+
if i == self.last_n_tokens[-len(i):]:
184+
return
185+
186+
# if we are using instruction mode, and we have processed the initial prompt
187+
if (self.init_break):
188+
self.init_break = False
189+
break
167190

168191
# if end of generation
169192
if len(self.embd) > 0 and self.embd[-1] == llama_cpp.llama_token_eos():
@@ -174,15 +197,20 @@ def generate(self):
174197
self.embd_inp += self.first_antiprompt[0]
175198
break
176199

200+
# return past text
177201
def past(self):
178202
for id in self.last_n_tokens[-self.n_past:]:
179203
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
180204

205+
# write input
181206
def input(self, prompt: str):
182-
embd_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
183-
n_of_tok = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), embd_arr, len(embd_arr), True)
184-
self.embd_inp += toIntArray(embd_arr[:n_of_tok])
207+
if (self.instruct):
208+
self.embd_inp += self.inp_prefix
209+
self.embd_inp += self._tokenize(prompt + "\n")
210+
if (self.instruct):
211+
self.embd_inp += self.inp_suffix
185212

213+
# write output
186214
def output(self):
187215
self.remaining_tokens = self.n_predict
188216
for id in self.generate():
@@ -193,7 +221,7 @@ def output(self):
193221

194222
USER_NAME="User"
195223
AI_NAME="ChatLLaMa"
196-
224+
197225
time_now = datetime.now()
198226
prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
199227
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}’s requests immediately and with details and precision.
@@ -214,7 +242,7 @@ def output(self):
214242
{USER_NAME}:"""
215243

216244
print("Loading model...")
217-
ll = LLaMAInteract(prompt,
245+
m = LLaMAInteract(prompt,
218246
model="./models/30B/ggml-model-q4_0.bin",
219247
n_ctx=2048,
220248
antiprompt=[f"\n{USER_NAME}:"],
@@ -224,12 +252,11 @@ def output(self):
224252
)
225253
print("Loaded model!")
226254

227-
for i in ll.output():
255+
for i in m.output():
228256
print(i,end="",flush=True)
229-
ll.input_echo = False
257+
m.input_echo = False
230258

231-
inp = lambda x: f" {x}\n"
232259
while True:
233-
ll.input(inp(input(' ')))
234-
for i in ll.output():
260+
m.input(" " + input('\n> ' if m.instruct else " "))
261+
for i in m.output():
235262
print(i,end="",flush=True)

0 commit comments

Comments
 (0)