diff --git a/bindings/go/pkg/whisper/model.go b/bindings/go/pkg/whisper/model.go index 68a150223c7..eae3f46c705 100644 --- a/bindings/go/pkg/whisper/model.go +++ b/bindings/go/pkg/whisper/model.go @@ -23,11 +23,11 @@ var _ Model = (*model)(nil) /////////////////////////////////////////////////////////////////////////////// // LIFECYCLE -func New(path string) (Model, error) { +func New(path string, flashAttn, disableLog bool) (Model, error) { model := new(model) if _, err := os.Stat(path); err != nil { return nil, err - } else if ctx := whisper.Whisper_init(path); ctx == nil { + } else if ctx := whisper.Whisper_init(path, flashAttn, disableLog); ctx == nil { return nil, ErrUnableToLoadModel } else { model.ctx = ctx diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index 525b72d2318..51580562952 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -57,6 +57,7 @@ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_ params.progress_callback_user_data = (void*)(ctx); return params; } + */ import "C" @@ -69,6 +70,7 @@ type ( TokenData C.struct_whisper_token_data SamplingStrategy C.enum_whisper_sampling_strategy Params C.struct_whisper_full_params + ContextParams C.struct_whisper_whisper_context_params ) /////////////////////////////////////////////////////////////////////////////// @@ -99,9 +101,14 @@ var ( // Allocates all memory needed for the model and loads the model from the given file. // Returns NULL on failure. -func Whisper_init(path string) *Context { +func Whisper_init(path string, flashAttn, disableLog bool) *Context { cPath := C.CString(path) defer C.free(unsafe.Pointer(cPath)) + if disableLog{ + C.whisper_log_disable() + } + params := C.whisper_context_default_params() + params.flash_attn = toBool(flashAttn) if ctx := C.whisper_init_from_file_with_params(cPath, C.whisper_context_default_params()); ctx != nil { return (*Context)(ctx) } else { diff --git a/include/whisper.h b/include/whisper.h index 4aeda98f334..043bfd2414d 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -728,6 +728,8 @@ extern "C" { WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data); + WHISPER_API void whisper_log_disable(); + // Get the no_speech probability for the specified segment WHISPER_API float whisper_full_get_segment_no_speech_prob (struct whisper_context * ctx, int i_segment); WHISPER_API float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state * state, int i_segment); diff --git a/src/whisper.cpp b/src/whisper.cpp index cb887d4593b..7978b6303f6 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -8887,6 +8887,12 @@ static void whisper_exp_compute_token_level_timestamps_dtw( ggml_free(gctx); } +static void cb_log_disable(enum ggml_log_level , const char * , void * ) { } + +void whisper_log_disable() { + whisper_log_set(cb_log_disable, NULL); +} + void whisper_log_set(ggml_log_callback log_callback, void * user_data) { g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default; g_state.log_callback_user_data = user_data;