@@ -2917,59 +2917,63 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
2917
2917
}
2918
2918
2919
2919
#ifdef GGML_USE_METAL
2920
+ // TODO: Param for enable GPU
2920
2921
state->ctx_metal = ggml_metal_init (1 );
2921
2922
if (!state->ctx_metal ) {
2922
2923
log (" %s: ggml_metal_init() failed\n " , __func__);
2923
2924
delete state;
2924
2925
return nullptr ;
2925
2926
}
2926
2927
2927
- log (" %s: Metal context initialized\n " , __func__);
2928
+ if (state->ctx_metal ) {
2929
+ log (" %s: Metal context initialized\n " , __func__);
2928
2930
2929
- // this allocates all Metal resources and memory buffers
2931
+ // this allocates all Metal resources and memory buffers
2930
2932
2931
- void * data_ptr = NULL ;
2932
- size_t data_size = 0 ;
2933
+ void * data_ptr = NULL ;
2934
+ size_t data_size = 0 ;
2933
2935
2934
- // TODO: add mmap support
2935
- // if (params.use_mmap) {
2936
- // data_ptr = ctx->model.mapping->addr;
2937
- // data_size = ctx->model.mapping->size;
2938
- // } else {
2939
- // data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
2940
- // data_size = ggml_get_mem_size (ctx->model.ctx);
2941
- // }
2936
+ // TODO: add mmap support
2937
+ // if (params.use_mmap) {
2938
+ // data_ptr = ctx->model.mapping->addr;
2939
+ // data_size = ctx->model.mapping->size;
2940
+ // } else {
2941
+ // data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
2942
+ // data_size = ggml_get_mem_size (ctx->model.ctx);
2943
+ // }
2942
2944
2943
- data_ptr = ggml_get_mem_buffer (ctx->model .ctx );
2944
- data_size = ggml_get_mem_size (ctx->model .ctx );
2945
+ data_ptr = ggml_get_mem_buffer (ctx->model .ctx );
2946
+ data_size = ggml_get_mem_size (ctx->model .ctx );
2945
2947
2946
- const size_t max_size = ggml_get_max_tensor_size (ctx->model .ctx );
2948
+ const size_t max_size = ggml_get_max_tensor_size (ctx->model .ctx );
2947
2949
2948
- log (" %s: max tensor size = %8.2f MB\n " , __func__, max_size/1024.0 /1024.0 );
2950
+ log (" %s: max tensor size = %8.2f MB\n " , __func__, max_size/1024.0 /1024.0 );
2949
2951
2950
2952
#define WHISPER_METAL_CHECK_BUF (result ) \
2951
- if (!(result)) { \
2952
- log (" %s: failed to add metal buffer\n " , __func__); \
2953
- delete state; \
2954
- return nullptr ; \
2955
- }
2953
+ if (!(result)) { \
2954
+ log (" %s: failed to add metal buffer\n " , __func__); \
2955
+ delete state; \
2956
+ return nullptr ; \
2957
+ }
2956
2958
2957
- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data" , data_ptr, data_size, max_size));
2959
+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data" , data_ptr, data_size, max_size));
2958
2960
2959
- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_conv" , state->alloc_conv .meta .data (), state->alloc_conv .meta .size (), 0 ));
2960
- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_encode" , state->alloc_encode .meta .data (), state->alloc_encode .meta .size (), 0 ));
2961
- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_cross" , state->alloc_cross .meta .data (), state->alloc_cross .meta .size (), 0 ));
2962
- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_decode" , state->alloc_decode .meta .data (), state->alloc_decode .meta .size (), 0 ));
2961
+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_conv" , state->alloc_conv .meta .data (), state->alloc_conv .meta .size (), 0 ));
2962
+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_encode" , state->alloc_encode .meta .data (), state->alloc_encode .meta .size (), 0 ));
2963
+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_cross" , state->alloc_cross .meta .data (), state->alloc_cross .meta .size (), 0 ));
2964
+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_decode" , state->alloc_decode .meta .data (), state->alloc_decode .meta .size (), 0 ));
2963
2965
2964
- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_conv" , state->alloc_conv .data .data (), state->alloc_conv .data .size (), 0 ));
2965
- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_encode" , state->alloc_encode .data .data (), state->alloc_encode .data .size (), 0 ));
2966
- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_cross" , state->alloc_cross .data .data (), state->alloc_cross .data .size (), 0 ));
2967
- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_decode" , state->alloc_decode .data .data (), state->alloc_decode .data .size (), 0 ));
2966
+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_conv" , state->alloc_conv .data .data (), state->alloc_conv .data .size (), 0 ));
2967
+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_encode" , state->alloc_encode .data .data (), state->alloc_encode .data .size (), 0 ));
2968
+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_cross" , state->alloc_cross .data .data (), state->alloc_cross .data .size (), 0 ));
2969
+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_decode" , state->alloc_decode .data .data (), state->alloc_decode .data .size (), 0 ));
2968
2970
2969
- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " kv_cross" , state->kv_cross .buf .data (), state->kv_cross .buf .size (), 0 ));
2971
+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " kv_cross" , state->kv_cross .buf .data (), state->kv_cross .buf .size (), 0 ));
2970
2972
2971
- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " kv_self_0" , state->decoders [0 ].kv_self .buf .data (), state->decoders [0 ].kv_self .buf .size (), 0 ));
2973
+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " kv_self_0" , state->decoders [0 ].kv_self .buf .data (), state->decoders [0 ].kv_self .buf .size (), 0 ));
2972
2974
#undef WHISPER_METAL_CHECK_BUF
2975
+
2976
+ }
2973
2977
#endif
2974
2978
2975
2979
state->rng = std::mt19937 (0 );
@@ -4493,17 +4497,19 @@ int whisper_full_with_state(
4493
4497
4494
4498
// TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
4495
4499
#ifdef GGML_USE_METAL
4500
+ if (state->ctx_metal ) {
4496
4501
#define WHISPER_METAL_CHECK_BUF (result ) \
4497
- if (!(result)) { \
4498
- log (" %s: failed to add metal buffer\n " , __func__); \
4499
- return 0 ; \
4500
- }
4502
+ if (!(result)) { \
4503
+ log (" %s: failed to add metal buffer\n " , __func__); \
4504
+ return 0 ; \
4505
+ }
4501
4506
4502
- const std::string kv_name = " kv_self_" + std::to_string (j);
4503
- auto & kv_self = decoder.kv_self ;
4507
+ const std::string kv_name = " kv_self_" + std::to_string (j);
4508
+ auto & kv_self = decoder.kv_self ;
4504
4509
4505
- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , kv_name.c_str (), kv_self.buf .data (), kv_self.buf .size (), 0 ));
4510
+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , kv_name.c_str (), kv_self.buf .data (), kv_self.buf .size (), 0 ));
4506
4511
#undef WHISPER_METAL_CHECK_BUF
4512
+ }
4507
4513
#endif
4508
4514
}
4509
4515
}
0 commit comments