11/* Inference for Llama-2 Transformer model in pure C++ */
2+ #include < cstdint>
3+ #include < cstdlib>
24#include < ctype.h>
5+ #include < iterator>
36#include < math.h>
47#include < stdint.h>
58#include < stdio.h>
69#include < stdlib.h>
710#include < string.h>
811#include < time.h>
912#include < tokenizer.h>
13+ #include < string>
14+
1015
1116#ifdef DEBUG
1217#include < cassert>
@@ -485,27 +490,184 @@ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
485490// python reference and that seemed ok, but this was not thoroughly tested and
486491// is not safely implemented, it's more a proof of concept atm.
487492
493+ enum class ModelType {
494+ unknown,
495+ llama2,
496+ llama3,
497+ };
498+
499+ ModelType get_model_type (Tokenizer* tokenizer) {
500+ if (BPETokenizer* t = dynamic_cast <BPETokenizer*>(tokenizer)) {
501+ return ModelType::llama2;
502+ } else if (Tiktoken* t = dynamic_cast <Tiktoken*>(tokenizer)) {
503+ return ModelType::llama3;
504+ } else {
505+ return ModelType::unknown;
506+ }
507+ }
508+
509+ uint64_t get_eot_token (Tokenizer* tokenizer) {
510+ ModelType model_type = get_model_type (tokenizer);
511+
512+ if (model_type == ModelType::llama2) {
513+ // llama2 uses EOS as EOT token
514+ return tokenizer->eos_tok ();
515+ }
516+
517+ if (model_type == ModelType::llama3) {
518+ auto tokens = tokenizer->encode (" <|eot_id|>" , 0 , 0 );
519+ return tokens[0 ];
520+ }
521+
522+ fprintf (stderr, " No chat template implemnation for model type %d" , model_type);
523+ exit (EXIT_FAILURE);
524+ }
525+
526+ std::vector<uint64_t > get_initial_prompt_tokens (const char * cli_system_prompt, const char * cli_user_prompt, Tokenizer* tokenizer) {
527+ char system_prompt[512 ];
528+ char user_prompt[512 ];
529+ char rendered_prompt[512 *2 + 200 ]; // the prompt template is ~170 characters. We use 200 to be safe.
530+
531+ if (cli_system_prompt != NULL ) {
532+ strcpy (system_prompt, cli_system_prompt);
533+ } else {
534+ read_stdin (" Enter system prompt (optional): " , system_prompt, sizeof (system_prompt));
535+ }
536+
537+ if (cli_user_prompt != NULL ) {
538+ strcpy (user_prompt, cli_user_prompt);
539+ } else {
540+ read_stdin (" User: " , user_prompt, sizeof (user_prompt));
541+ }
542+
543+ ModelType model_type = get_model_type (tokenizer);
544+ std::vector<uint64_t > tokens;
545+
546+ switch (model_type) {
547+
548+ case ModelType::llama2:
549+ if (system_prompt[0 ] != ' \0 ' ) {
550+ snprintf (
551+ rendered_prompt,
552+ sizeof (rendered_prompt)-1 ,
553+ " [INST] <<SYS>>\n %s\n <</SYS>>\n\n %s [/INST]" ,
554+ system_prompt,
555+ user_prompt
556+ );
557+ } else {
558+ // const char prompt_template[] = ;
559+ snprintf (
560+ rendered_prompt,
561+ sizeof (rendered_prompt)-1 ,
562+ " [INST] %s [/INST]" ,
563+ user_prompt
564+ );
565+ }
566+
567+ // We need to add BOS token here and not in template because llama2 tokenizer
568+ // does not pattern match special tokens
569+ tokens = tokenizer->encode (rendered_prompt, 1 , 0 );
570+ break ;
571+
572+ case ModelType::llama3:
573+ if (system_prompt[0 ] != ' \0 ' ) {
574+ snprintf (
575+ rendered_prompt,
576+ sizeof (rendered_prompt)-1 ,
577+ " <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
578+ system_prompt,
579+ user_prompt
580+ );
581+ } else {
582+ snprintf (
583+ rendered_prompt,
584+ sizeof (rendered_prompt)-1 ,
585+ " <|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
586+ user_prompt
587+ );
588+ }
589+ tokens = tokenizer->encode (rendered_prompt, 0 , 0 );
590+ break ;
591+
592+ default :
593+ fprintf (stderr, " No chat template implemnation for model type %d" , model_type);
594+ exit (EXIT_FAILURE);
595+ }
596+
597+ #ifdef DEBUG
598+ std::cerr << " Start of rendered prompt:" << std::endl;
599+ std::cerr << rendered_prompt;
600+ std::cerr << " End of rendered prompt:" << std::endl;
601+ std::cerr << " Encoded prompt: " ;
602+ for (int i = 0 ; i < tokens.size (); i++) {
603+ std::cerr << tokens[i] << " , " ;
604+ }
605+ std::cerr << std::endl << std::flush;
606+ #endif
607+
608+ return tokens;
609+ }
610+
611+ std::vector<uint64_t > get_next_user_prompt_tokens (Tokenizer* tokenizer) {
612+ char user_prompt[512 ];
613+ char rendered_prompt[512 + 150 ]; // the prompt template is ~100 characters. We use 150 to be safe.
614+
615+ read_stdin (" User: " , user_prompt, sizeof (user_prompt));
616+
617+ ModelType model_type = get_model_type (tokenizer);
618+ std::vector<uint64_t > tokens;
619+
620+ switch (model_type) {
621+
622+ case ModelType::llama2:
623+ // const char prompt_template[] = ;
624+ snprintf (rendered_prompt, sizeof (rendered_prompt)-1 , " [INST] %s [/INST]" , user_prompt);
625+
626+ // We need to add BOS token here and not in template because llama2 tokenizer
627+ // does not pattern match special tokens
628+ tokens = tokenizer->encode (rendered_prompt, /* bos*/ 1 , /* eos*/ 0 );
629+ break ;
630+
631+ case ModelType::llama3:
632+ snprintf (
633+ rendered_prompt,
634+ sizeof (rendered_prompt)-1 ,
635+ " <|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
636+ user_prompt
637+ );
638+ tokens = tokenizer->encode (rendered_prompt, 0 , 0 );
639+ break ;
640+
641+ default :
642+ fprintf (stderr, " No chat template implemnation for model type %d" , model_type);
643+ exit (EXIT_FAILURE);
644+ }
645+
646+
647+ #ifdef DEBUG
648+ std::cerr << " Start of rendered prompt:" << std::endl;
649+ std::cerr << rendered_prompt;
650+ std::cerr << " End of rendered prompt:" << std::endl;
651+ std::cerr << " Encoded prompt: " ;
652+ for (int i = 0 ; i < tokens.size (); i++) {
653+ std::cerr << tokens[i] << " , " ;
654+ }
655+ std::cerr << std::endl << std::flush;
656+ #endif
657+
658+ return tokens;
659+ }
660+
661+
488662void chat (
489663 Transformer* transformer,
490664 Tokenizer* tokenizer,
491665 Sampler* sampler,
492666 const char * cli_user_prompt,
493667 const char * cli_system_prompt,
494668 int steps) {
495- // special tokens
496- const int SOS_TOKEN = tokenizer->bos_tok (); // token starts the assistant turn
497- const int EOS_TOKEN = tokenizer->eos_tok (); // token ends the assistant turn
498- const int SYSTEM_PROMPT_SIZE = 512 ;
499- const int USER_PROMPT_SIZE = 512 ;
500- const int RENDERED_PROMPT_SIZE = SYSTEM_PROMPT_SIZE + USER_PROMPT_SIZE + 128 ; // This is big enough to hold the expanded template
501-
502-
503669
504- // buffers for reading the system prompt and user prompt from stdin
505- // you'll notice they are soomewhat haphazardly and unsafely set atm
506- char system_prompt[SYSTEM_PROMPT_SIZE];
507- char user_prompt[USER_PROMPT_SIZE];
508- char rendered_prompt[RENDERED_PROMPT_SIZE];
670+ const uint64_t EOT_TOKEN = get_eot_token (tokenizer);
509671 int num_prompt_tokens = 0 ;
510672 std::vector<uint64_t > prompt_tokens;
511673 int user_idx;
@@ -522,41 +684,10 @@ void chat(
522684 if (user_turn) {
523685 // get the (optional) system prompt at position 0
524686 if (pos == 0 ) {
525- // at position 0, the user can also contribute a system prompt
526- if (cli_system_prompt == NULL ) {
527- // system prompt was not passed in, attempt to get it from stdin
528- read_stdin (
529- " Enter system prompt (optional): " ,
530- system_prompt,
531- sizeof (system_prompt));
532- } else {
533- // system prompt was passed in, use it
534- strcpy (system_prompt, cli_system_prompt);
535- }
536- }
537- // get the user prompt
538- if (pos == 0 && cli_user_prompt != NULL ) {
539- // user prompt for position 0 was passed in, use it
540- strcpy (user_prompt, cli_user_prompt);
687+ prompt_tokens = get_initial_prompt_tokens (cli_system_prompt, cli_user_prompt, tokenizer);
541688 } else {
542- // otherwise get user prompt from stdin
543- read_stdin (" User: " , user_prompt, sizeof (user_prompt));
689+ prompt_tokens = get_next_user_prompt_tokens (tokenizer);
544690 }
545- // render user/system prompts into the Llama 2 Chat schema
546- if (pos == 0 && system_prompt[0 ] != ' \0 ' ) {
547- // We do not add <s> because that is added by tokenizer->encode(x, 1, 0)
548- const char system_template[] = " [INST] <<SYS>>\n %s\n <</SYS>>\n\n %s [/INST]" ;
549- snprintf (
550- rendered_prompt, RENDERED_PROMPT_SIZE-1 , system_template, system_prompt, user_prompt);
551- } else {
552- // Assistant should produce </s>, so we do not include it in template
553- // We do not add <s> because that is added by tokenizer->encode(x, 1, 0)
554- const char user_template[] = " [INST] %s [/INST]" ;
555- snprintf (rendered_prompt, RENDERED_PROMPT_SIZE-1 , user_template, user_prompt);
556- }
557-
558- // encode the rendered prompt into tokens
559- prompt_tokens = tokenizer->encode (rendered_prompt, 1 , 0 );
560691 num_prompt_tokens = prompt_tokens.size ();
561692
562693 user_idx = 0 ; // reset the user index
@@ -578,19 +709,21 @@ void chat(
578709 float * logits = forward (transformer, token, pos);
579710 next = sample (sampler, logits);
580711
712+ // std::cout << "TOKEN: " << token << " NEXT: " << next << std::endl;
581713
582- if (token == EOS_TOKEN) {
714+
715+ if ((user_idx >= num_prompt_tokens) && (token == EOT_TOKEN)) {
583716 user_turn = 1 ;
584717 }
585718
586- if (user_idx >= num_prompt_tokens && token != EOS_TOKEN && next != EOS_TOKEN ) {
719+ if (user_idx >= num_prompt_tokens && token != EOT_TOKEN && next != EOT_TOKEN ) {
587720 std::string piece = tokenizer->decode (token, next);
588721 safe_printf (piece.c_str ()); // same as printf("%s", piece), but skips
589722 // "unsafe" bytes
590723 fflush (stdout);
591724 }
592725
593- if (next == EOS_TOKEN ) {
726+ if (next == EOT_TOKEN ) {
594727 printf (" \n " );
595728 }
596729 pos++;
@@ -619,6 +752,7 @@ void error_usage() {
619752 fprintf (stderr, " -z <string> optional path to custom tokenizer\n " );
620753 fprintf (stderr, " -m <string> mode: generate|chat, default: generate\n " );
621754 fprintf (stderr, " -y <string> (optional) system prompt in chat mode\n " );
755+ fprintf (stderr, " -l <int> (optional) llama version (2 or 3). Defaults to 2.\n " );
622756 exit (EXIT_FAILURE);
623757}
624758
@@ -630,14 +764,17 @@ int main(int argc, char* argv[]) {
630764 1 .0f ; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
631765 float topp =
632766 0 .9f ; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
633- int vocab_size = 32000 ;
767+
634768 int steps = 256 ; // number of steps to run for
635769 const char * prompt = NULL ; // prompt string
636770 unsigned long long rng_seed = 0 ; // seed rng with time by default
637771 const char * mode = " generate" ; // generate|chat
638772 char * system_prompt =
639773 NULL ; // the (optional) system prompt to use in chat mode
640774
775+ int vocab_size = -1 ;
776+ int llama_ver = 2 ;
777+
641778#if defined(ET_USE_ADPATIVE_THREADS)
642779 uint32_t num_performant_cores = torch::executorch::cpuinfo::get_num_performant_cores ();
643780 if (num_performant_cores > 0 ) {
@@ -682,7 +819,10 @@ int main(int argc, char* argv[]) {
682819 mode = argv[i + 1 ];
683820 } else if (argv[i][1 ] == ' y' ) {
684821 system_prompt = argv[i + 1 ];
685- } else {
822+ } else if (argv[i][1 ] == ' l' ) {
823+ llama_ver = atoi (argv[i+1 ]);
824+ }
825+ else {
686826 error_usage ();
687827 }
688828 }
@@ -697,27 +837,35 @@ int main(int argc, char* argv[]) {
697837 if (steps < 0 )
698838 steps = 0 ;
699839
840+
841+ if (vocab_size == -1 ) {
842+ if (llama_ver == 2 ) {
843+ vocab_size = 32000 ;
844+ } else {
845+ vocab_size = 128256 ;
846+ }
847+ }
848+
700849 // build the Transformer via the model .bin file
701850 Transformer transformer;
702851 build_transformer (&transformer, checkpoint_path, vocab_size, steps);
703852
704853 // build the Tokenizer via the tokenizer .bin file
705854 Tokenizer* tokenizer = nullptr ;
706855
707- // Try to load using Tiktoken, if exception then switch to another tokenizer
708- try {
709- tokenizer =
710- new Tiktoken (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
711- tokenizer->load (tokenizer_path);
712- } catch (const std::invalid_argument&) {
713- fprintf (
714- stderr,
715- " Failed to load %s into a Tiktoken tokenizer. Trying sentencepiece tokenizer..\n " ,
716- tokenizer_path);
717- delete tokenizer;
718- tokenizer =
719- new BPETokenizer (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
720- tokenizer->load (tokenizer_path);
856+ switch (llama_ver) {
857+ case 2 :
858+ tokenizer = new BPETokenizer (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
859+ tokenizer->load (tokenizer_path);
860+ break ;
861+ case 3 :
862+ tokenizer = new Tiktoken (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
863+ tokenizer->load (tokenizer_path);
864+ break ;
865+
866+ default :
867+ fprintf (stderr, " Cannot load tokenizer for unrecognized llama version %d" , llama_ver);
868+ exit (EXIT_FAILURE);
721869 }
722870
723871 // build the Sampler
0 commit comments