@@ -53,6 +53,7 @@ struct whisper_params {
53
53
int32_t capture_id = -1 ;
54
54
int32_t max_tokens = 32 ;
55
55
int32_t audio_ctx = 0 ;
56
+ int32_t n_gpu_layers = 0 ;
56
57
57
58
float vad_thold = 0 .6f ;
58
59
float freq_thold = 100 .0f ;
@@ -90,6 +91,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
90
91
else if (arg == " -c" || arg == " --capture" ) { params.capture_id = std::stoi (argv[++i]); }
91
92
else if (arg == " -mt" || arg == " --max-tokens" ) { params.max_tokens = std::stoi (argv[++i]); }
92
93
else if (arg == " -ac" || arg == " --audio-ctx" ) { params.audio_ctx = std::stoi (argv[++i]); }
94
+ else if (arg == " -ngl" || arg == " --n-gpu-layers" ) { params.n_gpu_layers = std::stoi (argv[++i]); }
93
95
else if (arg == " -vth" || arg == " --vad-thold" ) { params.vad_thold = std::stof (argv[++i]); }
94
96
else if (arg == " -fth" || arg == " --freq-thold" ) { params.freq_thold = std::stof (argv[++i]); }
95
97
else if (arg == " -su" || arg == " --speed-up" ) { params.speed_up = true ; }
@@ -134,6 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
134
136
fprintf (stderr, " -c ID, --capture ID [%-7d] capture device ID\n " , params.capture_id );
135
137
fprintf (stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n " , params.max_tokens );
136
138
fprintf (stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n " , params.audio_ctx );
139
+ fprintf (stderr, " -ngl N, --n-gpu-layers N [%-7s] number of layers to store in VRAM\n " , params.n_gpu_layers );
137
140
fprintf (stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n " , params.vad_thold );
138
141
fprintf (stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n " , params.freq_thold );
139
142
fprintf (stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n " , params.speed_up ? " true" : " false" );
@@ -268,6 +271,8 @@ int main(int argc, char ** argv) {
268
271
auto lmparams = llama_model_default_params ();
269
272
if (!params.use_gpu ) {
270
273
lmparams.n_gpu_layers = 0 ;
274
+ } else {
275
+ lmparams.n_gpu_layers = params.n_gpu_layers ;
271
276
}
272
277
273
278
struct llama_model * model_llama = llama_load_model_from_file (params.model_llama .c_str (), lmparams);
0 commit comments