@@ -125,9 +125,17 @@ static void byteswap_tensor(ggml_tensor * tensor) {
125
125
// ggml helpers
126
126
//
127
127
128
- static void ggml_graph_compute_helper (std::vector<uint8_t > & buf, ggml_cgraph * graph, int n_threads) {
128
+ static void ggml_graph_compute_helper (
129
+ std::vector<uint8_t > & buf,
130
+ ggml_cgraph * graph,
131
+ int n_threads,
132
+ whisper_abort_callback abort_callback,
133
+ void * abort_callback_data) {
129
134
struct ggml_cplan plan = ggml_graph_plan (graph, n_threads);
130
135
136
+ plan.abort_callback = abort_callback;
137
+ plan.abort_callback_data = abort_callback_data;
138
+
131
139
if (plan.work_size > 0 ) {
132
140
buf.resize (plan.work_size );
133
141
plan.work_data = buf.data ();
@@ -1922,7 +1930,9 @@ static bool whisper_encode_internal(
1922
1930
whisper_context & wctx,
1923
1931
whisper_state & wstate,
1924
1932
const int mel_offset,
1925
- const int n_threads) {
1933
+ const int n_threads,
1934
+ whisper_abort_callback abort_callback,
1935
+ void * abort_callback_data) {
1926
1936
const int64_t t_start_us = ggml_time_us ();
1927
1937
1928
1938
// conv
@@ -1936,7 +1946,7 @@ static bool whisper_encode_internal(
1936
1946
ggml_allocr_alloc_graph (alloc, gf);
1937
1947
1938
1948
if (!whisper_encode_external (wstate)) {
1939
- ggml_graph_compute_helper (wstate.work_buffer , gf, n_threads);
1949
+ ggml_graph_compute_helper (wstate.work_buffer , gf, n_threads, abort_callback, abort_callback_data );
1940
1950
}
1941
1951
}
1942
1952
@@ -1955,10 +1965,10 @@ static bool whisper_encode_internal(
1955
1965
ggml_metal_set_n_cb (wstate.ctx_metal , n_threads);
1956
1966
ggml_metal_graph_compute (wstate.ctx_metal , gf);
1957
1967
} else {
1958
- ggml_graph_compute_helper (wstate.work_buffer , gf, n_threads);
1968
+ ggml_graph_compute_helper (wstate.work_buffer , gf, n_threads, abort_callback, abort_callback_data );
1959
1969
}
1960
1970
#else
1961
- ggml_graph_compute_helper (wstate.work_buffer , gf, n_threads);
1971
+ ggml_graph_compute_helper (wstate.work_buffer , gf, n_threads, abort_callback, abort_callback_data );
1962
1972
#endif
1963
1973
}
1964
1974
@@ -1977,10 +1987,10 @@ static bool whisper_encode_internal(
1977
1987
ggml_metal_set_n_cb (wstate.ctx_metal , n_threads);
1978
1988
ggml_metal_graph_compute (wstate.ctx_metal , gf);
1979
1989
} else {
1980
- ggml_graph_compute_helper (wstate.work_buffer , gf, n_threads);
1990
+ ggml_graph_compute_helper (wstate.work_buffer , gf, n_threads, abort_callback, abort_callback_data );
1981
1991
}
1982
1992
#else
1983
- ggml_graph_compute_helper (wstate.work_buffer , gf, n_threads);
1993
+ ggml_graph_compute_helper (wstate.work_buffer , gf, n_threads, abort_callback, abort_callback_data );
1984
1994
#endif
1985
1995
}
1986
1996
@@ -2346,7 +2356,9 @@ static bool whisper_decode_internal(
2346
2356
const whisper_token * tokens,
2347
2357
const int n_tokens,
2348
2358
const int n_past,
2349
- const int n_threads) {
2359
+ const int n_threads,
2360
+ whisper_abort_callback abort_callback,
2361
+ void * abort_callback_data) {
2350
2362
const int64_t t_start_us = ggml_time_us ();
2351
2363
2352
2364
const auto & model = wctx.model ;
@@ -2375,10 +2387,10 @@ static bool whisper_decode_internal(
2375
2387
ggml_metal_set_n_cb (wstate.ctx_metal , n_threads);
2376
2388
ggml_metal_graph_compute (wstate.ctx_metal , gf);
2377
2389
} else {
2378
- ggml_graph_compute_helper (wstate.work_buffer , gf, n_threads);
2390
+ ggml_graph_compute_helper (wstate.work_buffer , gf, n_threads, abort_callback, abort_callback_data );
2379
2391
}
2380
2392
#else
2381
- ggml_graph_compute_helper (wstate.work_buffer , gf, n_threads);
2393
+ ggml_graph_compute_helper (wstate.work_buffer , gf, n_threads, abort_callback, abort_callback_data );
2382
2394
#endif
2383
2395
}
2384
2396
@@ -3290,7 +3302,7 @@ int whisper_set_mel(
3290
3302
}
3291
3303
3292
3304
int whisper_encode_with_state (struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
3293
- if (!whisper_encode_internal (*ctx, *state, offset, n_threads)) {
3305
+ if (!whisper_encode_internal (*ctx, *state, offset, n_threads, nullptr , nullptr )) {
3294
3306
log (" %s: failed to eval\n " , __func__);
3295
3307
return -1 ;
3296
3308
}
@@ -3299,7 +3311,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
3299
3311
}
3300
3312
3301
3313
int whisper_encode (struct whisper_context * ctx, int offset, int n_threads) {
3302
- if (!whisper_encode_internal (*ctx, *ctx->state , offset, n_threads)) {
3314
+ if (!whisper_encode_internal (*ctx, *ctx->state , offset, n_threads, nullptr , nullptr )) {
3303
3315
log (" %s: failed to eval\n " , __func__);
3304
3316
return -1 ;
3305
3317
}
@@ -3310,7 +3322,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
3310
3322
int whisper_decode_with_state (struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
3311
3323
const int selected_decoder_id = 0 ;
3312
3324
3313
- if (!whisper_decode_internal (*ctx, *state, state->decoders [selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
3325
+ if (!whisper_decode_internal (*ctx, *state, state->decoders [selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr , nullptr )) {
3314
3326
log (" %s: failed to eval\n " , __func__);
3315
3327
return 1 ;
3316
3328
}
@@ -3327,7 +3339,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
3327
3339
return false ;
3328
3340
}
3329
3341
3330
- if (!whisper_decode_internal (*ctx, *ctx->state , ctx->state ->decoders [selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
3342
+ if (!whisper_decode_internal (*ctx, *ctx->state , ctx->state ->decoders [selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr , nullptr )) {
3331
3343
log (" %s: failed to eval\n " , __func__);
3332
3344
return 1 ;
3333
3345
}
@@ -4594,7 +4606,7 @@ int whisper_full_with_state(
4594
4606
}
4595
4607
4596
4608
// encode audio features starting at offset seek
4597
- if (!whisper_encode_internal (*ctx, *state, seek, params.n_threads )) {
4609
+ if (!whisper_encode_internal (*ctx, *state, seek, params.n_threads , params. abort_callback , params. abort_callback_user_data )) {
4598
4610
log (" %s: failed to encode\n " , __func__);
4599
4611
return -6 ;
4600
4612
}
@@ -4677,7 +4689,7 @@ int whisper_full_with_state(
4677
4689
}
4678
4690
WHISPER_PRINT_DEBUG (" \n\n " );
4679
4691
4680
- if (!whisper_decode_internal (*ctx, *state, state->decoders [0 ], prompt.data (), prompt.size (), 0 , params.n_threads )) {
4692
+ if (!whisper_decode_internal (*ctx, *state, state->decoders [0 ], prompt.data (), prompt.size (), 0 , params.n_threads , params. abort_callback , params. abort_callback_user_data )) {
4681
4693
log (" %s: failed to decode\n " , __func__);
4682
4694
return -7 ;
4683
4695
}
@@ -4901,7 +4913,7 @@ int whisper_full_with_state(
4901
4913
4902
4914
// WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
4903
4915
4904
- if (!whisper_decode_internal (*ctx, *state, decoder, decoder.tokens_tmp .data (), decoder.tokens_tmp .size (), decoder.kv_self .n , params.n_threads )) {
4916
+ if (!whisper_decode_internal (*ctx, *state, decoder, decoder.tokens_tmp .data (), decoder.tokens_tmp .size (), decoder.kv_self .n , params.n_threads , params. abort_callback , params. abort_callback_user_data )) {
4905
4917
log (" %s: failed to decode\n " , __func__);
4906
4918
return -8 ;
4907
4919
}
@@ -5473,12 +5485,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
5473
5485
double tsum = 0.0 ;
5474
5486
5475
5487
// heat-up
5476
- ggml_graph_compute_helper (work, &gf, n_threads);
5488
+ ggml_graph_compute_helper (work, &gf, n_threads, nullptr , nullptr );
5477
5489
5478
5490
for (int i = 0 ; i < n_max; ++i) {
5479
5491
const int64_t t0 = ggml_time_us ();
5480
5492
5481
- ggml_graph_compute_helper (work, &gf, n_threads);
5493
+ ggml_graph_compute_helper (work, &gf, n_threads, nullptr , nullptr );
5482
5494
5483
5495
const int64_t t1 = ggml_time_us ();
5484
5496
0 commit comments