Skip to content

Commit 4d684b3

Browse files
ggerganovrock3125
authored andcommitted
whisper : reduce memory usage during inference (ggml-org#431)
* ggml : add "scratch" buffer support * ggml : support for scratch ring-buffer * ggml : bug fix in ggml_repeat() * ggml : error on scratch buffer overflow * whisper : use scratch buffers during inference (base model only) * whisper : update memory usage for all models * whisper : fix encoder memory usage * whisper : use whisper_context functions instead of macros * whisper : fix FF + remove it from README * ggml : reuse ggml_new_i32 * ggml : refactor the scratch buffer storage * whisper : reorder scratch buffers in the decoder * main : add option to disable temp fallback * Update README.md
1 parent 5a33c04 commit 4d684b3

File tree

7 files changed

+649
-419
lines changed

7 files changed

+649
-419
lines changed

README.md

Lines changed: 109 additions & 98 deletions
Large diffs are not rendered by default.

bindings/javascript/whisper.js

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/main/README.md

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,35 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
99
usage: ./main [options] file0.wav file1.wav ...
1010
1111
options:
12-
-h, --help [default] show this help message and exit
13-
-t N, --threads N [4 ] number of threads to use during computation
14-
-p N, --processors N [1 ] number of processors to use during computation
15-
-ot N, --offset-t N [0 ] time offset in milliseconds
16-
-on N, --offset-n N [0 ] segment index offset
17-
-d N, --duration N [0 ] duration of audio to process in milliseconds
18-
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
19-
-ml N, --max-len N [0 ] maximum segment length in characters
20-
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
21-
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
22-
-tr, --translate [false ] translate from source language to english
23-
-otxt, --output-txt [false ] output result in a text file
24-
-ovtt, --output-vtt [false ] output result in a vtt file
25-
-osrt, --output-srt [false ] output result in a srt file
26-
-owts, --output-words [false ] output script for generating karaoke video
27-
-ps, --print-special [false ] print special tokens
28-
-pc, --print-colors [false ] print colors
29-
-nt, --no-timestamps [true ] do not print timestamps
30-
-l LANG, --language LANG [en ] spoken language
31-
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
32-
-f FNAME, --file FNAME [ ] input WAV file path
12+
-h, --help [default] show this help message and exit
13+
-t N, --threads N [4 ] number of threads to use during computation
14+
-p N, --processors N [1 ] number of processors to use during computation
15+
-ot N, --offset-t N [0 ] time offset in milliseconds
16+
-on N, --offset-n N [0 ] segment index offset
17+
-d N, --duration N [0 ] duration of audio to process in milliseconds
18+
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
19+
-ml N, --max-len N [0 ] maximum segment length in characters
20+
-bo N, --best-of N [5 ] number of best candidates to keep
21+
-bs N, --beam-size N [-1 ] beam size for beam search
22+
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
23+
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
24+
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
25+
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
26+
-tr, --translate [false ] translate from source language to english
27+
-di, --diarize [false ] stereo audio diarization
28+
-nf, --no-fallback [false ] do not use temperature fallback while decoding
29+
-otxt, --output-txt [false ] output result in a text file
30+
-ovtt, --output-vtt [false ] output result in a vtt file
31+
-osrt, --output-srt [false ] output result in a srt file
32+
-owts, --output-words [false ] output script for generating karaoke video
33+
-ocsv, --output-csv [false ] output result in a CSV file
34+
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
35+
-ps, --print-special [false ] print special tokens
36+
-pc, --print-colors [false ] print colors
37+
-pp, --print-progress [false ] print progress
38+
-nt, --no-timestamps [true ] do not print timestamps
39+
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
40+
--prompt PROMPT [ ] initial prompt
41+
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
42+
-f FNAME, --file FNAME [ ] input WAV file path
3343
```

examples/main/main.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,22 +53,23 @@ void replace_all(std::string & s, const std::string & search, const std::string
5353
// command-line parameters
5454
struct whisper_params {
5555
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
56-
int32_t n_processors = 1;
57-
int32_t offset_t_ms = 0;
58-
int32_t offset_n = 0;
59-
int32_t duration_ms = 0;
56+
int32_t n_processors = 1;
57+
int32_t offset_t_ms = 0;
58+
int32_t offset_n = 0;
59+
int32_t duration_ms = 0;
6060
int32_t max_context = -1;
61-
int32_t max_len = 0;
62-
int32_t best_of = 5;
61+
int32_t max_len = 0;
62+
int32_t best_of = 5;
6363
int32_t beam_size = -1;
6464

65-
float word_thold = 0.01f;
66-
float entropy_thold = 2.4f;
67-
float logprob_thold = -1.0f;
65+
float word_thold = 0.01f;
66+
float entropy_thold = 2.40f;
67+
float logprob_thold = -1.00f;
6868

6969
bool speed_up = false;
7070
bool translate = false;
7171
bool diarize = false;
72+
bool no_fallback = false;
7273
bool output_txt = false;
7374
bool output_vtt = false;
7475
bool output_srt = false;
@@ -117,6 +118,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
117118
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
118119
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
119120
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
121+
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
120122
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
121123
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
122124
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
@@ -162,6 +164,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
162164
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
163165
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
164166
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
167+
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
165168
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
166169
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
167170
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
@@ -514,7 +517,7 @@ int main(int argc, char ** argv) {
514517

515518
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
516519
const auto fname_inp = params.fname_inp[f];
517-
const auto fname_outp = f < params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f];
520+
const auto fname_outp = f < (int) params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f];
518521

519522
std::vector<float> pcmf32; // mono-channel F32 PCM
520523
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
@@ -647,17 +650,19 @@ int main(int argc, char ** argv) {
647650

648651
wparams.token_timestamps = params.output_wts || params.max_len > 0;
649652
wparams.thold_pt = params.word_thold;
650-
wparams.entropy_thold = params.entropy_thold;
651-
wparams.logprob_thold = params.logprob_thold;
652653
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
653654

654655
wparams.speed_up = params.speed_up;
655656

657+
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
658+
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
659+
656660
wparams.greedy.best_of = params.best_of;
657661
wparams.beam_search.beam_size = params.beam_size;
658662

659-
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
660-
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
663+
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
664+
wparams.entropy_thold = params.entropy_thold;
665+
wparams.logprob_thold = params.logprob_thold;
661666

662667
whisper_print_user_data user_data = { &params, &pcmf32s };
663668

ggml.c

Lines changed: 88 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,7 +1258,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
12581258
//
12591259

12601260
struct ggml_object {
1261-
size_t offset;
1261+
size_t offs;
12621262
size_t size;
12631263

12641264
struct ggml_object * next;
@@ -1284,6 +1284,9 @@ struct ggml_context {
12841284

12851285
struct ggml_object * objects_begin;
12861286
struct ggml_object * objects_end;
1287+
1288+
struct ggml_scratch scratch;
1289+
struct ggml_scratch scratch_save;
12871290
};
12881291

12891292
struct ggml_context_container {
@@ -1346,7 +1349,7 @@ inline static void ggml_critical_section_end(void) {
13461349

13471350
void ggml_print_object(const struct ggml_object * obj) {
13481351
GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n",
1349-
obj->offset, obj->size, (const void *) obj->next);
1352+
obj->offs, obj->size, (const void *) obj->next);
13501353
}
13511354

13521355
void ggml_print_objects(const struct ggml_context * ctx) {
@@ -1542,12 +1545,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
15421545
}
15431546

15441547
*ctx = (struct ggml_context) {
1545-
.mem_size = params.mem_size,
1546-
.mem_buffer = params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
1547-
.mem_buffer_owned = params.mem_buffer ? false : true,
1548-
.n_objects = 0,
1549-
.objects_begin = NULL,
1550-
.objects_end = NULL,
1548+
/*.mem_size =*/ params.mem_size,
1549+
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
1550+
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
1551+
/*.n_objects =*/ 0,
1552+
/*.objects_begin =*/ NULL,
1553+
/*.objects_end =*/ NULL,
1554+
/*.scratch =*/ { 0, 0, NULL, },
1555+
/*.scratch_save =*/ { 0, 0, NULL, },
15511556
};
15521557

15531558
ggml_assert_aligned(ctx->mem_buffer);
@@ -1570,7 +1575,7 @@ void ggml_free(struct ggml_context * ctx) {
15701575
g_state.contexts[i].used = false;
15711576

15721577
GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
1573-
__func__, i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size);
1578+
__func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
15741579

15751580
if (ctx->mem_buffer_owned) {
15761581
free(ctx->mem_buffer);
@@ -1589,7 +1594,15 @@ void ggml_free(struct ggml_context * ctx) {
15891594
}
15901595

15911596
size_t ggml_used_mem(const struct ggml_context * ctx) {
1592-
return ctx->objects_end->offset + ctx->objects_end->size;
1597+
return ctx->objects_end->offs + ctx->objects_end->size;
1598+
}
1599+
1600+
size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) {
1601+
const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0;
1602+
1603+
ctx->scratch = scratch;
1604+
1605+
return result;
15931606
}
15941607

15951608
////////////////////////////////////////////////////////////////////////////////
@@ -1603,9 +1616,9 @@ struct ggml_tensor * ggml_new_tensor_impl(
16031616
// always insert objects at the end of the context's memory pool
16041617
struct ggml_object * obj_cur = ctx->objects_end;
16051618

1606-
const size_t cur_offset = obj_cur == NULL ? 0 : obj_cur->offset;
1607-
const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
1608-
const size_t cur_end = cur_offset + cur_size;
1619+
const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs;
1620+
const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
1621+
const size_t cur_end = cur_offs + cur_size;
16091622

16101623
size_t size_needed = 0;
16111624

@@ -1616,25 +1629,52 @@ struct ggml_tensor * ggml_new_tensor_impl(
16161629
}
16171630
// align to GGML_MEM_ALIGN
16181631
size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN;
1619-
1620-
}
1621-
size_needed += sizeof(struct ggml_tensor);
1622-
1623-
if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
1624-
GGML_PRINT("%s: not enough space in the context's memory pool\n", __func__);
1625-
assert(false);
1626-
return NULL;
16271632
}
16281633

16291634
char * const mem_buffer = ctx->mem_buffer;
1630-
16311635
struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
16321636

1633-
*obj_new = (struct ggml_object) {
1634-
.offset = cur_end + GGML_OBJECT_SIZE,
1635-
.size = size_needed,
1636-
.next = NULL,
1637-
};
1637+
if (ctx->scratch.data == NULL || data != NULL) {
1638+
size_needed += sizeof(struct ggml_tensor);
1639+
1640+
if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
1641+
GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
1642+
__func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
1643+
assert(false);
1644+
return NULL;
1645+
}
1646+
1647+
*obj_new = (struct ggml_object) {
1648+
.offs = cur_end + GGML_OBJECT_SIZE,
1649+
.size = size_needed,
1650+
.next = NULL,
1651+
};
1652+
} else {
1653+
if (ctx->scratch.offs + size_needed > ctx->scratch.size) {
1654+
GGML_PRINT("%s: not enough space in the scratch memory\n", __func__);
1655+
assert(false);
1656+
return NULL;
1657+
}
1658+
1659+
if (cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE > ctx->mem_size) {
1660+
GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
1661+
__func__, cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE, ctx->mem_size);
1662+
assert(false);
1663+
return NULL;
1664+
}
1665+
1666+
data = (char * const) ctx->scratch.data + ctx->scratch.offs;
1667+
1668+
*obj_new = (struct ggml_object) {
1669+
.offs = cur_end + GGML_OBJECT_SIZE,
1670+
.size = sizeof(struct ggml_tensor),
1671+
.next = NULL,
1672+
};
1673+
1674+
//printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed);
1675+
1676+
ctx->scratch.offs += size_needed;
1677+
}
16381678

16391679
if (obj_cur != NULL) {
16401680
obj_cur->next = obj_new;
@@ -1645,9 +1685,9 @@ struct ggml_tensor * ggml_new_tensor_impl(
16451685

16461686
ctx->objects_end = obj_new;
16471687

1648-
//GGML_PRINT_DEBUG("%s: inserted new object at %zu\n", __func__, cur_end);
1688+
//printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
16491689

1650-
struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offset);
1690+
struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offs);
16511691

16521692
ggml_assert_aligned(result);
16531693

@@ -1690,7 +1730,7 @@ struct ggml_tensor * ggml_new_tensor(
16901730
struct ggml_context * ctx,
16911731
enum ggml_type type,
16921732
int n_dims,
1693-
const int* ne) {
1733+
const int * ne) {
16941734
return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
16951735
}
16961736

@@ -1732,16 +1772,26 @@ struct ggml_tensor * ggml_new_tensor_4d(
17321772
}
17331773

17341774
struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
1775+
ctx->scratch_save = ctx->scratch;
1776+
ctx->scratch.data = NULL;
1777+
17351778
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
17361779

1780+
ctx->scratch = ctx->scratch_save;
1781+
17371782
ggml_set_i32(result, value);
17381783

17391784
return result;
17401785
}
17411786

17421787
struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
1788+
ctx->scratch_save = ctx->scratch;
1789+
ctx->scratch.data = NULL;
1790+
17431791
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
17441792

1793+
ctx->scratch = ctx->scratch_save;
1794+
17451795
ggml_set_f32(result, value);
17461796

17471797
return result;
@@ -2350,7 +2400,7 @@ struct ggml_tensor * ggml_repeat(
23502400
result->op = GGML_OP_REPEAT;
23512401
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
23522402
result->src0 = a;
2353-
result->src1 = NULL;
2403+
result->src1 = b;
23542404

23552405
return result;
23562406
}
@@ -2966,9 +3016,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
29663016
// TODO: when implement backward, fix this:
29673017
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
29683018
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
2969-
2970-
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
2971-
((int32_t *) b->data)[0] = n_past;
3019+
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
29723020

29733021
result->op = GGML_OP_DIAG_MASK_INF;
29743022
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -4300,7 +4348,9 @@ static bool ggml_compute_forward_mul_mat_use_blas(
43004348
const int ne1 = dst->ne[1];
43014349

43024350
// TODO: find the optimal values for these
4303-
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ne0 >= 32 && ne1 >= 32 && ne10 >= 32) {
4351+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && (
4352+
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)
4353+
)) {
43044354
//printf("BLAS: %d %d %d\n", ne0, ne1, ne10);
43054355
return true;
43064356
}
@@ -7289,6 +7339,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
72897339
node->n_tasks = 1; // TODO: this actually is doing nothing
72907340
// the threads are still spinning
72917341
cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]);
7342+
//printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
7343+
//printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
7344+
//printf("cur = %zu\n", cur);
72927345
} else {
72937346
cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1);
72947347
}

0 commit comments

Comments
 (0)