@@ -4418,6 +4418,9 @@ struct whisper_vad_model {
4418
4418
struct whisper_vad_state {
4419
4419
std::vector<ggml_backend_t > backends;
4420
4420
4421
+ struct ggml_tensor * h_state;
4422
+ struct ggml_tensor * c_state;
4423
+
4421
4424
whisper_sched sched;
4422
4425
};
4423
4426
@@ -4588,22 +4591,12 @@ static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0,
4588
4591
4589
4592
struct ggml_tensor * x_t = ggml_transpose (ctx0, cur);
4590
4593
4591
- // Hidden state from previous time step.
4592
- struct ggml_tensor * h_in = ggml_new_tensor_1d (ctx0, GGML_TYPE_F32, hdim);
4593
- ggml_set_name (h_in, " h_in" );
4594
- ggml_set_input (h_in);
4595
-
4596
- // Cell state from all previous time steps.
4597
- struct ggml_tensor * c_in = ggml_new_tensor_1d (ctx0, GGML_TYPE_F32, hdim);
4598
- ggml_set_name (c_in, " c_in" );
4599
- ggml_set_input (c_in);
4600
-
4601
4594
// Create operations using the input-to-hidden weights.
4602
4595
struct ggml_tensor * inp_gate = ggml_mul_mat (ctx0, model.lstm_ih_weight , x_t );
4603
4596
inp_gate = ggml_add (ctx0, inp_gate, model.lstm_ih_bias );
4604
4597
4605
4598
// Create operations using the hidden-to-hidden weights.
4606
- struct ggml_tensor * hid_gate = ggml_mul_mat (ctx0, model.lstm_hh_weight , h_in );
4599
+ struct ggml_tensor * hid_gate = ggml_mul_mat (ctx0, model.lstm_hh_weight , vctx. state -> h_state );
4607
4600
hid_gate = ggml_add (ctx0, hid_gate, model.lstm_hh_bias );
4608
4601
4609
4602
// Create add operation to get preactivations for all gates.
@@ -4624,26 +4617,22 @@ static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0,
4624
4617
4625
4618
// Update cell state
4626
4619
struct ggml_tensor * c_out = ggml_add (ctx0,
4627
- ggml_mul (ctx0, f_t , c_in ),
4620
+ ggml_mul (ctx0, f_t , vctx. state -> c_state ),
4628
4621
ggml_mul (ctx0, i_t , g_t ));
4629
- ggml_set_output (c_out);
4630
- ggml_set_name (c_out, " c_out" );
4631
- ggml_build_forward_expand (gf, c_out);
4622
+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, c_out, vctx.state ->c_state ));
4632
4623
4633
4624
// Update hidden state
4634
4625
struct ggml_tensor * out = ggml_mul (ctx0, o_t , ggml_tanh (ctx0, c_out));
4635
- ggml_set_output (out);
4636
- ggml_set_name (out, " h_out" );
4626
+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, out, vctx.state ->h_state ));
4637
4627
return out;
4638
4628
}
4639
4629
4640
- static struct ggml_cgraph * whisper_vad_build_graph (whisper_vad_context & vctx,
4641
- whisper_vad_state & vstate) {
4630
+ static struct ggml_cgraph * whisper_vad_build_graph (whisper_vad_context & vctx) {
4642
4631
const auto & model = vctx.model ;
4643
4632
4644
4633
struct ggml_init_params params = {
4645
- /* .mem_size =*/ vstate. sched .meta .size (),
4646
- /* .mem_buffer =*/ vstate. sched .meta .data (),
4634
+ /* .mem_size =*/ vctx. state -> sched .meta .size (),
4635
+ /* .mem_buffer =*/ vctx. state -> sched .meta .data (),
4647
4636
/* .no_alloc =*/ true ,
4648
4637
};
4649
4638
@@ -4681,23 +4670,44 @@ static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx,
4681
4670
return gf;
4682
4671
}
4683
4672
4684
- struct whisper_vad_state * whisper_vad_init_state (whisper_vad_context * ctx ) {
4673
+ struct whisper_vad_state * whisper_vad_init_state (whisper_vad_context * vctx ) {
4685
4674
whisper_vad_state * state = new whisper_vad_state;
4675
+ vctx->state = state;
4686
4676
4687
4677
auto whisper_context_params = whisper_context_default_params ();
4688
- whisper_context_params.use_gpu = ctx ->params .use_gpu ;
4689
- whisper_context_params.gpu_device = ctx ->params .gpu_device ;
4678
+ whisper_context_params.use_gpu = vctx ->params .use_gpu ;
4679
+ whisper_context_params.gpu_device = vctx ->params .gpu_device ;
4690
4680
state->backends = whisper_backend_init (whisper_context_params);
4691
4681
if (state->backends .empty ()) {
4692
4682
WHISPER_LOG_ERROR (" %s: whisper_backend_init() failed\n " , __func__);
4693
4683
whisper_vad_free_state (state);
4694
4684
return nullptr ;
4695
4685
}
4696
4686
4687
+ int32_t lstm_hidden_size = vctx->model .hparams .lstm_hidden_size ;
4688
+ struct ggml_init_params params = {
4689
+ /* .mem_size =*/ size_t (2u *lstm_hidden_size*ggml_tensor_overhead ()),
4690
+ /* .mem_buffer =*/ NULL ,
4691
+ /* .no_alloc =*/ true ,
4692
+ };
4693
+ ggml_context * ctx = ggml_init (params);
4694
+ if (!ctx) {
4695
+ WHISPER_LOG_ERROR (" %s: failed to init LSTM state ggml context\n " , __func__);
4696
+ return nullptr ;
4697
+ }
4698
+
4699
+ // LSTM Hidden state
4700
+ state->h_state = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, lstm_hidden_size);
4701
+ ggml_set_name (state->h_state , " h_state" );
4702
+
4703
+ // LSTM Cell state
4704
+ state->c_state = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, lstm_hidden_size);
4705
+ ggml_set_name (state->c_state , " c_state" );
4706
+
4697
4707
{
4698
4708
bool ok = whisper_sched_graph_init (state->sched , state->backends ,
4699
4709
[&]() {
4700
- return whisper_vad_build_graph (*ctx, *state );
4710
+ return whisper_vad_build_graph (*vctx );
4701
4711
});
4702
4712
4703
4713
if (!ok) {
@@ -4719,7 +4729,7 @@ struct whisper_vad_context * whisper_vad_init_from_file_with_params(
4719
4729
return nullptr ;
4720
4730
}
4721
4731
4722
- ctx-> state = whisper_vad_init_state (ctx);
4732
+ whisper_vad_init_state (ctx);
4723
4733
if (!ctx->state ) {
4724
4734
whisper_vad_free (ctx);
4725
4735
return nullptr ;
@@ -5092,7 +5102,6 @@ struct whisper_vad_context * whisper_vad_init_with_params_no_state(struct whispe
5092
5102
struct whisper_vad_speech whisper_vad_detect_speech (struct whisper_vad_context * vctx,
5093
5103
const float * pcmf32,
5094
5104
int n_samples) {
5095
- const int hidden_dim = vctx->model .hparams .lstm_hidden_size ;
5096
5105
int n_chunks = n_samples / vctx->n_window ;
5097
5106
if (n_samples % vctx->n_window != 0 ) {
5098
5107
n_chunks += 1 ; // Add one more chunk for remaining samples.
@@ -5102,28 +5111,20 @@ struct whisper_vad_speech whisper_vad_detect_speech(struct whisper_vad_context *
5102
5111
WHISPER_LOG_INFO (" %s: detecting speech in %d samples\n " , __func__, n_samples);
5103
5112
WHISPER_LOG_INFO (" %s: n_chunks: %d\n " , __func__, n_chunks);
5104
5113
5105
- ggml_cgraph * gf = whisper_vad_build_graph (*vctx, *vctx-> state );
5114
+ ggml_cgraph * gf = whisper_vad_build_graph (*vctx);
5106
5115
5107
5116
if (!ggml_backend_sched_alloc_graph (sched, gf)) {
5108
5117
WHISPER_LOG_ERROR (" %s: failed to allocate the compute buffer\n " , __func__);
5109
5118
return {};
5110
5119
}
5111
5120
5112
5121
struct ggml_tensor * frame = ggml_graph_get_tensor (gf, " frame" );
5113
- struct ggml_tensor * c_out = ggml_graph_get_tensor (gf, " c_out" );
5114
- struct ggml_tensor * h_out = ggml_graph_get_tensor (gf, " h_out" );
5115
- struct ggml_tensor * c_in = ggml_graph_get_tensor (gf, " c_in" );
5116
- struct ggml_tensor * h_in = ggml_graph_get_tensor (gf, " h_in" );
5117
5122
struct ggml_tensor * prob = ggml_graph_get_tensor (gf, " prob" );
5118
-
5119
- ggml_set_zero (c_out);
5120
- ggml_set_zero (h_out);
5121
5123
ggml_set_zero (prob);
5122
- ggml_set_zero (c_in);
5123
- ggml_set_zero (h_in);
5124
5124
5125
- std::vector<float > h_state (hidden_dim, 0 .0f );
5126
- std::vector<float > c_state (hidden_dim, 0 .0f );
5125
+ // Reset LSTM hidden/cell states
5126
+ ggml_set_zero (vctx->state ->h_state );
5127
+ ggml_set_zero (vctx->state ->c_state );
5127
5128
5128
5129
float * probs= new float [n_chunks];
5129
5130
WHISPER_LOG_INFO (" %s: props size: %u\n " , __func__, n_chunks);
@@ -5156,18 +5157,11 @@ struct whisper_vad_speech whisper_vad_detect_speech(struct whisper_vad_context *
5156
5157
// Set the frame tensor data with the samples.
5157
5158
ggml_backend_tensor_set (frame, window.data (), 0 , ggml_nelements (frame) * sizeof (float ));
5158
5159
5159
- ggml_backend_tensor_set (h_in, h_state.data (), 0 , hidden_dim * sizeof (float ));
5160
- ggml_backend_tensor_set (c_in, c_state.data (), 0 , hidden_dim * sizeof (float ));
5161
-
5162
5160
if (!ggml_graph_compute_helper (sched, gf, vctx->n_threads )) {
5163
5161
WHISPER_LOG_ERROR (" %s: failed to compute VAD graph\n " , __func__);
5164
5162
break ;
5165
5163
}
5166
5164
5167
- // Update the LSTM states
5168
- ggml_backend_tensor_get (h_out, h_state.data (), 0 , hidden_dim * sizeof (float ));
5169
- ggml_backend_tensor_get (c_out, c_state.data (), 0 , hidden_dim * sizeof (float ));
5170
-
5171
5165
// Get the probability for this chunk.
5172
5166
ggml_backend_tensor_get (prob, &probs[i], 0 , sizeof (float ));
5173
5167
0 commit comments