@@ -90,7 +90,8 @@ struct llama_model {
90
90
};
91
91
92
92
// load the model's weights from a file
93
- bool llama_model_load (const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) {
93
+
94
+ bool llama_model_load (const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, int n_parts, ggml_type memory_type = GGML_TYPE_F32) {
94
95
fprintf (stderr, " %s: loading model from '%s' - please wait ...\n " , __func__, fname.c_str ());
95
96
96
97
std::vector<char > f_buf (1024 *1024 );
@@ -127,7 +128,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca
127
128
}
128
129
129
130
int n_ff = 0 ;
130
- int n_parts = 0 ;
131
131
132
132
// load hparams
133
133
{
@@ -145,7 +145,10 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca
145
145
hparams.n_ctx = n_ctx;
146
146
147
147
n_ff = ((2 *(4 *hparams.n_embd )/3 + hparams.n_mult - 1 )/hparams.n_mult )*hparams.n_mult ;
148
- n_parts = LLAMA_N_PARTS.at (hparams.n_embd );
148
+
149
+ if (n_parts < 1 ) {
150
+ n_parts = LLAMA_N_PARTS.at (hparams.n_embd );
151
+ }
149
152
150
153
fprintf (stderr, " %s: n_vocab = %d\n " , __func__, hparams.n_vocab );
151
154
fprintf (stderr, " %s: n_ctx = %d\n " , __func__, hparams.n_ctx );
@@ -839,7 +842,7 @@ int main(int argc, char ** argv) {
839
842
{
840
843
const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
841
844
const int64_t t_start_us = ggml_time_us ();
842
- if (!llama_model_load (params.model , model, vocab, params.n_ctx , memory_type)) {
845
+ if (!llama_model_load (params.model , model, vocab, params.n_ctx , params. n_parts , memory_type)) {
843
846
fprintf (stderr, " %s: failed to load model from '%s'\n " , __func__, params.model .c_str ());
844
847
return 1 ;
845
848
}
0 commit comments