5
5
* Input is always echoed if on, so it should be turned off when using "input()"
6
6
* The first antiprompt should be the userprompt like "\n User:",
7
7
because its added when n_predict is reached (aka generation ended prematurely)
8
- * n_predict can be set to -1 for unlimited length responses
8
+ * n_predict can be set to -1 for unlimited length responses (or just a really high value)
9
+ * It's always in interactive mode, generation ends either by reaching an antiprompt
10
+ or running out of n_predict.
11
+ * Instruction mode adds its own antiprompt
9
12
"""
10
13
import llama_cpp
11
14
12
- def toIntArray (lst ):
13
- return [int (i ) for i in lst ]
14
-
15
15
# A LLaMA interactive session
16
16
class LLaMAInteract :
17
17
def __init__ (self ,
18
18
primer : str = "" ,
19
19
model : str = "./models/30B/ggml-model-q4_0.bin" ,
20
+ instruct : bool = False ,
20
21
n_ctx : int = 1024 ,
21
22
seed : int = 0 ,
22
23
n_threads : int = 8 ,
23
24
antiprompt : list [str ]= [],
24
25
input_echo : bool = True ,
25
26
n_predict : int = 20 ,
27
+ n_keep : int = 0 ,
26
28
n_batch : int = 8 ,
27
29
repeat_last_n : int = 64 ,
28
30
top_k : int = 50 ,
@@ -31,17 +33,17 @@ def __init__(self,
31
33
repeat_penalty : float = 1 ,
32
34
) -> None :
33
35
# input args
36
+ self .instruct = instruct
34
37
self .n_threads = n_threads
35
38
self .input_echo = input_echo
36
39
self .n_predict = n_predict
40
+ self .n_keep = n_keep
37
41
self .n_batch = n_batch
38
42
self .repeat_last_n = repeat_last_n
39
43
self .top_k = top_k
40
44
self .top_p = top_p
41
45
self .temp = temp
42
46
self .repeat_penalty = repeat_penalty
43
- self .n_ctx = n_ctx
44
- self .seed = seed
45
47
46
48
# runtime args
47
49
self .input_consumed = 0
@@ -54,38 +56,53 @@ def __init__(self,
54
56
55
57
# model load
56
58
self .lparams = llama_cpp .llama_context_default_params ()
57
- self .lparams .n_ctx = self . n_ctx
58
- self .lparams .seed = self . seed
59
+ self .lparams .n_ctx = n_ctx
60
+ self .lparams .seed = seed
59
61
self .ctx = llama_cpp .llama_init_from_file (model .encode ("utf8" ), self .lparams )
60
62
61
63
# determine the required inference memory per token:
62
64
tmp = [0 , 1 , 2 , 3 ]
63
65
llama_cpp .llama_eval (self .ctx , (llama_cpp .c_int * len (tmp ))(* tmp ), len (tmp ), 0 , self .n_threads )
64
66
65
67
# determine newline token
66
- self .llama_token_newline = (llama_cpp .llama_token * 1 )()
67
- llama_cpp .llama_tokenize (self .ctx , b"\n " , self .llama_token_newline , len (self .llama_token_newline ), False )
68
- self .llama_token_newline = toIntArray (self .llama_token_newline )
68
+ self .llama_token_newline = self ._tokenize ("\n " , False )
69
+ self .inp_prefix = self ._tokenize ("\n \n ### Instruction:\n \n " )
70
+ self .inp_suffix = self ._tokenize ("\n \n ### Response:\n \n " , False )
71
+
72
+ # add instruction as antiprompt
73
+ if (self .instruct ):
74
+ self .first_antiprompt .append (self .inp_prefix )
69
75
70
76
# primer feed
71
77
if (len (primer ) > 0 ):
72
- self .input (primer )
73
- self .n_keep = len (self .embd_inp )
78
+ self .embd_inp += self ._tokenize (primer )
79
+
80
+ # break immediately if using instruct
81
+ self .init_break = self .instruct
82
+
83
+ # number of tokens to keep when resetting context
84
+ if (self .n_keep < 0 or self .n_keep > len (self .embd_inp ) or self .instruct ):
85
+ self .n_keep = len (self .embd_inp )
74
86
75
87
# create internal context
76
- self .n_ctx = int ( llama_cpp .llama_n_ctx (self .ctx ) )
88
+ self .n_ctx = llama_cpp .llama_n_ctx (self .ctx )
77
89
self .last_n_tokens = [0 ]* self .n_ctx #TODO: deque doesnt support slices
78
90
79
91
# determine antiprompt tokens
80
92
for i in antiprompt :
81
- d_antiprompt = (llama_cpp .llama_token * (len (i ) + 1 ))()
82
- n_antiprompt = llama_cpp .llama_tokenize (self .ctx , i .encode ("utf8" ), d_antiprompt , len (d_antiprompt ), False )
83
- self .first_antiprompt .append (toIntArray (d_antiprompt [:n_antiprompt ]))
93
+ self .first_antiprompt .append (self ._tokenize (i , False ))
94
+
95
+ # tokenize a prompt
96
+ def _tokenize (self , prompt , bos = True ):
97
+ _arr = (llama_cpp .llama_token * (len (prompt ) + 1 ))()
98
+ _n = llama_cpp .llama_tokenize (self .ctx , prompt .encode ("utf8" ), _arr , len (_arr ), bos )
99
+ return _arr [:_n ]
84
100
85
101
# if an antiprompt is present
86
102
def use_antiprompt (self ):
87
103
return len (self .first_antiprompt ) > 0
88
104
105
+ # generate tokens
89
106
def generate (self ):
90
107
while self .remaining_tokens > 0 or self .use_antiprompt ():
91
108
# predict
@@ -125,16 +142,16 @@ def generate(self):
125
142
self .repeat_penalty ,
126
143
)
127
144
self .last_n_tokens .pop (0 )
128
- self .last_n_tokens .append (int ( id ) )
145
+ self .last_n_tokens .append (id )
129
146
130
147
# replace end of text token with newline token when in interactive mode
131
- if (id == llama_cpp .llama_token_eos () and self .use_antiprompt ()):
148
+ if (id == llama_cpp .llama_token_eos () and self .use_antiprompt () and not self . instruct ):
132
149
id = self .llama_token_newline [0 ]
133
150
# tokenize and inject first reverse prompt
134
151
self .embd_inp += self .first_antiprompt [0 ]
135
152
136
153
# add it to the context
137
- self .embd .append (int ( id ) )
154
+ self .embd .append (id )
138
155
139
156
# echo this to console
140
157
self .output_echo = True
@@ -147,9 +164,9 @@ def generate(self):
147
164
148
165
# some user input remains from prompt or interaction, forward it to processing
149
166
while len (self .embd_inp ) > self .input_consumed :
150
- self .embd .append (int ( self .embd_inp [self .input_consumed ]) )
167
+ self .embd .append (self .embd_inp [self .input_consumed ])
151
168
self .last_n_tokens .pop (0 )
152
- self .last_n_tokens .append (int ( self .embd_inp [self .input_consumed ]) )
169
+ self .last_n_tokens .append (self .embd_inp [self .input_consumed ])
153
170
self .input_consumed += 1
154
171
if len (self .embd ) >= self .n_batch :
155
172
break
@@ -159,11 +176,17 @@ def generate(self):
159
176
for id in self .embd :
160
177
yield id
161
178
162
- # if antiprompt is present, stop
163
- if (self .use_antiprompt () and len (self .embd_inp ) <= self .input_consumed ):
164
- for i in self .first_antiprompt :
165
- if i == self .last_n_tokens [- len (i ):]:
166
- return
179
+ if (len (self .embd_inp ) <= self .input_consumed ):
180
+ # if antiprompt is present, stop
181
+ if (self .use_antiprompt ()):
182
+ for i in self .first_antiprompt :
183
+ if i == self .last_n_tokens [- len (i ):]:
184
+ return
185
+
186
+ # if we are using instruction mode, and we have processed the initial prompt
187
+ if (self .init_break ):
188
+ self .init_break = False
189
+ break
167
190
168
191
# if end of generation
169
192
if len (self .embd ) > 0 and self .embd [- 1 ] == llama_cpp .llama_token_eos ():
@@ -174,15 +197,20 @@ def generate(self):
174
197
self .embd_inp += self .first_antiprompt [0 ]
175
198
break
176
199
200
+ # return past text
177
201
def past (self ):
178
202
for id in self .last_n_tokens [- self .n_past :]:
179
203
yield llama_cpp .llama_token_to_str (self .ctx , id ).decode ("utf-8" )
180
204
205
+ # write input
181
206
def input (self , prompt : str ):
182
- embd_arr = (llama_cpp .llama_token * (len (prompt ) + 1 ))()
183
- n_of_tok = llama_cpp .llama_tokenize (self .ctx , prompt .encode ("utf8" ), embd_arr , len (embd_arr ), True )
184
- self .embd_inp += toIntArray (embd_arr [:n_of_tok ])
207
+ if (self .instruct ):
208
+ self .embd_inp += self .inp_prefix
209
+ self .embd_inp += self ._tokenize (prompt + "\n " )
210
+ if (self .instruct ):
211
+ self .embd_inp += self .inp_suffix
185
212
213
+ # write output
186
214
def output (self ):
187
215
self .remaining_tokens = self .n_predict
188
216
for id in self .generate ():
@@ -193,7 +221,7 @@ def output(self):
193
221
194
222
USER_NAME = "User"
195
223
AI_NAME = "ChatLLaMa"
196
-
224
+
197
225
time_now = datetime .now ()
198
226
prompt = f"""Text transcript of a never ending dialog, where { USER_NAME } interacts with an AI assistant named { AI_NAME } .
199
227
{ AI_NAME } is helpful, kind, honest, friendly, good at writing and never fails to answer { USER_NAME } ’s requests immediately and with details and precision.
@@ -214,7 +242,7 @@ def output(self):
214
242
{ USER_NAME } :"""
215
243
216
244
print ("Loading model..." )
217
- ll = LLaMAInteract (prompt ,
245
+ m = LLaMAInteract (prompt ,
218
246
model = "./models/30B/ggml-model-q4_0.bin" ,
219
247
n_ctx = 2048 ,
220
248
antiprompt = [f"\n { USER_NAME } :" ],
@@ -224,12 +252,11 @@ def output(self):
224
252
)
225
253
print ("Loaded model!" )
226
254
227
- for i in ll .output ():
255
+ for i in m .output ():
228
256
print (i ,end = "" ,flush = True )
229
- ll .input_echo = False
257
+ m .input_echo = False
230
258
231
- inp = lambda x : f" { x } \n "
232
259
while True :
233
- ll .input (inp ( input (' ' ) ))
234
- for i in ll .output ():
260
+ m .input (" " + input ('\n > ' if m . instruct else " " ))
261
+ for i in m .output ():
235
262
print (i ,end = "" ,flush = True )
0 commit comments