Skip to content

Disable log and switch flash_attn to go binding. #3200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bindings/go/pkg/whisper/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ var _ Model = (*model)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE

func New(path string) (Model, error) {
func New(path string, flashAttn, disableLog bool) (Model, error) {
model := new(model)
if _, err := os.Stat(path); err != nil {
return nil, err
} else if ctx := whisper.Whisper_init(path); ctx == nil {
} else if ctx := whisper.Whisper_init(path, flashAttn, disableLog); ctx == nil {
return nil, ErrUnableToLoadModel
} else {
model.ctx = ctx
Expand Down
9 changes: 8 additions & 1 deletion bindings/go/whisper.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_
params.progress_callback_user_data = (void*)(ctx);
return params;
}

*/
import "C"

Expand All @@ -69,6 +70,7 @@ type (
TokenData C.struct_whisper_token_data
SamplingStrategy C.enum_whisper_sampling_strategy
Params C.struct_whisper_full_params
ContextParams C.struct_whisper_whisper_context_params
)

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -99,9 +101,14 @@ var (

// Allocates all memory needed for the model and loads the model from the given file.
// Returns NULL on failure.
func Whisper_init(path string) *Context {
func Whisper_init(path string, flashAttn, disableLog bool) *Context {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about adding these context parameters like this. Perhaps instead there should be a method that takes ContextParams as an argument similar to the Java binding?

cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
if disableLog{
C.whisper_log_disable()
}
params := C.whisper_context_default_params()
params.flash_attn = toBool(flashAttn)
if ctx := C.whisper_init_from_file_with_params(cPath, C.whisper_context_default_params()); ctx != nil {
return (*Context)(ctx)
} else {
Expand Down
2 changes: 2 additions & 0 deletions include/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,8 @@ extern "C" {

WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);

WHISPER_API void whisper_log_disable();

// Get the no_speech probability for the specified segment
WHISPER_API float whisper_full_get_segment_no_speech_prob (struct whisper_context * ctx, int i_segment);
WHISPER_API float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state * state, int i_segment);
Expand Down
6 changes: 6 additions & 0 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8887,6 +8887,12 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
ggml_free(gctx);
}

static void cb_log_disable(enum ggml_log_level , const char * , void * ) { }

void whisper_log_disable() {
whisper_log_set(cb_log_disable, NULL);
}

void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
g_state.log_callback_user_data = user_data;
Expand Down