Skip to content

Commit c23598e

Browse files
authored
talk-llama : add n_gpu_layers parameter (#1475)
1 parent 54a08bd commit c23598e

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

examples/talk-llama/talk-llama.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ struct whisper_params {
5353
int32_t capture_id = -1;
5454
int32_t max_tokens = 32;
5555
int32_t audio_ctx = 0;
56+
int32_t n_gpu_layers = 0;
5657

5758
float vad_thold = 0.6f;
5859
float freq_thold = 100.0f;
@@ -90,6 +91,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
9091
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
9192
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
9293
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
94+
else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); }
9395
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
9496
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
9597
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
@@ -134,6 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
134136
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
135137
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
136138
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
139+
fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7s] number of layers to store in VRAM\n", params.n_gpu_layers);
137140
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
138141
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
139142
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
@@ -268,6 +271,8 @@ int main(int argc, char ** argv) {
268271
auto lmparams = llama_model_default_params();
269272
if (!params.use_gpu) {
270273
lmparams.n_gpu_layers = 0;
274+
} else {
275+
lmparams.n_gpu_layers = params.n_gpu_layers;
271276
}
272277

273278
struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lmparams);

0 commit comments

Comments
 (0)