@@ -603,6 +603,8 @@ struct whisper_context {
603603 // [EXPERIMENTAL] speed-up techniques
604604 int32_t exp_n_audio_ctx; // 0 - use default
605605
606+ std::vector<float > audio_embd;
607+
606608 void use_buf (struct ggml_context * ctx, int i) {
607609#if defined(WHISPER_USE_SCRATCH)
608610 size_t last_size = 0 ;
@@ -1707,18 +1709,34 @@ static bool whisper_encode(
17071709 }
17081710
17091711 // cur
1710- // {
1711- // printf("ne0 = %d\n", cur->ne[0]);
1712- // printf("ne1 = %d\n", cur->ne[1]);
1713- // for (int i = 0; i < 10; ++i) {
1714- // printf("%8.4f ", ((float *)(cur->data))[i]);
1715- // }
1716- // printf("... ");
1717- // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1718- // printf("%8.4f ", ((float *)(cur->data))[i]);
1719- // }
1720- // printf("\n");
1721- // }
1712+ {
1713+ // printf("ne0 = %d\n", cur->ne[0]);
1714+ // printf("ne1 = %d\n", cur->ne[1]);
1715+ // for (int i = 0; i < 10; ++i) {
1716+ // printf("%8.4f ", ((float *)(cur->data))[i]);
1717+ // }
1718+ // printf("... ");
1719+ // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
1720+ // printf("%8.4f ", ((float *)(cur->data))[i]);
1721+ // }
1722+ // printf("\n");
1723+ }
1724+
1725+ {
1726+ const int i0 = std::min (mel_offset, mel_inp.n_len );
1727+ const int i1 = std::min (mel_offset + 2 *n_ctx, mel_inp.n_len );
1728+
1729+ printf (" i0 = %d, i1 = %d, (i1 - i0) = %d, embd size = %d\n " , i0, i1, i1 - i0, cur->ne [0 ]);
1730+
1731+ wctx.audio_embd .clear ();
1732+ wctx.audio_embd .resize (cur->ne [0 ], 0 .0f );
1733+ for (int j = 0 ; j < cur->ne [0 ]; ++j) {
1734+ for (int i = i0; i < i1; ++i) {
1735+ wctx.audio_embd [j] += ((float *)(cur->data ))[(i - i0)*cur->ne [0 ] + j];
1736+ }
1737+ wctx.audio_embd [j] /= (i1 - i0);
1738+ }
1739+ }
17221740
17231741 // pre-compute cross-attention memory
17241742 {
@@ -4806,3 +4824,129 @@ static void whisper_exp_compute_token_level_timestamps(
48064824 // }
48074825 // }
48084826}
4827+
4828+ //
4829+ // diarization stuff
4830+ //
4831+
4832+ void whisper_full_cluster_segments (struct whisper_context * ctx) {
4833+ const int n_segments = ctx->result_all .size ();
4834+ printf (" %s: clustering %d segments\n " , __func__, n_segments);
4835+
4836+ const auto mel_len_save = ctx->mel .n_len ;
4837+ printf (" %s: mel_len_save = %d\n " , __func__, mel_len_save);
4838+
4839+ std::vector<std::vector<float >> features (n_segments);
4840+
4841+ for (int i = 0 ; i < n_segments; ++i) {
4842+ const auto & segment_i = ctx->result_all [i];
4843+ printf (" %s: segment %d: t0 = %d, t1 = %d, text = %s\n " , __func__, i, (int ) segment_i.t0 , (int ) segment_i.t1 , segment_i.text .c_str ());
4844+
4845+ ctx->mel .n_len = segment_i.t1 ;
4846+ whisper_encode (ctx, segment_i.t0 , 4 );
4847+
4848+ features[i] = ctx->audio_embd ;
4849+ }
4850+
4851+ const int n_features = features[0 ].size ();
4852+
4853+ // fuzzy c-means clustering
4854+ const int n_clusters = 4 ;
4855+
4856+ std::vector<std::vector<float >> centroids (n_clusters, std::vector<float >(n_features, 0.0 ));
4857+ std::vector<std::vector<float >> membership (n_segments, std::vector<float >(n_clusters, 0.0 ));
4858+
4859+ // initialize the centroids
4860+ for (int i = 0 ; i < n_clusters; ++i) {
4861+ for (int j = 0 ; j < n_features; ++j) {
4862+ centroids[i][j] = features[i][j];
4863+ }
4864+ }
4865+
4866+ // initialize the membership
4867+ for (int i = 0 ; i < n_segments; ++i) {
4868+ membership[i][i % n_clusters] = 1.0 ;
4869+ }
4870+
4871+ // iterate
4872+ for (int i = 0 ; i < 100 ; ++i) {
4873+ // update the centroids
4874+ for (int j = 0 ; j < n_clusters; ++j) {
4875+ for (int k = 0 ; k < n_features; ++k) {
4876+ centroids[j][k] = 0.0 ;
4877+ }
4878+ }
4879+
4880+ for (int j = 0 ; j < n_segments; ++j) {
4881+ for (int k = 0 ; k < n_clusters; ++k) {
4882+ for (int l = 0 ; l < n_features; ++l) {
4883+ centroids[k][l] += membership[j][k]*features[j][l];
4884+ }
4885+ }
4886+ }
4887+
4888+ for (int j = 0 ; j < n_clusters; ++j) {
4889+ float sum = 0.0 ;
4890+ for (int k = 0 ; k < n_segments; ++k) {
4891+ sum += membership[k][j];
4892+ }
4893+
4894+ for (int k = 0 ; k < n_features; ++k) {
4895+ centroids[j][k] /= sum;
4896+ }
4897+ }
4898+
4899+ // update the membership
4900+ for (int j = 0 ; j < n_segments; ++j) {
4901+ for (int k = 0 ; k < n_clusters; ++k) {
4902+ float sum = 0.0 ;
4903+ for (int l = 0 ; l < n_clusters; ++l) {
4904+ // sum += std::pow(whisper_distance(features[j], centroids[k])/whisper_distance(features[j], centroids[l]), 2.0/(2.0 - 1.0));
4905+
4906+ // use the euclidean distance
4907+ double d0 = 0.0 ;
4908+ for (int m = 0 ; m < n_features; ++m) {
4909+ d0 += std::pow (features[j][m] - centroids[k][m], 2.0 );
4910+ }
4911+ d0 = std::sqrt (d0);
4912+
4913+ double d1 = 0.0 ;
4914+ for (int m = 0 ; m < n_features; ++m) {
4915+ d1 += std::pow (features[j][m] - centroids[l][m], 2.0 );
4916+ }
4917+ d1 = std::sqrt (d1);
4918+ if (d1 == 0.0 ) {
4919+ sum += 1.0 ;
4920+ } else {
4921+ sum += std::pow (d0/d1, 2.0 /(2.0 - 1.0 ));
4922+ }
4923+ }
4924+
4925+ membership[j][k] = 1.0 /sum;
4926+ }
4927+ }
4928+
4929+ // print the membership
4930+ for (int i = 0 ; i < n_segments; ++i) {
4931+ printf (" %s: membership %d: " , __func__, i);
4932+ for (int j = 0 ; j < n_clusters; ++j) {
4933+ printf (" %f " , membership[i][j]);
4934+ }
4935+ printf (" '%s'\n " , ctx->result_all [i].text .c_str ());
4936+ }
4937+ printf (" ----------------\n " );
4938+ }
4939+
4940+ // print the centroids
4941+ // for (int i = 0; i < n_clusters; ++i) {
4942+ // printf("%s: centroid %d: ", __func__, i);
4943+ // for (int j = 0; j < n_features; ++j) {
4944+ // printf("%f ", centroids[i][j]);
4945+ // }
4946+ // printf("\n");
4947+ // }
4948+
4949+ // restore the mel length
4950+ ctx->mel .n_len = mel_len_save;
4951+ }
4952+
0 commit comments