@@ -102,7 +102,7 @@ static void byteswap_tensor(ggml_tensor * tensor) {
102102#define  WHISPER_PRINT_DEBUG (...)
103103#endif 
104104
105- #define  WHISPER_USE_FLASH_ATTN 
105+ // #define WHISPER_USE_FLASH_ATTN
106106// #define WHISPER_USE_FLASH_FF
107107#define  WHISPER_MAX_DECODERS  16 
108108
@@ -224,11 +224,11 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
224224static  const  size_t  MB = 1ull *1024 *1024 ;
225225
226226static  const  std::map<e_model, size_t > MEM_REQ_SCRATCH0 = {
227-     { MODEL_TINY,     14ull *MB },
228-     { MODEL_BASE,     18ull *MB },
229-     { MODEL_SMALL,     28ull *MB },
230-     { MODEL_MEDIUM,    36ull *MB },
231-     { MODEL_LARGE,     44ull *MB },
227+     { MODEL_TINY,     62ull *MB },
228+     { MODEL_BASE,     80ull *MB },
229+     { MODEL_SMALL,   120ull *MB },
230+     { MODEL_MEDIUM,  158ull *MB },
231+     { MODEL_LARGE,   198ull *MB },
232232};
233233
234234static  const  std::map<e_model, size_t > MEM_REQ_SCRATCH1 = {
@@ -280,11 +280,11 @@ static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
280280};
281281
282282static  const  std::map<e_model, size_t > MEM_REQ_ENCODE = {
283-     { MODEL_TINY,       6ull *MB },
284-     { MODEL_BASE,       8ull *MB },
285-     { MODEL_SMALL,    13ull *MB },
286-     { MODEL_MEDIUM,   22ull *MB },
287-     { MODEL_LARGE,    33ull *MB },
283+     { MODEL_TINY,     30ull *MB },
284+     { MODEL_BASE,     38ull *MB },
285+     { MODEL_SMALL,    56ull *MB },
286+     { MODEL_MEDIUM,   74ull *MB },
287+     { MODEL_LARGE,    94ull *MB },
288288};
289289
290290static  const  std::map<e_model, size_t > MEM_REQ_DECODE = {
@@ -1554,26 +1554,17 @@ static bool whisper_encode_internal(
15541554
15551555                struct  ggml_tensor  * KQ_soft_max = ggml_soft_max (ctx0, KQ_scaled);
15561556
1557-                 // struct ggml_tensor * V_trans =
1558-                 //     ggml_permute(ctx0,
1559-                 //             ggml_cpy(ctx0,
1560-                 //                 Vcur,
1561-                 //                 ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1562-                 //             1, 2, 0, 3);
1563- 
1564-                 // struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1565- 
15661557                struct  ggml_tensor  * V =
15671558                    ggml_cpy (ctx0,
15681559                            ggml_permute (ctx0,
15691560                                ggml_reshape_3d (ctx0,
15701561                                    Vcur,
15711562                                    n_state/n_head, n_head, n_ctx),
1572-                                 0 , 2 , 1 , 3 ),
1573-                             ggml_new_tensor_3d (ctx0, wctx.wtype , n_state/n_head, n_ctx , n_head)
1563+                                 1 , 2 , 0 , 3 ),
1564+                             ggml_new_tensor_3d (ctx0, wctx.wtype , n_ctx,  n_state/n_head, n_head)
15741565                            );
15751566
1576-                 struct  ggml_tensor  * KQV = ggml_mul_mat (ctx0, ggml_transpose (ctx0, V) , KQ_soft_max);
1567+                 struct  ggml_tensor  * KQV = ggml_mul_mat (ctx0, V , KQ_soft_max);
15771568#endif 
15781569                struct  ggml_tensor  * KQV_merged = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
15791570
0 commit comments