Skip to content

whisper : try to fix the parallel whisper_state functionality #1479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
.DS_Store

build/
build-coreml/
build-em/
build-debug/
build-release/
Expand Down
2 changes: 1 addition & 1 deletion ggml-metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include <stdbool.h>

// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 16
#define GGML_METAL_MAX_BUFFERS 64
#define GGML_METAL_MAX_COMMAND_BUFFERS 32

struct ggml_tensor;
Expand Down
4 changes: 4 additions & 0 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,10 @@ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {

const int64_t tsize = ggml_nbytes(t);

if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
ctx = t->buffer->backend->context;
}

// find the view that contains the tensor fully
for (int i = 0; i < ctx->n_buffers; ++i) {
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
Expand Down
88 changes: 49 additions & 39 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,6 @@ static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
auto & alloc = allocr.alloc;
auto & meta = allocr.meta;
auto & buffer = allocr.buffer;

alloc = ggml_allocr_new_measure_from_backend(backend);

Expand All @@ -659,6 +658,11 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backe
}

static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
if (allocr.alloc == nullptr) {
// this can be null if we use external encoder like CoreML or OpenVINO
return;
}

auto & alloc = allocr.alloc;
auto & buffer = allocr.buffer;

Expand Down Expand Up @@ -702,6 +706,8 @@ struct whisper_state {
// buffer for swapping KV caches between decoders during beam-search
std::vector<kv_buf> kv_swap_bufs;

ggml_backend_t backend = nullptr;

// ggml-alloc:
// - stores meta info about the intermediate tensors into the `meta` buffers
// - stores the actual tensor data into the `data` buffers
Expand Down Expand Up @@ -881,6 +887,37 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
}
}

static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
ggml_backend_t backend_gpu = NULL;

// initialize the backends
#ifdef GGML_USE_CUBLAS
if (params.use_gpu) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
backend_gpu = ggml_backend_cuda_init();
if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
}
}
#endif

#ifdef GGML_USE_METAL
if (params.use_gpu) {
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
backend_gpu = ggml_backend_metal_init();
if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
}
}
#endif

if (backend_gpu) {
return backend_gpu;
}
return ggml_backend_cpu_init();
}

// load the model from a ggml file
//
// file format:
Expand Down Expand Up @@ -1299,38 +1336,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
}
}

// init backends
{
ggml_backend_t backend_gpu = NULL;

// initialize the backends
#ifdef GGML_USE_CUBLAS
if (wctx.params.use_gpu) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
backend_gpu = ggml_backend_cuda_init();
if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
}
}
#endif

#ifdef GGML_USE_METAL
if (wctx.params.use_gpu) {
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
backend_gpu = ggml_backend_metal_init();
if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
}
}
#endif

if (backend_gpu) {
wctx.backend = backend_gpu;
} else {
wctx.backend = ggml_backend_cpu_init();
}
}
wctx.backend = whisper_backend_init(wctx.params);

{
size_t size_main = 0;
Expand Down Expand Up @@ -1964,7 +1970,7 @@ static bool whisper_encode_internal(
ggml_allocr_alloc_graph(alloc, gf);

if (!whisper_encode_external(wstate)) {
ggml_graph_compute_helper(wctx.backend, gf, n_threads);
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
}
}

Expand All @@ -1978,7 +1984,7 @@ static bool whisper_encode_internal(

ggml_allocr_alloc_graph(alloc, gf);

ggml_graph_compute_helper(wctx.backend, gf, n_threads);
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
}

// cross
Expand All @@ -1991,7 +1997,7 @@ static bool whisper_encode_internal(

ggml_allocr_alloc_graph(alloc, gf);

ggml_graph_compute_helper(wctx.backend, gf, n_threads);
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
}

wstate.t_encode_us += ggml_time_us() - t_start_us;
Expand Down Expand Up @@ -2382,7 +2388,7 @@ static bool whisper_decode_internal(

logits = gf->nodes[gf->n_nodes - 1];

ggml_graph_compute_helper(wctx.backend, gf, n_threads);
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
}

// extract logits for all N tokens
Expand Down Expand Up @@ -2825,6 +2831,8 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {

whisper_state * state = new whisper_state;

state->backend = whisper_backend_init(ctx->params);

if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
delete state;
Expand Down Expand Up @@ -2922,9 +2930,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
}

whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);

state->rng = std::mt19937(0);
Expand Down Expand Up @@ -3178,6 +3186,8 @@ void whisper_free_state(struct whisper_state * state)
whisper_allocr_free(state->alloc_cross);
whisper_allocr_free(state->alloc_decode);

ggml_backend_free(state->backend);

delete state;
}
}
Expand Down