8
8
* n_predict can be set to -1 for unlimited length responses (or just a really high value)
9
9
* It's always in interactive mode, generation ends either by reaching an antiprompt
10
10
or running out of n_predict.
11
- * Instruction mode adds its own antiprompt
11
+ * Instruction mode adds its own antiprompt.
12
+ You should also still be feeding the model with a "primer" prompt that
13
+ shows it the expected format.
12
14
"""
13
15
import llama_cpp
14
16
@@ -31,6 +33,8 @@ def __init__(self,
31
33
top_p : float = 1. ,
32
34
temp : float = 1.0 ,
33
35
repeat_penalty : float = 1 ,
36
+ instruct_inp_prefix : str = "\n \n ### Instruction:\n \n " ,
37
+ instruct_inp_suffix : str = "\n \n ### Response:\n \n " ,
34
38
) -> None :
35
39
# input args
36
40
self .instruct = instruct
@@ -66,12 +70,12 @@ def __init__(self,
66
70
67
71
# determine newline token
68
72
self .llama_token_newline = self ._tokenize ("\n " , False )
69
- self .inp_prefix = self ._tokenize (" \n \n ### Instruction: \n \n " )
70
- self .inp_suffix = self ._tokenize (" \n \n ### Response: \n \n " , False )
73
+ self .inp_prefix = self ._tokenize (instruct_inp_prefix )
74
+ self .inp_suffix = self ._tokenize (instruct_inp_suffix , False )
71
75
72
76
# add instruction as antiprompt
73
77
if (self .instruct ):
74
- self .first_antiprompt .append (self .inp_prefix )
78
+ self .first_antiprompt .append (self .inp_prefix . strip () )
75
79
76
80
# primer feed
77
81
if (len (primer ) > 0 ):
@@ -117,10 +121,9 @@ def generate(self):
117
121
118
122
# insert n_left/2 tokens at the start of embd from last_n_tokens
119
123
_insert = self .last_n_tokens [
120
- - ( int (n_left / 2 ) - len (self .embd ) ):- len (self .embd )
124
+ self . n_ctx - int (n_left / 2 ) - len (self .embd ):- len (self .embd )
121
125
]
122
- self .embd [:len (_insert )] = _insert
123
- #TODO: Still untested
126
+ self .embd = _insert + self .embd
124
127
125
128
if (llama_cpp .llama_eval (
126
129
self .ctx , (llama_cpp .llama_token * len (self .embd ))(* self .embd ), len (self .embd ), self .n_past , self .n_threads
@@ -197,6 +200,12 @@ def generate(self):
197
200
self .embd_inp += self .first_antiprompt [0 ]
198
201
break
199
202
203
+ def __enter__ (self ):
204
+ return self
205
+
206
+ def __exit__ (self , type , value , tb ):
207
+ llama_cpp .llama_free (self .ctx )
208
+
200
209
# return past text
201
210
def past (self ):
202
211
for id in self .last_n_tokens [- self .n_past :]:
@@ -206,7 +215,7 @@ def past(self):
206
215
def input (self , prompt : str ):
207
216
if (self .instruct ):
208
217
self .embd_inp += self .inp_prefix
209
- self .embd_inp += self ._tokenize (prompt + " \n " )
218
+ self .embd_inp += self ._tokenize (prompt )
210
219
if (self .instruct ):
211
220
self .embd_inp += self .inp_suffix
212
221
@@ -242,21 +251,38 @@ def output(self):
242
251
{ USER_NAME } :"""
243
252
244
253
print ("Loading model..." )
245
- m = LLaMAInteract (prompt ,
254
+ with LLaMAInteract (prompt ,
246
255
model = "./models/30B/ggml-model-q4_0.bin" ,
247
256
n_ctx = 2048 ,
248
257
antiprompt = [f"\n { USER_NAME } :" ],
249
258
repeat_last_n = 256 ,
250
259
n_predict = 2048 ,
251
260
temp = 0.7 , top_p = 0.5 , top_k = 40 , repeat_penalty = 1.17647
252
- )
253
- print ("Loaded model!" )
261
+ ) as m :
262
+ print ("Loaded model!" )
254
263
255
- for i in m .output ():
256
- print (i ,end = "" ,flush = True )
257
- m .input_echo = False
258
-
259
- while True :
260
- m .input (" " + input ('\n > ' if m .instruct else " " ))
261
264
for i in m .output ():
262
- print (i ,end = "" ,flush = True )
265
+ print (i ,end = "" ,flush = True )
266
+ m .input_echo = False
267
+
268
+ def inp ():
269
+ out = ""
270
+ while (t := input ()).endswith ("\\ " ):
271
+ out += t [:- 1 ] + "\n "
272
+ return out + t + "\n "
273
+
274
+ while True :
275
+ if (m .instruct ):
276
+ print ('\n > ' , end = "" )
277
+ m .input (inp ())
278
+ else :
279
+ print (f" " , end = "" )
280
+ m .input (f" { inp ()} { AI_NAME } :" )
281
+ print (f"{ AI_NAME } : " ,end = "" )
282
+
283
+ try :
284
+ for i in m .output ():
285
+ print (i ,end = "" ,flush = True )
286
+ except KeyboardInterrupt :
287
+ print (f"\n { USER_NAME } :" ,end = "" )
288
+ m .input (f"\n { USER_NAME } :" )
0 commit comments