Skip to content

Commit 975d2ce

Browse files
anzz1ggerganov
andauthored
cmdline option for custom amount of model parts (--n_parts N) (abetlen#348)
* cmdline option for custom amount of model parts (--n_parts N) * Update main.cpp --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent e0ffc86 commit 975d2ce

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

main.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ struct llama_model {
9090
};
9191

9292
// load the model's weights from a file
93-
bool llama_model_load(const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) {
93+
94+
bool llama_model_load(const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, int n_parts, ggml_type memory_type = GGML_TYPE_F32) {
9495
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
9596

9697
std::vector<char> f_buf(1024*1024);
@@ -127,7 +128,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca
127128
}
128129

129130
int n_ff = 0;
130-
int n_parts = 0;
131131

132132
// load hparams
133133
{
@@ -145,7 +145,10 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca
145145
hparams.n_ctx = n_ctx;
146146

147147
n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
148-
n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
148+
149+
if (n_parts < 1) {
150+
n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
151+
}
149152

150153
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
151154
fprintf(stderr, "%s: n_ctx = %d\n", __func__, hparams.n_ctx);
@@ -839,7 +842,7 @@ int main(int argc, char ** argv) {
839842
{
840843
const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
841844
const int64_t t_start_us = ggml_time_us();
842-
if (!llama_model_load(params.model, model, vocab, params.n_ctx, memory_type)) {
845+
if (!llama_model_load(params.model, model, vocab, params.n_ctx, params.n_parts, memory_type)) {
843846
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
844847
return 1;
845848
}

utils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
7474
params.antiprompt.push_back(argv[++i]);
7575
} else if (arg == "--ignore-eos") {
7676
params.ignore_eos = true;
77+
} else if (arg == "--n_parts") {
78+
params.n_parts = std::stoi(argv[++i]);
7779
} else if (arg == "-h" || arg == "--help") {
7880
gpt_print_usage(argc, argv, params);
7981
exit(0);
@@ -116,6 +118,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
116118
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
117119
fprintf(stderr, " --memory_f16 use f16 instead of f32 for memory key+value\n");
118120
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
121+
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
119122
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
120123
fprintf(stderr, " -m FNAME, --model FNAME\n");
121124
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());

utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
//
1414

1515
struct gpt_params {
16-
int32_t seed = -1; // RNG seed
16+
int32_t seed = -1; // RNG seed
1717
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
1818
int32_t n_predict = 128; // new tokens to predict
1919
int32_t repeat_last_n = 64; // last n tokens to penalize
20+
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
2021
int32_t n_ctx = 512; //context size
2122

2223
// sampling parameters

0 commit comments

Comments
 (0)