11/* Inference for Llama-2 Transformer model in pure C++ */
2- #include < cstdint>
3- #include < cstdlib>
42#include < ctype.h>
5- #include < iterator>
63#include < math.h>
74#include < stdint.h>
85#include < stdio.h>
96#include < stdlib.h>
107#include < string.h>
118#include < time.h>
129#include < tokenizer.h>
10+ #include < cstdint>
11+ #include < cstdlib>
12+ #include < iterator>
1313#include < string>
1414
15-
1615#ifdef DEBUG
1716#include < cassert>
1817#include < iostream>
@@ -167,22 +166,14 @@ float* forward(Transformer* transformer, int token, int pos) {
167166 torch::Tensor pos_tensor = torch::from_blob (pos_buffer, {1 }, torch::kLong );
168167 std::vector<torch::Tensor> inputs{token_tensor, pos_tensor};
169168
170- torch::Tensor result = transformer->runner ->run (inputs)[0 ].to (torch::dtype (torch::kFloat32 ));
169+ torch::Tensor result =
170+ transformer->runner ->run (inputs)[0 ].to (torch::dtype (torch::kFloat32 ));
171171 auto logits = result[0 ].data_ptr ();
172172
173173#else // __ET_MODEL__
174174 ManagedTensor pos_managed (pos_buffer, sizeof (int64_t ), {1 }, ScalarType::Long);
175- #ifndef __KV_CACHE__
176- // @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
177- ManagedTensor tokens_managed (
178- &(s->toks [pos]),
179- /* ignored*/ sizeof (int64_t ) * (pos + 1 ),
180- {1 , 1 },
181- ScalarType::Long);
182- #else // __KV_CACHE__
183175 ManagedTensor tokens_managed (
184176 token_buffer, sizeof (int64_t ), {1 , 1 }, ScalarType::Long);
185- #endif
186177 std::vector<EValue> inputs;
187178 auto tmp1 = EValue (tokens_managed.get_aliasing_tensor ());
188179 auto tmp2 = EValue (pos_managed.get_aliasing_tensor ());
@@ -491,9 +482,9 @@ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
491482// is not safely implemented, it's more a proof of concept atm.
492483
493484enum class ModelType {
494- unknown,
495- llama2,
496- llama3,
485+ unknown,
486+ llama2,
487+ llama3,
497488};
498489
499490ModelType get_model_type (Tokenizer* tokenizer) {
@@ -519,19 +510,27 @@ uint64_t get_eot_token(Tokenizer* tokenizer) {
519510 return tokens[0 ];
520511 }
521512
522- fprintf (stderr, " No chat template implemnation for model type %d" , model_type);
513+ fprintf (
514+ stderr, " No chat template implemnation for model type %d" , model_type);
523515 exit (EXIT_FAILURE);
524516}
525517
526- std::vector<uint64_t > get_initial_prompt_tokens (const char * cli_system_prompt, const char * cli_user_prompt, Tokenizer* tokenizer) {
518+ std::vector<uint64_t > get_initial_prompt_tokens (
519+ const char * cli_system_prompt,
520+ const char * cli_user_prompt,
521+ Tokenizer* tokenizer) {
527522 char system_prompt[512 ];
528523 char user_prompt[512 ];
529- char rendered_prompt[512 *2 + 200 ]; // the prompt template is ~170 characters. We use 200 to be safe.
524+ char rendered_prompt[512 * 2 + 200 ]; // the prompt template is ~170
525+ // characters. We use 200 to be safe.
530526
531527 if (cli_system_prompt != NULL ) {
532528 strcpy (system_prompt, cli_system_prompt);
533529 } else {
534- read_stdin (" Enter system prompt (optional): " , system_prompt, sizeof (system_prompt));
530+ read_stdin (
531+ " Enter system prompt (optional): " ,
532+ system_prompt,
533+ sizeof (system_prompt));
535534 }
536535
537536 if (cli_user_prompt != NULL ) {
@@ -540,111 +539,114 @@ std::vector<uint64_t> get_initial_prompt_tokens(const char* cli_system_prompt, c
540539 read_stdin (" User: " , user_prompt, sizeof (user_prompt));
541540 }
542541
543- ModelType model_type = get_model_type (tokenizer);
544- std::vector<uint64_t > tokens;
545-
546- switch (model_type) {
542+ ModelType model_type = get_model_type (tokenizer);
543+ std::vector<uint64_t > tokens;
547544
545+ switch (model_type) {
548546 case ModelType::llama2:
549547 if (system_prompt[0 ] != ' \0 ' ) {
550548 snprintf (
551- rendered_prompt,
552- sizeof (rendered_prompt)-1 ,
553- " [INST] <<SYS>>\n %s\n <</SYS>>\n\n %s [/INST]" ,
554- system_prompt,
555- user_prompt
556- );
549+ rendered_prompt,
550+ sizeof (rendered_prompt) - 1 ,
551+ " [INST] <<SYS>>\n %s\n <</SYS>>\n\n %s [/INST]" ,
552+ system_prompt,
553+ user_prompt);
557554 } else {
558555 // const char prompt_template[] = ;
559556 snprintf (
560- rendered_prompt,
561- sizeof (rendered_prompt)-1 ,
562- " [INST] %s [/INST]" ,
563- user_prompt
564- );
557+ rendered_prompt,
558+ sizeof (rendered_prompt) - 1 ,
559+ " [INST] %s [/INST]" ,
560+ user_prompt);
565561 }
566562
567- // We need to add BOS token here and not in template because llama2 tokenizer
568- // does not pattern match special tokens
563+ // We need to add BOS token here and not in template because llama2
564+ // tokenizer does not pattern match special tokens
569565 tokens = tokenizer->encode (rendered_prompt, 1 , 0 );
570566 break ;
571567
572568 case ModelType::llama3:
573569 if (system_prompt[0 ] != ' \0 ' ) {
574570 snprintf (
575- rendered_prompt,
576- sizeof (rendered_prompt)-1 ,
577- " <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
578- system_prompt,
579- user_prompt
580- );
571+ rendered_prompt,
572+ sizeof (rendered_prompt) - 1 ,
573+ " <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
574+ system_prompt,
575+ user_prompt);
581576 } else {
582577 snprintf (
583- rendered_prompt,
584- sizeof (rendered_prompt)-1 ,
585- " <|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
586- user_prompt
587- );
578+ rendered_prompt,
579+ sizeof (rendered_prompt) - 1 ,
580+ " <|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
581+ user_prompt);
588582 }
589583 tokens = tokenizer->encode (rendered_prompt, 0 , 0 );
590584 break ;
591585
592586 default :
593- fprintf (stderr, " No chat template implemnation for model type %d" , model_type);
587+ fprintf (
588+ stderr,
589+ " No chat template implemnation for model type %d" ,
590+ model_type);
594591 exit (EXIT_FAILURE);
595- }
592+ }
596593
597- #ifdef DEBUG
598- std::cerr << " Start of rendered prompt:" << std::endl;
599- std::cerr << rendered_prompt;
600- std::cerr << " End of rendered prompt:" << std::endl;
601- std::cerr << " Encoded prompt: " ;
602- for (int i = 0 ; i < tokens.size (); i++) {
603- std::cerr << tokens[i] << " , " ;
604- }
605- std::cerr << std::endl << std::flush;
606- #endif
594+ #ifdef DEBUG
595+ std::cerr << " Start of rendered prompt:" << std::endl;
596+ std::cerr << rendered_prompt;
597+ std::cerr << " End of rendered prompt:" << std::endl;
598+ std::cerr << " Encoded prompt: " ;
599+ for (int i = 0 ; i < tokens.size (); i++) {
600+ std::cerr << tokens[i] << " , " ;
601+ }
602+ std::cerr << std::endl << std::flush;
603+ #endif
607604
608- return tokens;
605+ return tokens;
609606}
610607
611608std::vector<uint64_t > get_next_user_prompt_tokens (Tokenizer* tokenizer) {
612609 char user_prompt[512 ];
613- char rendered_prompt[512 + 150 ]; // the prompt template is ~100 characters. We use 150 to be safe.
610+ char rendered_prompt[512 + 150 ]; // the prompt template is ~100 characters. We
611+ // use 150 to be safe.
614612
615613 read_stdin (" User: " , user_prompt, sizeof (user_prompt));
616614
617615 ModelType model_type = get_model_type (tokenizer);
618616 std::vector<uint64_t > tokens;
619617
620618 switch (model_type) {
621-
622619 case ModelType::llama2:
623620 // const char prompt_template[] = ;
624- snprintf (rendered_prompt, sizeof (rendered_prompt)-1 , " [INST] %s [/INST]" , user_prompt);
621+ snprintf (
622+ rendered_prompt,
623+ sizeof (rendered_prompt) - 1 ,
624+ " [INST] %s [/INST]" ,
625+ user_prompt);
625626
626- // We need to add BOS token here and not in template because llama2 tokenizer
627- // does not pattern match special tokens
628- tokens = tokenizer->encode (rendered_prompt, /* bos*/ 1 , /* eos*/ 0 );
627+ // We need to add BOS token here and not in template because llama2
628+ // tokenizer does not pattern match special tokens
629+ tokens = tokenizer->encode (rendered_prompt, /* bos*/ 1 , /* eos*/ 0 );
629630 break ;
630631
631632 case ModelType::llama3:
632633 snprintf (
633- rendered_prompt,
634- sizeof (rendered_prompt)-1 ,
635- " <|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
636- user_prompt
637- );
634+ rendered_prompt,
635+ sizeof (rendered_prompt) - 1 ,
636+ " <|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ,
637+ user_prompt);
638638 tokens = tokenizer->encode (rendered_prompt, 0 , 0 );
639639 break ;
640640
641641 default :
642- fprintf (stderr, " No chat template implemnation for model type %d" , model_type);
642+ fprintf (
643+ stderr,
644+ " No chat template implemnation for model type %d" ,
645+ model_type);
643646 exit (EXIT_FAILURE);
644647 }
645648
646-
647- #ifdef DEBUG
649+ #ifdef DEBUG
648650 std::cerr << " Start of rendered prompt:" << std::endl;
649651 std::cerr << rendered_prompt;
650652 std::cerr << " End of rendered prompt:" << std::endl;
@@ -653,20 +655,18 @@ std::vector<uint64_t> get_next_user_prompt_tokens(Tokenizer* tokenizer) {
653655 std::cerr << tokens[i] << " , " ;
654656 }
655657 std::cerr << std::endl << std::flush;
656- #endif
658+ #endif
657659
658660 return tokens;
659661}
660662
661-
662663void chat (
663664 Transformer* transformer,
664665 Tokenizer* tokenizer,
665666 Sampler* sampler,
666667 const char * cli_user_prompt,
667668 const char * cli_system_prompt,
668669 int steps) {
669-
670670 const uint64_t EOT_TOKEN = get_eot_token (tokenizer);
671671 int num_prompt_tokens = 0 ;
672672 std::vector<uint64_t > prompt_tokens;
@@ -679,12 +679,12 @@ void chat(
679679 int prev_token;
680680 int pos = 0 ; // position in the sequence
681681 while (pos < steps) {
682-
683682 // when it is the user's turn to contribute tokens to the dialog...
684683 if (user_turn) {
685684 // get the (optional) system prompt at position 0
686685 if (pos == 0 ) {
687- prompt_tokens = get_initial_prompt_tokens (cli_system_prompt, cli_user_prompt, tokenizer);
686+ prompt_tokens = get_initial_prompt_tokens (
687+ cli_system_prompt, cli_user_prompt, tokenizer);
688688 } else {
689689 prompt_tokens = get_next_user_prompt_tokens (tokenizer);
690690 }
@@ -711,12 +711,12 @@ void chat(
711711
712712 // std::cout << "TOKEN: " << token << " NEXT: " << next << std::endl;
713713
714-
715714 if ((user_idx >= num_prompt_tokens) && (token == EOT_TOKEN)) {
716715 user_turn = 1 ;
717716 }
718717
719- if (user_idx >= num_prompt_tokens && token != EOT_TOKEN && next != EOT_TOKEN) {
718+ if (user_idx >= num_prompt_tokens && token != EOT_TOKEN &&
719+ next != EOT_TOKEN) {
720720 std::string piece = tokenizer->decode (token, next);
721721 safe_printf (piece.c_str ()); // same as printf("%s", piece), but skips
722722 // "unsafe" bytes
@@ -727,7 +727,6 @@ void chat(
727727 printf (" \n " );
728728 }
729729 pos++;
730-
731730 }
732731 printf (" \n " );
733732}
@@ -752,7 +751,9 @@ void error_usage() {
752751 fprintf (stderr, " -z <string> optional path to custom tokenizer\n " );
753752 fprintf (stderr, " -m <string> mode: generate|chat, default: generate\n " );
754753 fprintf (stderr, " -y <string> (optional) system prompt in chat mode\n " );
755- fprintf (stderr, " -l <int> (optional) llama version (2 or 3). Defaults to 2.\n " );
754+ fprintf (
755+ stderr,
756+ " -l <int> (optional) llama version (2 or 3). Defaults to 2.\n " );
756757 exit (EXIT_FAILURE);
757758}
758759
@@ -776,7 +777,8 @@ int main(int argc, char* argv[]) {
776777 int llama_ver = 2 ;
777778
778779#if defined(ET_USE_ADPATIVE_THREADS)
779- uint32_t num_performant_cores = torch::executorch::cpuinfo::get_num_performant_cores ();
780+ uint32_t num_performant_cores =
781+ torch::executorch::cpuinfo::get_num_performant_cores ();
780782 if (num_performant_cores > 0 ) {
781783 torch::executorch::threadpool::get_threadpool ()->_unsafe_reset_threadpool (
782784 num_performant_cores);
@@ -820,9 +822,8 @@ int main(int argc, char* argv[]) {
820822 } else if (argv[i][1 ] == ' y' ) {
821823 system_prompt = argv[i + 1 ];
822824 } else if (argv[i][1 ] == ' l' ) {
823- llama_ver = atoi (argv[i+1 ]);
824- }
825- else {
825+ llama_ver = atoi (argv[i + 1 ]);
826+ } else {
826827 error_usage ();
827828 }
828829 }
@@ -837,7 +838,6 @@ int main(int argc, char* argv[]) {
837838 if (steps < 0 )
838839 steps = 0 ;
839840
840-
841841 if (vocab_size == -1 ) {
842842 if (llama_ver == 2 ) {
843843 vocab_size = 32000 ;
@@ -855,16 +855,21 @@ int main(int argc, char* argv[]) {
855855
856856 switch (llama_ver) {
857857 case 2 :
858- tokenizer = new BPETokenizer (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
858+ tokenizer =
859+ new BPETokenizer (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
859860 tokenizer->load (tokenizer_path);
860861 break ;
861862 case 3 :
862- tokenizer = new Tiktoken (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
863+ tokenizer =
864+ new Tiktoken (transformer.config .vocab_size , /* bos*/ 1 , /* eos*/ 2 );
863865 tokenizer->load (tokenizer_path);
864866 break ;
865867
866868 default :
867- fprintf (stderr, " Cannot load tokenizer for unrecognized llama version %d" , llama_ver);
869+ fprintf (
870+ stderr,
871+ " Cannot load tokenizer for unrecognized llama version %d" ,
872+ llama_ver);
868873 exit (EXIT_FAILURE);
869874 }
870875
0 commit comments