File tree Expand file tree Collapse file tree 1 file changed +24
-2
lines changed Expand file tree Collapse file tree 1 file changed +24
-2
lines changed Original file line number Diff line number Diff line change @@ -681,9 +681,31 @@ def callback(
681681 # print(, end='', flush=True)
682682
683683 else :
684+ assert not generator_args .chat_mode
685+ buffer = [generator_args .prompt ]
686+ period_id = tokenizer .encode ("." )[0 ]
687+ done_generating = False
684688
685- def callback (x ):
686- return x
689+ def callback (
690+ x , buffer = buffer , period_id = period_id , done_generating = done_generating
691+ ):
692+ if done_generating :
693+ return
694+ buffer .append (
695+ tokenizer .decode ([period_id ] + x .tolist ())[1 :]
696+ ) # I think this results in the first output token being dropped from the display which is wrong.
697+ if x .item () == tokenizer .eos_id ():
698+ done_generating = True
699+ if (
700+ is_llama3_model
701+ and x .item () == tokenizer .special_tokens ["<|eot_id|>" ]
702+ ):
703+ done_generating = True
704+ buffer = buffer [:- 1 ] # drop the eot_id from the output buffer
705+ if len (buffer ) == 4 or done_generating :
706+ print ("" .join (buffer ), end = "" , flush = True )
707+ buffer .clear ()
708+ # print(, end='', flush=True)
687709
688710 t0 = time .perf_counter ()
689711 import contextlib
You can’t perform that action at this time.
0 commit comments