Skip to content

Commit 5d38768

Browse files
committed
vad : add h_state and c_state to whisper_vad_state
This commit adds two new members, h_state and c_state, to the whisper_vad_state structure. These members are used to store the hidden and cell states to avoid having to get and set the LSTM states in the processing. Refs: ggml-org#3065 (comment)
1 parent 26d7435 commit 5d38768

File tree

1 file changed

+40
-46
lines changed

1 file changed

+40
-46
lines changed

src/whisper.cpp

Lines changed: 40 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4418,6 +4418,9 @@ struct whisper_vad_model {
44184418
struct whisper_vad_state {
44194419
std::vector<ggml_backend_t> backends;
44204420

4421+
struct ggml_tensor * h_state;
4422+
struct ggml_tensor * c_state;
4423+
44214424
whisper_sched sched;
44224425
};
44234426

@@ -4588,22 +4591,12 @@ static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0,
45884591

45894592
struct ggml_tensor * x_t = ggml_transpose(ctx0, cur);
45904593

4591-
// Hidden state from previous time step.
4592-
struct ggml_tensor * h_in = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hdim);
4593-
ggml_set_name(h_in, "h_in");
4594-
ggml_set_input(h_in);
4595-
4596-
// Cell state from all previous time steps.
4597-
struct ggml_tensor * c_in = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hdim);
4598-
ggml_set_name(c_in, "c_in");
4599-
ggml_set_input(c_in);
4600-
46014594
// Create operations using the input-to-hidden weights.
46024595
struct ggml_tensor * inp_gate = ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t);
46034596
inp_gate = ggml_add(ctx0, inp_gate, model.lstm_ih_bias);
46044597

46054598
// Create operations using the hidden-to-hidden weights.
4606-
struct ggml_tensor * hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, h_in);
4599+
struct ggml_tensor * hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.state->h_state);
46074600
hid_gate = ggml_add(ctx0, hid_gate, model.lstm_hh_bias);
46084601

46094602
// Create add operation to get preactivations for all gates.
@@ -4624,26 +4617,22 @@ static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0,
46244617

46254618
// Update cell state
46264619
struct ggml_tensor * c_out = ggml_add(ctx0,
4627-
ggml_mul(ctx0, f_t, c_in),
4620+
ggml_mul(ctx0, f_t, vctx.state->c_state),
46284621
ggml_mul(ctx0, i_t, g_t));
4629-
ggml_set_output(c_out);
4630-
ggml_set_name(c_out, "c_out");
4631-
ggml_build_forward_expand(gf, c_out);
4622+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_out, vctx.state->c_state));
46324623

46334624
// Update hidden state
46344625
struct ggml_tensor * out = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_out));
4635-
ggml_set_output(out);
4636-
ggml_set_name(out, "h_out");
4626+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, vctx.state->h_state));
46374627
return out;
46384628
}
46394629

4640-
static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx,
4641-
whisper_vad_state & vstate) {
4630+
static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) {
46424631
const auto & model = vctx.model;
46434632

46444633
struct ggml_init_params params = {
4645-
/*.mem_size =*/ vstate.sched.meta.size(),
4646-
/*.mem_buffer =*/ vstate.sched.meta.data(),
4634+
/*.mem_size =*/ vctx.state->sched.meta.size(),
4635+
/*.mem_buffer =*/ vctx.state->sched.meta.data(),
46474636
/*.no_alloc =*/ true,
46484637
};
46494638

@@ -4681,23 +4670,44 @@ static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx,
46814670
return gf;
46824671
}
46834672

4684-
struct whisper_vad_state * whisper_vad_init_state(whisper_vad_context * ctx) {
4673+
struct whisper_vad_state * whisper_vad_init_state(whisper_vad_context * vctx) {
46854674
whisper_vad_state * state = new whisper_vad_state;
4675+
vctx->state = state;
46864676

46874677
auto whisper_context_params = whisper_context_default_params();
4688-
whisper_context_params.use_gpu = ctx->params.use_gpu;
4689-
whisper_context_params.gpu_device = ctx->params.gpu_device;
4678+
whisper_context_params.use_gpu = vctx->params.use_gpu;
4679+
whisper_context_params.gpu_device = vctx->params.gpu_device;
46904680
state->backends = whisper_backend_init(whisper_context_params);
46914681
if (state->backends.empty()) {
46924682
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
46934683
whisper_vad_free_state(state);
46944684
return nullptr;
46954685
}
46964686

4687+
int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size;
4688+
struct ggml_init_params params = {
4689+
/*.mem_size =*/ size_t(2u*lstm_hidden_size*ggml_tensor_overhead()),
4690+
/*.mem_buffer =*/ NULL,
4691+
/*.no_alloc =*/ true,
4692+
};
4693+
ggml_context * ctx = ggml_init(params);
4694+
if (!ctx) {
4695+
WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__);
4696+
return nullptr;
4697+
}
4698+
4699+
// LSTM Hidden state
4700+
state->h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
4701+
ggml_set_name(state->h_state, "h_state");
4702+
4703+
// LSTM Cell state
4704+
state->c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
4705+
ggml_set_name(state->c_state, "c_state");
4706+
46974707
{
46984708
bool ok = whisper_sched_graph_init(state->sched, state->backends,
46994709
[&]() {
4700-
return whisper_vad_build_graph(*ctx, *state);
4710+
return whisper_vad_build_graph(*vctx);
47014711
});
47024712

47034713
if (!ok) {
@@ -4719,7 +4729,7 @@ struct whisper_vad_context * whisper_vad_init_from_file_with_params(
47194729
return nullptr;
47204730
}
47214731

4722-
ctx->state = whisper_vad_init_state(ctx);
4732+
whisper_vad_init_state(ctx);
47234733
if (!ctx->state) {
47244734
whisper_vad_free(ctx);
47254735
return nullptr;
@@ -5092,7 +5102,6 @@ struct whisper_vad_context * whisper_vad_init_with_params_no_state(struct whispe
50925102
struct whisper_vad_speech whisper_vad_detect_speech(struct whisper_vad_context * vctx,
50935103
const float * pcmf32,
50945104
int n_samples) {
5095-
const int hidden_dim = vctx->model.hparams.lstm_hidden_size;
50965105
int n_chunks = n_samples / vctx->n_window;
50975106
if (n_samples % vctx->n_window != 0) {
50985107
n_chunks += 1; // Add one more chunk for remaining samples.
@@ -5102,28 +5111,20 @@ struct whisper_vad_speech whisper_vad_detect_speech(struct whisper_vad_context *
51025111
WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
51035112
WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks);
51045113

5105-
ggml_cgraph * gf = whisper_vad_build_graph(*vctx, *vctx->state);
5114+
ggml_cgraph * gf = whisper_vad_build_graph(*vctx);
51065115

51075116
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
51085117
WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
51095118
return {};
51105119
}
51115120

51125121
struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame");
5113-
struct ggml_tensor * c_out = ggml_graph_get_tensor(gf, "c_out");
5114-
struct ggml_tensor * h_out = ggml_graph_get_tensor(gf, "h_out");
5115-
struct ggml_tensor * c_in = ggml_graph_get_tensor(gf, "c_in");
5116-
struct ggml_tensor * h_in = ggml_graph_get_tensor(gf, "h_in");
51175122
struct ggml_tensor * prob = ggml_graph_get_tensor(gf, "prob");
5118-
5119-
ggml_set_zero(c_out);
5120-
ggml_set_zero(h_out);
51215123
ggml_set_zero(prob);
5122-
ggml_set_zero(c_in);
5123-
ggml_set_zero(h_in);
51245124

5125-
std::vector<float> h_state(hidden_dim, 0.0f);
5126-
std::vector<float> c_state(hidden_dim, 0.0f);
5125+
// Reset LSTM hidden/cell states
5126+
ggml_set_zero(vctx->state->h_state);
5127+
ggml_set_zero(vctx->state->c_state);
51275128

51285129
float * probs= new float[n_chunks];
51295130
WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks);
@@ -5156,18 +5157,11 @@ struct whisper_vad_speech whisper_vad_detect_speech(struct whisper_vad_context *
51565157
// Set the frame tensor data with the samples.
51575158
ggml_backend_tensor_set(frame, window.data(), 0, ggml_nelements(frame) * sizeof(float));
51585159

5159-
ggml_backend_tensor_set(h_in, h_state.data(), 0, hidden_dim * sizeof(float));
5160-
ggml_backend_tensor_set(c_in, c_state.data(), 0, hidden_dim * sizeof(float));
5161-
51625160
if (!ggml_graph_compute_helper(sched, gf, vctx->n_threads)) {
51635161
WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__);
51645162
break;
51655163
}
51665164

5167-
// Update the LSTM states
5168-
ggml_backend_tensor_get(h_out, h_state.data(), 0, hidden_dim * sizeof(float));
5169-
ggml_backend_tensor_get(c_out, c_state.data(), 0, hidden_dim * sizeof(float));
5170-
51715165
// Get the probability for this chunk.
51725166
ggml_backend_tensor_get(prob, &probs[i], 0, sizeof(float));
51735167

0 commit comments

Comments
 (0)