Skip to content

Commit 9398498

Browse files
committed
whisper : check state->ctx_metal not null
1 parent 951a119 commit 9398498

File tree

1 file changed

+45
-39
lines changed

1 file changed

+45
-39
lines changed

whisper.cpp

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2917,59 +2917,63 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
29172917
}
29182918

29192919
#ifdef GGML_USE_METAL
2920+
// TODO: Param for enable GPU
29202921
state->ctx_metal = ggml_metal_init(1);
29212922
if (!state->ctx_metal) {
29222923
log("%s: ggml_metal_init() failed\n", __func__);
29232924
delete state;
29242925
return nullptr;
29252926
}
29262927

2927-
log("%s: Metal context initialized\n", __func__);
2928+
if (state->ctx_metal) {
2929+
log("%s: Metal context initialized\n", __func__);
29282930

2929-
// this allocates all Metal resources and memory buffers
2931+
// this allocates all Metal resources and memory buffers
29302932

2931-
void * data_ptr = NULL;
2932-
size_t data_size = 0;
2933+
void * data_ptr = NULL;
2934+
size_t data_size = 0;
29332935

2934-
// TODO: add mmap support
2935-
//if (params.use_mmap) {
2936-
// data_ptr = ctx->model.mapping->addr;
2937-
// data_size = ctx->model.mapping->size;
2938-
//} else {
2939-
// data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
2940-
// data_size = ggml_get_mem_size (ctx->model.ctx);
2941-
//}
2936+
// TODO: add mmap support
2937+
//if (params.use_mmap) {
2938+
// data_ptr = ctx->model.mapping->addr;
2939+
// data_size = ctx->model.mapping->size;
2940+
//} else {
2941+
// data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
2942+
// data_size = ggml_get_mem_size (ctx->model.ctx);
2943+
//}
29422944

2943-
data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
2944-
data_size = ggml_get_mem_size (ctx->model.ctx);
2945+
data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
2946+
data_size = ggml_get_mem_size (ctx->model.ctx);
29452947

2946-
const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
2948+
const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
29472949

2948-
log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
2950+
log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
29492951

29502952
#define WHISPER_METAL_CHECK_BUF(result) \
2951-
if (!(result)) { \
2952-
log("%s: failed to add metal buffer\n", __func__); \
2953-
delete state; \
2954-
return nullptr; \
2955-
}
2953+
if (!(result)) { \
2954+
log("%s: failed to add metal buffer\n", __func__); \
2955+
delete state; \
2956+
return nullptr; \
2957+
}
29562958

2957-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
2959+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size));
29582960

2959-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
2960-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
2961-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
2962-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
2961+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0));
2962+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0));
2963+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0));
2964+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0));
29632965

2964-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
2965-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
2966-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
2967-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
2966+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0));
2967+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0));
2968+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0));
2969+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0));
29682970

2969-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
2971+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0));
29702972

2971-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
2973+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0));
29722974
#undef WHISPER_METAL_CHECK_BUF
2975+
2976+
}
29732977
#endif
29742978

29752979
state->rng = std::mt19937(0);
@@ -4493,17 +4497,19 @@ int whisper_full_with_state(
44934497

44944498
// TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
44954499
#ifdef GGML_USE_METAL
4500+
if (state->ctx_metal) {
44964501
#define WHISPER_METAL_CHECK_BUF(result) \
4497-
if (!(result)) { \
4498-
log("%s: failed to add metal buffer\n", __func__); \
4499-
return 0; \
4500-
}
4502+
if (!(result)) { \
4503+
log("%s: failed to add metal buffer\n", __func__); \
4504+
return 0; \
4505+
}
45014506

4502-
const std::string kv_name = "kv_self_" + std::to_string(j);
4503-
auto & kv_self = decoder.kv_self;
4507+
const std::string kv_name = "kv_self_" + std::to_string(j);
4508+
auto & kv_self = decoder.kv_self;
45044509

4505-
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
4510+
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
45064511
#undef WHISPER_METAL_CHECK_BUF
4512+
}
45074513
#endif
45084514
}
45094515
}

0 commit comments

Comments
 (0)