Skip to content

Commit 0463028

Browse files
jhen0409ggerganov
andauthored
whisper : add context param to disable gpu (#1293)
* whisper : check state->ctx_metal not null * whisper : add whisper_context_params { use_gpu } * whisper : new API with params & deprecate old API * examples : use no-gpu param && whisper_init_from_file_with_params * whisper.objc : enable metal & disable on simulator * whisper.swiftui, metal : enable metal & support load default.metallib * whisper.android : use new API * bindings : use new API * addon.node : fix build & test * bindings : updata java binding * bindings : add missing whisper_context_default_params_by_ref WHISPER_API for java * metal : use SWIFTPM_MODULE_BUNDLE for GGML_SWIFT and reuse library load * metal : move bundle var into block * metal : use SWIFT_PACKAGE instead of GGML_SWIFT * style : minor updates --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 39cfad0 commit 0463028

File tree

29 files changed

+421
-170
lines changed

29 files changed

+421
-170
lines changed

bindings/go/whisper.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ var (
103103
func Whisper_init(path string) *Context {
104104
cPath := C.CString(path)
105105
defer C.free(unsafe.Pointer(cPath))
106-
if ctx := C.whisper_init_from_file(cPath); ctx != nil {
106+
if ctx := C.whisper_init_from_file_with_params(cPath, C.whisper_context_default_params()); ctx != nil {
107107
return (*Context)(ctx)
108108
} else {
109109
return nil

bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperContext.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import com.sun.jna.ptr.PointerByReference;
55
import io.github.ggerganov.whispercpp.ggml.GgmlType;
66
import io.github.ggerganov.whispercpp.WhisperModel;
7+
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
78

89
import java.util.List;
910

@@ -23,8 +24,9 @@ public class WhisperContext extends Structure {
2324
public PointerByReference vocab;
2425
public PointerByReference state;
2526

26-
/** populated by whisper_init_from_file() */
27+
/** populated by whisper_init_from_file_with_params() */
2728
String path_model;
29+
WhisperContextParams params;
2830

2931
// public static class ByReference extends WhisperContext implements Structure.ByReference {
3032
// }

bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.sun.jna.Native;
44
import com.sun.jna.Pointer;
5+
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
56
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
67
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
78

@@ -15,8 +16,9 @@
1516
public class WhisperCpp implements AutoCloseable {
1617
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
1718
private Pointer ctx = null;
18-
private Pointer greedyPointer = null;
19-
private Pointer beamPointer = null;
19+
private Pointer paramsPointer = null;
20+
private Pointer greedyParamsPointer = null;
21+
private Pointer beamParamsPointer = null;
2022

2123
public File modelDir() {
2224
String modelDirPath = System.getenv("XDG_CACHE_HOME");
@@ -31,6 +33,18 @@ public File modelDir() {
3133
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
3234
*/
3335
public void initContext(String modelPath) throws FileNotFoundException {
36+
initContextImpl(modelPath, getContextDefaultParams());
37+
}
38+
39+
/**
40+
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
41+
* @param params - params to use when initialising the context
42+
*/
43+
public void initContext(String modelPath, WhisperContextParams params) throws FileNotFoundException {
44+
initContextImpl(modelPath, params);
45+
}
46+
47+
private void initContextImpl(String modelPath, WhisperContextParams params) throws FileNotFoundException {
3448
if (ctx != null) {
3549
lib.whisper_free(ctx);
3650
}
@@ -43,13 +57,26 @@ public void initContext(String modelPath) throws FileNotFoundException {
4357
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
4458
}
4559

46-
ctx = lib.whisper_init_from_file(modelPath);
60+
ctx = lib.whisper_init_from_file_with_params(modelPath, params);
4761

4862
if (ctx == null) {
4963
throw new FileNotFoundException(modelPath);
5064
}
5165
}
5266

67+
/**
68+
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
69+
* Because this function allocates memory for the params, the caller must call either:
70+
* - call `whisper_free_context_params()`
71+
* - `Native.free(Pointer.nativeValue(pointer));`
72+
*/
73+
public WhisperContextParams getContextDefaultParams() {
74+
paramsPointer = lib.whisper_context_default_params_by_ref();
75+
WhisperContextParams params = new WhisperContextParams(paramsPointer);
76+
params.read();
77+
return params;
78+
}
79+
5380
/**
5481
* Provides default params which can be used with `whisper_full()` etc.
5582
* Because this function allocates memory for the params, the caller must call either:
@@ -63,15 +90,15 @@ public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy)
6390

6491
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
6592
if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {
66-
if (greedyPointer == null) {
67-
greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
93+
if (greedyParamsPointer == null) {
94+
greedyParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
6895
}
69-
pointer = greedyPointer;
96+
pointer = greedyParamsPointer;
7097
} else {
71-
if (beamPointer == null) {
72-
beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
98+
if (beamParamsPointer == null) {
99+
beamParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
73100
}
74-
pointer = beamPointer;
101+
pointer = beamParamsPointer;
75102
}
76103

77104
WhisperFullParams params = new WhisperFullParams(pointer);
@@ -93,13 +120,17 @@ private void freeContext() {
93120
}
94121

95122
private void freeParams() {
96-
if (greedyPointer != null) {
97-
Native.free(Pointer.nativeValue(greedyPointer));
98-
greedyPointer = null;
123+
if (paramsPointer != null) {
124+
Native.free(Pointer.nativeValue(paramsPointer));
125+
paramsPointer = null;
126+
}
127+
if (greedyParamsPointer != null) {
128+
Native.free(Pointer.nativeValue(greedyParamsPointer));
129+
greedyParamsPointer = null;
99130
}
100-
if (beamPointer != null) {
101-
Native.free(Pointer.nativeValue(beamPointer));
102-
beamPointer = null;
131+
if (beamParamsPointer != null) {
132+
Native.free(Pointer.nativeValue(beamParamsPointer));
133+
beamParamsPointer = null;
103134
}
104135
}
105136

bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import com.sun.jna.Pointer;
66
import io.github.ggerganov.whispercpp.model.WhisperModelLoader;
77
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
8+
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
89
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
910

1011
public interface WhisperCppJnaLibrary extends Library {
@@ -13,12 +14,31 @@ public interface WhisperCppJnaLibrary extends Library {
1314
String whisper_print_system_info();
1415

1516
/**
16-
* Allocate (almost) all memory needed for the model by loading from a file.
17+
* DEPRECATED. Allocate (almost) all memory needed for the model by loading from a file.
1718
*
1819
* @param path_model Path to the model file
1920
* @return Whisper context on success, null on failure
2021
*/
2122
Pointer whisper_init_from_file(String path_model);
23+
24+
/**
25+
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
26+
* Because this function allocates memory for the params, the caller must call either:
27+
* - call `whisper_free_context_params()`
28+
* - `Native.free(Pointer.nativeValue(pointer));`
29+
*/
30+
Pointer whisper_context_default_params_by_ref();
31+
32+
void whisper_free_context_params(Pointer params);
33+
34+
/**
35+
* Allocate (almost) all memory needed for the model by loading from a file.
36+
*
37+
* @param path_model Path to the model file
38+
* @param params Pointer to whisper_context_params
39+
* @return Whisper context on success, null on failure
40+
*/
41+
Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams params);
2242

2343
/**
2444
* Allocate (almost) all memory needed for the model by loading from a buffer.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package io.github.ggerganov.whispercpp.params;
2+
3+
import com.sun.jna.*;
4+
5+
import java.util.Arrays;
6+
import java.util.List;
7+
8+
/**
9+
* Parameters for the whisper_init_from_file_with_params() function.
10+
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
11+
* whisper_context_default_params()
12+
*/
13+
public class WhisperContextParams extends Structure {
14+
15+
public WhisperContextParams(Pointer p) {
16+
super(p);
17+
}
18+
19+
/** Use GPU for inference Number (default = true) */
20+
public CBool use_gpu;
21+
22+
/** Use GPU for inference Number (default = true) */
23+
public void useGpu(boolean enable) {
24+
use_gpu = enable ? CBool.TRUE : CBool.FALSE;
25+
}
26+
27+
@Override
28+
protected List<String> getFieldOrder() {
29+
return Arrays.asList("use_gpu");
30+
}
31+
}

bindings/javascript/emscripten.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct whisper_context * g_context;
2020
EMSCRIPTEN_BINDINGS(whisper) {
2121
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
2222
if (g_context == nullptr) {
23-
g_context = whisper_init_from_file(path_model.c_str());
23+
g_context = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
2424
if (g_context != nullptr) {
2525
return true;
2626
} else {

bindings/ruby/ext/ruby_whisper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
8787
if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) {
8888
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
8989
}
90-
rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path));
90+
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
9191
if (rw->context == nullptr) {
9292
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
9393
}

examples/addon.node/__test__/whisper.spec.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ const whisperParamsMock = {
1111
language: "en",
1212
model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
1313
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
14+
use_gpu: true,
1415
};
1516

1617
describe("Run whisper.node", () => {

examples/addon.node/addon.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ struct whisper_params {
3636
bool print_colors = false;
3737
bool print_progress = false;
3838
bool no_timestamps = false;
39+
bool use_gpu = true;
3940

4041
std::string language = "en";
4142
std::string prompt;
@@ -153,7 +154,9 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
153154

154155
// whisper init
155156

156-
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
157+
struct whisper_context_params cparams;
158+
cparams.use_gpu = params.use_gpu;
159+
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
157160

158161
if (ctx == nullptr) {
159162
fprintf(stderr, "error: failed to initialize whisper context\n");
@@ -315,10 +318,12 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
315318
std::string language = whisper_params.Get("language").As<Napi::String>();
316319
std::string model = whisper_params.Get("model").As<Napi::String>();
317320
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
321+
bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
318322

319323
params.language = language;
320324
params.model = model;
321325
params.fname_inp.emplace_back(input);
326+
params.use_gpu = use_gpu;
322327

323328
Napi::Function callback = info[1].As<Napi::Function>();
324329
Worker* worker = new Worker(callback, params);

examples/addon.node/index.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ const whisperParams = {
1111
language: "en",
1212
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
1313
fname_inp: "../../samples/jfk.wav",
14+
use_gpu: true,
1415
};
1516

1617
const arguments = process.argv.slice(2);

examples/bench.wasm/emscripten.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ EMSCRIPTEN_BINDINGS(bench) {
5757
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
5858
for (size_t i = 0; i < g_contexts.size(); ++i) {
5959
if (g_contexts[i] == nullptr) {
60-
g_contexts[i] = whisper_init_from_file(path_model.c_str());
60+
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
6161
if (g_contexts[i] != nullptr) {
6262
if (g_worker.joinable()) {
6363
g_worker.join();

examples/bench/bench.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ struct whisper_params {
1111
int32_t what = 0; // what to benchmark: 0 - whisper ecoder, 1 - memcpy, 2 - ggml_mul_mat
1212

1313
std::string model = "models/ggml-base.en.bin";
14+
15+
bool use_gpu = true;
1416
};
1517

1618
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -23,9 +25,10 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
2325
whisper_print_usage(argc, argv, params);
2426
exit(0);
2527
}
26-
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
27-
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
28-
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
28+
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
29+
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
30+
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
31+
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
2932
else {
3033
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
3134
whisper_print_usage(argc, argv, params);
@@ -45,6 +48,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
4548
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
4649
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
4750
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
51+
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
4852
fprintf(stderr, " %-7s 0 - whisper\n", "");
4953
fprintf(stderr, " %-7s 1 - memcpy\n", "");
5054
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
@@ -54,7 +58,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
5458
int whisper_bench_full(const whisper_params & params) {
5559
// whisper init
5660

57-
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
61+
struct whisper_context_params cparams;
62+
cparams.use_gpu = params.use_gpu;
63+
64+
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
5865

5966
{
6067
fprintf(stderr, "\n");

examples/command.wasm/emscripten.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ EMSCRIPTEN_BINDINGS(command) {
243243
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
244244
for (size_t i = 0; i < g_contexts.size(); ++i) {
245245
if (g_contexts[i] == nullptr) {
246-
g_contexts[i] = whisper_init_from_file(path_model.c_str());
246+
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
247247
if (g_contexts[i] != nullptr) {
248248
g_running = true;
249249
if (g_worker.joinable()) {

examples/command/command.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ struct whisper_params {
3838
bool print_special = false;
3939
bool print_energy = false;
4040
bool no_timestamps = true;
41+
bool use_gpu = true;
4142

4243
std::string language = "en";
4344
std::string model = "models/ggml-base.en.bin";
@@ -68,6 +69,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
6869
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
6970
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
7071
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
72+
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
7173
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
7274
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
7375
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
@@ -101,6 +103,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
101103
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
102104
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
103105
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
106+
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
104107
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
105108
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
106109
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
@@ -610,7 +613,10 @@ int main(int argc, char ** argv) {
610613

611614
// whisper init
612615

613-
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
616+
struct whisper_context_params cparams;
617+
cparams.use_gpu = params.use_gpu;
618+
619+
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
614620

615621
// print some info about the processing
616622
{

0 commit comments

Comments
 (0)