@@ -24,6 +24,10 @@ class LLaMAInteract:
24
24
def __init__ (self , params : GptParams ) -> None :
25
25
# input args
26
26
self .params = params
27
+ if self .params .path_session is None :
28
+ self .params .path_session = ""
29
+ if self .params .antiprompt is None :
30
+ self .params .antiprompt = ""
27
31
28
32
if (self .params .perplexity ):
29
33
raise NotImplementedError ("""************
@@ -66,7 +70,9 @@ def __init__(self, params: GptParams) -> None:
66
70
self .lparams .use_mlock = self .params .use_mlock
67
71
self .lparams .use_mmap = self .params .use_mmap
68
72
69
- self .ctx = llama_cpp .llama_init_from_file (self .params .model .encode ("utf8" ), self .lparams )
73
+ self .model = llama_cpp .llama_load_model_from_file (
74
+ self .params .model .encode ("utf8" ), self .lparams )
75
+ self .ctx = llama_cpp .llama_new_context_with_model (self .model , self .lparams )
70
76
if (not self .ctx ):
71
77
raise RuntimeError (f"error: failed to load model '{ self .params .model } '" )
72
78
@@ -181,12 +187,12 @@ def __init__(self, params: GptParams) -> None:
181
187
number of tokens in prompt = { len (self .embd_inp )} """ , file = sys .stderr )
182
188
183
189
for i in range (len (self .embd_inp )):
184
- print (f"{ self .embd_inp [i ]} -> '{ llama_cpp . llama_token_to_str ( self .ctx , self .embd_inp [i ])} '" , file = sys .stderr )
190
+ print (f"{ self .embd_inp [i ]} -> '{ self .token_to_str ( self .embd_inp [i ])} '" , file = sys .stderr )
185
191
186
192
if (self .params .n_keep > 0 ):
187
193
print ("static prompt based on n_keep: '" )
188
194
for i in range (self .params .n_keep ):
189
- print (llama_cpp . llama_token_to_str ( self .ctx , self .embd_inp [i ]), file = sys .stderr )
195
+ print (self .token_to_str ( self .embd_inp [i ]), file = sys .stderr )
190
196
print ("'" , file = sys .stderr )
191
197
print (file = sys .stderr )
192
198
@@ -339,7 +345,7 @@ def generate(self):
339
345
candidates_p = llama_cpp .ctypes .pointer (llama_cpp .llama_token_data_array (_arr , len (_arr ), False ))
340
346
341
347
# Apply penalties
342
- nl_logit = logits [llama_cpp .llama_token_nl ()]
348
+ nl_logit = logits [llama_cpp .llama_token_nl (self . ctx )]
343
349
last_n_repeat = min (len (self .last_n_tokens ), repeat_last_n , self .n_ctx )
344
350
345
351
_arr = (llama_cpp .llama_token * last_n_repeat )(* self .last_n_tokens [len (self .last_n_tokens ) - last_n_repeat :])
@@ -380,7 +386,7 @@ def generate(self):
380
386
self .last_n_tokens .append (id )
381
387
382
388
# replace end of text token with newline token when in interactive mode
383
- if (id == llama_cpp .llama_token_eos () and self .params .interactive and not self .params .instruct ):
389
+ if (id == llama_cpp .llama_token_eos (self . ctx ) and self .params .interactive and not self .params .instruct ):
384
390
id = self .llama_token_newline [0 ]
385
391
self .embd .append (id )
386
392
if (self .use_antiprompt ()):
@@ -437,7 +443,7 @@ def generate(self):
437
443
break
438
444
439
445
# end of text token
440
- if len (self .embd ) > 0 and self .embd [- 1 ] == llama_cpp .llama_token_eos ():
446
+ if len (self .embd ) > 0 and self .embd [- 1 ] == llama_cpp .llama_token_eos (self . ctx ):
441
447
if (not self .params .instruct ):
442
448
for i in self .llama_token_eot :
443
449
yield i
@@ -464,10 +470,18 @@ def exit(self):
464
470
llama_cpp .llama_free (self .ctx )
465
471
self .set_color (util .CONSOLE_COLOR_DEFAULT )
466
472
473
+ def token_to_str (self , token_id : int ) -> bytes :
474
+ size = 32
475
+ buffer = (ctypes .c_char * size )()
476
+ n = llama_cpp .llama_token_to_piece_with_model (
477
+ self .model , llama_cpp .llama_token (token_id ), buffer , size )
478
+ assert n <= size
479
+ return bytes (buffer [:n ])
480
+
467
481
# return past text
468
482
def past (self ):
469
483
for id in self .last_n_tokens [- self .n_past :]:
470
- yield llama_cpp . llama_token_to_str ( self .ctx , id ).decode ("utf8" , errors = "ignore" )
484
+ yield self .token_to_str ( id ).decode ("utf8" , errors = "ignore" )
471
485
472
486
# write input
473
487
def input (self , prompt : str ):
@@ -481,7 +495,7 @@ def input(self, prompt: str):
481
495
def output (self ):
482
496
self .remaining_tokens = self .params .n_predict
483
497
for id in self .generate ():
484
- cur_char = llama_cpp . llama_token_to_str ( self .ctx , id )
498
+ cur_char = self .token_to_str ( id )
485
499
486
500
# Add remainder of missing bytes
487
501
if None in self .multibyte_fix :
0 commit comments