Skip to content

Commit 8b9a9dc

Browse files
committed
Move model loading back to main.cpp
Signed-off-by: Thiago Padilha <[email protected]>
1 parent 734a858 commit 8b9a9dc

File tree

3 files changed

+77
-61
lines changed

3 files changed

+77
-61
lines changed

llama.cpp

Lines changed: 6 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -731,41 +731,12 @@ void sigint_handler(int signo) {
731731
}
732732
#endif
733733

734-
const char * llama_print_system_info(void) {
735-
static std::string s;
736-
737-
s = "";
738-
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
739-
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
740-
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
741-
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
742-
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
743-
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
744-
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
745-
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
746-
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
747-
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
748-
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
749-
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
750-
751-
return s.c_str();
752-
}
753-
754-
int llama_main(int argc, char ** argv) {
755-
ggml_time_init();
756-
const int64_t t_main_start_us = ggml_time_us();
757-
758-
gpt_params params;
759-
params.model = "models/llama-7B/ggml-model.bin";
760-
761-
if (gpt_params_parse(argc, argv, params) == false) {
762-
return 1;
763-
}
764-
765-
if (params.n_ctx > 2048) {
766-
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
767-
"expect poor results\n", __func__, params.n_ctx);
768-
}
734+
int llama_main(
735+
gpt_params params,
736+
gpt_vocab vocab,
737+
llama_model model,
738+
int64_t t_load_us,
739+
int64_t t_main_start_us) {
769740

770741
if (params.seed < 0) {
771742
params.seed = time(NULL);
@@ -781,30 +752,6 @@ int llama_main(int argc, char ** argv) {
781752
// params.prompt = R"(// this function checks if the number n is prime
782753
//bool is_prime(int n) {)";
783754

784-
int64_t t_load_us = 0;
785-
786-
gpt_vocab vocab;
787-
llama_model model;
788-
789-
// load the model
790-
{
791-
const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
792-
const int64_t t_start_us = ggml_time_us();
793-
if (!llama_model_load(params.model, model, vocab, params.n_ctx, memory_type)) {
794-
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
795-
return 1;
796-
}
797-
798-
t_load_us = ggml_time_us() - t_start_us;
799-
}
800-
801-
// print system information
802-
{
803-
fprintf(stderr, "\n");
804-
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
805-
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
806-
}
807-
808755
int n_past = 0;
809756

810757
int64_t t_sample_us = 0;

llama.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <string>
77

88
#include "ggml.h"
9+
#include "utils.h"
910

1011

1112
// default hparams (LLaMA 7B)
@@ -58,4 +59,10 @@ struct llama_model {
5859
std::map<std::string, struct ggml_tensor *> tensors;
5960
};
6061

61-
int llama_main(int argc, char ** argv);
62+
int llama_main(
63+
gpt_params params,
64+
gpt_vocab vocab,
65+
llama_model model,
66+
int64_t t_load_us,
67+
int64_t t_main_start_us);
68+
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type);

main.cpp

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,67 @@
1+
#include "ggml.h"
2+
#include "utils.h"
13
#include "llama.h"
24

5+
const char * llama_print_system_info(void) {
6+
static std::string s;
7+
8+
s = "";
9+
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
10+
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
11+
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
12+
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
13+
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
14+
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
15+
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
16+
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
17+
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
18+
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
19+
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
20+
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
21+
22+
return s.c_str();
23+
}
24+
325
int main(int argc, char ** argv) {
4-
return llama_main(argc, argv);
26+
27+
ggml_time_init();
28+
const int64_t t_main_start_us = ggml_time_us();
29+
30+
gpt_params params;
31+
params.model = "models/llama-7B/ggml-model.bin";
32+
33+
if (gpt_params_parse(argc, argv, params) == false) {
34+
return 1;
35+
}
36+
37+
if (params.n_ctx > 2048) {
38+
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
39+
"expect poor results\n", __func__, params.n_ctx);
40+
}
41+
42+
int64_t t_load_us = 0;
43+
44+
gpt_vocab vocab;
45+
llama_model model;
46+
47+
// load the model
48+
{
49+
const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
50+
const int64_t t_start_us = ggml_time_us();
51+
if (!llama_model_load(params.model, model, vocab, params.n_ctx, memory_type)) {
52+
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
53+
return 1;
54+
}
55+
56+
t_load_us = ggml_time_us() - t_start_us;
57+
}
58+
59+
// print system information
60+
{
61+
fprintf(stderr, "\n");
62+
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
63+
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
64+
}
65+
66+
return llama_main(params, vocab, model, t_main_start_us, t_load_us);
567
}

0 commit comments

Comments
 (0)