@@ -102,7 +102,7 @@ static void byteswap_tensor(ggml_tensor * tensor) {
102102#define WHISPER_PRINT_DEBUG (...)
103103#endif
104104
105- #define WHISPER_USE_FLASH_ATTN
105+ // #define WHISPER_USE_FLASH_ATTN
106106// #define WHISPER_USE_FLASH_FF
107107#define WHISPER_MAX_DECODERS 16
108108
@@ -224,11 +224,11 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
224224static const size_t MB = 1ull *1024 *1024 ;
225225
226226static const std::map<e_model, size_t > MEM_REQ_SCRATCH0 = {
227- { MODEL_TINY, 14ull *MB },
228- { MODEL_BASE, 18ull *MB },
229- { MODEL_SMALL, 28ull *MB },
230- { MODEL_MEDIUM, 36ull *MB },
231- { MODEL_LARGE, 44ull *MB },
227+ { MODEL_TINY, 62ull *MB },
228+ { MODEL_BASE, 80ull *MB },
229+ { MODEL_SMALL, 120ull *MB },
230+ { MODEL_MEDIUM, 158ull *MB },
231+ { MODEL_LARGE, 198ull *MB },
232232};
233233
234234static const std::map<e_model, size_t > MEM_REQ_SCRATCH1 = {
@@ -280,11 +280,11 @@ static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
280280};
281281
282282static const std::map<e_model, size_t > MEM_REQ_ENCODE = {
283- { MODEL_TINY, 6ull *MB },
284- { MODEL_BASE, 8ull *MB },
285- { MODEL_SMALL, 13ull *MB },
286- { MODEL_MEDIUM, 22ull *MB },
287- { MODEL_LARGE, 33ull *MB },
283+ { MODEL_TINY, 30ull *MB },
284+ { MODEL_BASE, 38ull *MB },
285+ { MODEL_SMALL, 56ull *MB },
286+ { MODEL_MEDIUM, 74ull *MB },
287+ { MODEL_LARGE, 94ull *MB },
288288};
289289
290290static const std::map<e_model, size_t > MEM_REQ_DECODE = {
@@ -1554,26 +1554,17 @@ static bool whisper_encode_internal(
15541554
15551555 struct ggml_tensor * KQ_soft_max = ggml_soft_max (ctx0, KQ_scaled);
15561556
1557- // struct ggml_tensor * V_trans =
1558- // ggml_permute(ctx0,
1559- // ggml_cpy(ctx0,
1560- // Vcur,
1561- // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1562- // 1, 2, 0, 3);
1563-
1564- // struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1565-
15661557 struct ggml_tensor * V =
15671558 ggml_cpy (ctx0,
15681559 ggml_permute (ctx0,
15691560 ggml_reshape_3d (ctx0,
15701561 Vcur,
15711562 n_state/n_head, n_head, n_ctx),
1572- 0 , 2 , 1 , 3 ),
1573- ggml_new_tensor_3d (ctx0, wctx.wtype , n_state/n_head, n_ctx , n_head)
1563+ 1 , 2 , 0 , 3 ),
1564+ ggml_new_tensor_3d (ctx0, wctx.wtype , n_ctx, n_state/n_head, n_head)
15741565 );
15751566
1576- struct ggml_tensor * KQV = ggml_mul_mat (ctx0, ggml_transpose (ctx0, V) , KQ_soft_max);
1567+ struct ggml_tensor * KQV = ggml_mul_mat (ctx0, V , KQ_soft_max);
15771568#endif
15781569 struct ggml_tensor * KQV_merged = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
15791570
0 commit comments