Skip to content

Commit 2f668c3

Browse files
authored
whisper : add abort callback (#1335)
1 parent 08fa348 commit 2f668c3

File tree

2 files changed

+40
-19
lines changed

2 files changed

+40
-19
lines changed

whisper.cpp

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,17 @@ static void byteswap_tensor(ggml_tensor * tensor) {
125125
// ggml helpers
126126
//
127127

128-
static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
128+
static void ggml_graph_compute_helper(
129+
std::vector<uint8_t> & buf,
130+
ggml_cgraph * graph,
131+
int n_threads,
132+
whisper_abort_callback abort_callback,
133+
void * abort_callback_data) {
129134
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
130135

136+
plan.abort_callback = abort_callback;
137+
plan.abort_callback_data = abort_callback_data;
138+
131139
if (plan.work_size > 0) {
132140
buf.resize(plan.work_size);
133141
plan.work_data = buf.data();
@@ -1922,7 +1930,9 @@ static bool whisper_encode_internal(
19221930
whisper_context & wctx,
19231931
whisper_state & wstate,
19241932
const int mel_offset,
1925-
const int n_threads) {
1933+
const int n_threads,
1934+
whisper_abort_callback abort_callback,
1935+
void * abort_callback_data) {
19261936
const int64_t t_start_us = ggml_time_us();
19271937

19281938
// conv
@@ -1936,7 +1946,7 @@ static bool whisper_encode_internal(
19361946
ggml_allocr_alloc_graph(alloc, gf);
19371947

19381948
if (!whisper_encode_external(wstate)) {
1939-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
1949+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
19401950
}
19411951
}
19421952

@@ -1955,10 +1965,10 @@ static bool whisper_encode_internal(
19551965
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
19561966
ggml_metal_graph_compute(wstate.ctx_metal, gf);
19571967
} else {
1958-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
1968+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
19591969
}
19601970
#else
1961-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
1971+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
19621972
#endif
19631973
}
19641974

@@ -1977,10 +1987,10 @@ static bool whisper_encode_internal(
19771987
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
19781988
ggml_metal_graph_compute(wstate.ctx_metal, gf);
19791989
} else {
1980-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
1990+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
19811991
}
19821992
#else
1983-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
1993+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
19841994
#endif
19851995
}
19861996

@@ -2346,7 +2356,9 @@ static bool whisper_decode_internal(
23462356
const whisper_token * tokens,
23472357
const int n_tokens,
23482358
const int n_past,
2349-
const int n_threads) {
2359+
const int n_threads,
2360+
whisper_abort_callback abort_callback,
2361+
void * abort_callback_data) {
23502362
const int64_t t_start_us = ggml_time_us();
23512363

23522364
const auto & model = wctx.model;
@@ -2375,10 +2387,10 @@ static bool whisper_decode_internal(
23752387
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
23762388
ggml_metal_graph_compute(wstate.ctx_metal, gf);
23772389
} else {
2378-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
2390+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
23792391
}
23802392
#else
2381-
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
2393+
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
23822394
#endif
23832395
}
23842396

@@ -3290,7 +3302,7 @@ int whisper_set_mel(
32903302
}
32913303

32923304
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
3293-
if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
3305+
if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
32943306
log("%s: failed to eval\n", __func__);
32953307
return -1;
32963308
}
@@ -3299,7 +3311,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
32993311
}
33003312

33013313
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3302-
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
3314+
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
33033315
log("%s: failed to eval\n", __func__);
33043316
return -1;
33053317
}
@@ -3310,7 +3322,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
33103322
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
33113323
const int selected_decoder_id = 0;
33123324

3313-
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
3325+
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
33143326
log("%s: failed to eval\n", __func__);
33153327
return 1;
33163328
}
@@ -3327,7 +3339,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
33273339
return false;
33283340
}
33293341

3330-
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
3342+
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
33313343
log("%s: failed to eval\n", __func__);
33323344
return 1;
33333345
}
@@ -4594,7 +4606,7 @@ int whisper_full_with_state(
45944606
}
45954607

45964608
// encode audio features starting at offset seek
4597-
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
4609+
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
45984610
log("%s: failed to encode\n", __func__);
45994611
return -6;
46004612
}
@@ -4677,7 +4689,7 @@ int whisper_full_with_state(
46774689
}
46784690
WHISPER_PRINT_DEBUG("\n\n");
46794691

4680-
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
4692+
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
46814693
log("%s: failed to decode\n", __func__);
46824694
return -7;
46834695
}
@@ -4901,7 +4913,7 @@ int whisper_full_with_state(
49014913

49024914
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
49034915

4904-
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
4916+
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
49054917
log("%s: failed to decode\n", __func__);
49064918
return -8;
49074919
}
@@ -5473,12 +5485,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
54735485
double tsum = 0.0;
54745486

54755487
// heat-up
5476-
ggml_graph_compute_helper(work, &gf, n_threads);
5488+
ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
54775489

54785490
for (int i = 0; i < n_max; ++i) {
54795491
const int64_t t0 = ggml_time_us();
54805492

5481-
ggml_graph_compute_helper(work, &gf, n_threads);
5493+
ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
54825494

54835495
const int64_t t1 = ggml_time_us();
54845496

whisper.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,11 @@ extern "C" {
334334
// If it returns false, the computation is aborted
335335
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
336336

337+
// Abort callback
338+
// If not NULL, called before ggml computation
339+
// If it returns true, the computation is aborted
340+
typedef bool (*whisper_abort_callback)(void * user_data);
341+
337342
// Logits filter callback
338343
// Can be used to modify the logits before sampling
339344
// If not NULL, called after applying temperature to logits
@@ -428,6 +433,10 @@ extern "C" {
428433
whisper_encoder_begin_callback encoder_begin_callback;
429434
void * encoder_begin_callback_user_data;
430435

436+
// called each time before ggml computation starts
437+
whisper_abort_callback abort_callback;
438+
void * abort_callback_user_data;
439+
431440
// called by each decoder to filter obtained logits
432441
whisper_logits_filter_callback logits_filter_callback;
433442
void * logits_filter_callback_user_data;

0 commit comments

Comments
 (0)