Skip to content

Commit 97e6ce2

Browse files
committed
feat(go bindings): add VAD and Diarization parameters
1 parent ebbcf3f commit 97e6ce2

File tree

8 files changed

+264
-35
lines changed

8 files changed

+264
-35
lines changed

bindings/go/params.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,44 @@ func (p *Params) SetPrintTimestamps(v bool) {
4747
p.print_timestamps = toBool(v)
4848
}
4949

50+
// Enable tinydiarize speaker turn detection
51+
func (p *Params) SetDiarize(v bool) {
52+
p.tdrz_enable = toBool(v)
53+
}
54+
55+
// Voice Activity Detection (VAD)
56+
func (p *Params) SetVAD(v bool) {
57+
p.vad = toBool(v)
58+
}
59+
60+
func (p *Params) SetVADModelPath(path string) {
61+
p.vad_model_path = C.CString(path)
62+
}
63+
64+
func (p *Params) SetVADThreshold(t float32) {
65+
p.vad_params.threshold = C.float(t)
66+
}
67+
68+
func (p *Params) SetVADMinSpeechMs(ms int) {
69+
p.vad_params.min_speech_duration_ms = C.int(ms)
70+
}
71+
72+
func (p *Params) SetVADMinSilenceMs(ms int) {
73+
p.vad_params.min_silence_duration_ms = C.int(ms)
74+
}
75+
76+
func (p *Params) SetVADMaxSpeechSec(s float32) {
77+
p.vad_params.max_speech_duration_s = C.float(s)
78+
}
79+
80+
func (p *Params) SetVADSpeechPadMs(ms int) {
81+
p.vad_params.speech_pad_ms = C.int(ms)
82+
}
83+
84+
func (p *Params) SetVADSamplesOverlap(sec float32) {
85+
p.vad_params.samples_overlap = C.float(sec)
86+
}
87+
5088
// Set language id
5189
func (p *Params) SetLanguage(lang int) error {
5290
if lang == -1 {

bindings/go/pkg/whisper/consts.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,10 @@ const SampleRate = whisper.SampleRate
2828

2929
// SampleBits is the number of bytes per sample.
3030
const SampleBits = whisper.SampleBits
31+
32+
type SamplingStrategy whisper.SamplingStrategy
33+
34+
const (
35+
SAMPLING_GREEDY SamplingStrategy = SamplingStrategy(whisper.SAMPLING_GREEDY)
36+
SAMPLING_BEAM_SEARCH SamplingStrategy = SamplingStrategy(whisper.SAMPLING_BEAM_SEARCH)
37+
)

bindings/go/pkg/whisper/context.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ type context struct {
1919
Parameters
2020
}
2121

22-
func newContext(model Model, params whisper.Params) (Context, error) {
22+
func newContext(model Model, params Parameters) (Context, error) {
2323
c := new(context)
2424
c.model = model
2525

26-
c.params = newParameters(&params)
26+
c.params = params
2727
c.Parameters = c.params
2828

2929
// allocate isolated state per context
@@ -132,7 +132,7 @@ func (context *context) Process(
132132
context.params.SetSingleSegment(true)
133133
}
134134

135-
lowLevelParams := context.params.WhisperParams()
135+
lowLevelParams := context.params.UnsafeParams()
136136
if lowLevelParams == nil {
137137
return fmt.Errorf("lowLevelParams is nil: %w", ErrInternalAppError)
138138
}
@@ -249,11 +249,12 @@ func (context *context) IsLANG(t Token, lang string) bool {
249249
// State-backed helper functions
250250
func toSegmentFromState(ctx *whisper.Context, st *whisper.State, n int) Segment {
251251
return Segment{
252-
Num: n,
253-
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text_from_state(st, n)),
254-
Start: time.Duration(ctx.Whisper_full_get_segment_t0_from_state(st, n)) * time.Millisecond * 10,
255-
End: time.Duration(ctx.Whisper_full_get_segment_t1_from_state(st, n)) * time.Millisecond * 10,
256-
Tokens: toTokensFromState(ctx, st, n),
252+
Num: n,
253+
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text_from_state(st, n)),
254+
Start: time.Duration(ctx.Whisper_full_get_segment_t0_from_state(st, n)) * time.Millisecond * 10,
255+
End: time.Duration(ctx.Whisper_full_get_segment_t1_from_state(st, n)) * time.Millisecond * 10,
256+
Tokens: toTokensFromState(ctx, st, n),
257+
SpeakerTurnNext: ctx.Whisper_full_get_segment_speaker_turn_next_from_state(st, n),
257258
}
258259
}
259260

bindings/go/pkg/whisper/context_test.go

Lines changed: 94 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func TestSetLanguage(t *testing.T) {
1717
model, err := whisper.New(ModelPath)
1818
assert.NoError(err)
1919
assert.NotNil(model)
20-
defer model.Close()
20+
defer func() { _ = model.Close() }()
2121

2222
context, err := model.NewContext()
2323
assert.NoError(err)
@@ -35,7 +35,7 @@ func TestContextModelIsMultilingual(t *testing.T) {
3535
model, err := whisper.New(ModelPath)
3636
assert.NoError(err)
3737
assert.NotNil(model)
38-
defer model.Close()
38+
defer func() { _ = model.Close() }()
3939

4040
context, err := model.NewContext()
4141
assert.NoError(err)
@@ -54,7 +54,7 @@ func TestLanguage(t *testing.T) {
5454
model, err := whisper.New(ModelPath)
5555
assert.NoError(err)
5656
assert.NotNil(model)
57-
defer model.Close()
57+
defer func() { _ = model.Close() }()
5858

5959
context, err := model.NewContext()
6060
assert.NoError(err)
@@ -72,7 +72,7 @@ func TestProcess(t *testing.T) {
7272

7373
fh, err := os.Open(SamplePath)
7474
assert.NoError(err)
75-
defer fh.Close()
75+
defer func() { _ = fh.Close() }()
7676

7777
// Decode the WAV file - load the full buffer
7878
dec := wav.NewDecoder(fh)
@@ -85,7 +85,7 @@ func TestProcess(t *testing.T) {
8585
model, err := whisper.New(ModelPath)
8686
assert.NoError(err)
8787
assert.NotNil(model)
88-
defer model.Close()
88+
defer func() { _ = model.Close() }()
8989

9090
context, err := model.NewContext()
9191
assert.NoError(err)
@@ -99,7 +99,7 @@ func TestDetectedLanguage(t *testing.T) {
9999

100100
fh, err := os.Open(SamplePath)
101101
assert.NoError(err)
102-
defer fh.Close()
102+
defer func() { _ = fh.Close() }()
103103

104104
// Decode the WAV file - load the full buffer
105105
dec := wav.NewDecoder(fh)
@@ -112,7 +112,7 @@ func TestDetectedLanguage(t *testing.T) {
112112
model, err := whisper.New(ModelPath)
113113
assert.NoError(err)
114114
assert.NotNil(model)
115-
defer model.Close()
115+
defer func() { _ = model.Close() }()
116116

117117
context, err := model.NewContext()
118118
assert.NoError(err)
@@ -139,7 +139,7 @@ func TestContext_ConcurrentProcessing(t *testing.T) {
139139

140140
fh, err := os.Open(SamplePath)
141141
assert.NoError(err)
142-
defer fh.Close()
142+
defer func() { _ = fh.Close() }()
143143

144144
dec := wav.NewDecoder(fh)
145145
buf, err := dec.FullPCMBuffer()
@@ -150,12 +150,12 @@ func TestContext_ConcurrentProcessing(t *testing.T) {
150150
model, err := whisper.New(ModelPath)
151151
assert.NoError(err)
152152
assert.NotNil(model)
153-
defer model.Close()
153+
defer func() { _ = model.Close() }()
154154

155155
ctx, err := model.NewContext()
156156
assert.NoError(err)
157157
assert.NotNil(ctx)
158-
defer ctx.Close()
158+
defer func() { _ = ctx.Close() }()
159159

160160
err = ctx.Process(data, nil, nil, nil)
161161
assert.NoError(err)
@@ -179,7 +179,7 @@ func TestContext_Parallel_DifferentInputs(t *testing.T) {
179179

180180
fh, err := os.Open(SamplePath)
181181
assert.NoError(err)
182-
defer fh.Close()
182+
defer func() { _ = fh.Close() }()
183183

184184
dec := wav.NewDecoder(fh)
185185
buf, err := dec.FullPCMBuffer()
@@ -195,14 +195,14 @@ func TestContext_Parallel_DifferentInputs(t *testing.T) {
195195
model, err := whisper.New(ModelPath)
196196
assert.NoError(err)
197197
assert.NotNil(model)
198-
defer model.Close()
198+
defer func() { _ = model.Close() }()
199199

200200
ctx1, err := model.NewContext()
201201
assert.NoError(err)
202-
defer ctx1.Close()
202+
defer func() { _ = ctx1.Close() }()
203203
ctx2, err := model.NewContext()
204204
assert.NoError(err)
205-
defer ctx2.Close()
205+
defer func() { _ = ctx2.Close() }()
206206

207207
// Run in parallel - each context has isolated whisper_state
208208
var wg sync.WaitGroup
@@ -258,7 +258,7 @@ func TestContext_Close(t *testing.T) {
258258
model, err := whisper.New(ModelPath)
259259
assert.NoError(err)
260260
assert.NotNil(model)
261-
defer model.Close()
261+
defer func() { _ = model.Close() }()
262262

263263
ctx, err := model.NewContext()
264264
assert.NoError(err)
@@ -294,3 +294,82 @@ func Test_Close_Context_of_Closed_Model(t *testing.T) {
294294
require.NoError(t, model.Close())
295295
require.NoError(t, ctx.Close())
296296
}
297+
298+
func TestContext_VAD_And_Diarization_Params_DoNotPanic(t *testing.T) {
299+
assert := assert.New(t)
300+
301+
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
302+
t.Skip("Skipping test, model not found:", ModelPath)
303+
}
304+
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
305+
t.Skip("Skipping test, sample not found:", SamplePath)
306+
}
307+
308+
fh, err := os.Open(SamplePath)
309+
assert.NoError(err)
310+
defer func() { _ = fh.Close() }()
311+
312+
dec := wav.NewDecoder(fh)
313+
buf, err := dec.FullPCMBuffer()
314+
assert.NoError(err)
315+
assert.Equal(uint16(1), dec.NumChans)
316+
data := buf.AsFloat32Buffer().Data
317+
318+
model, err := whisper.New(ModelPath)
319+
assert.NoError(err)
320+
defer func() { _ = model.Close() }()
321+
322+
ctx, err := model.NewContext()
323+
assert.NoError(err)
324+
defer func() { _ = ctx.Close() }()
325+
326+
p := ctx.Params()
327+
p.SetDiarize(true)
328+
p.SetVAD(true)
329+
p.SetVADThreshold(0.5)
330+
p.SetVADMinSpeechMs(200)
331+
p.SetVADMinSilenceMs(100)
332+
p.SetVADMaxSpeechSec(10)
333+
p.SetVADSpeechPadMs(30)
334+
p.SetVADSamplesOverlap(0.02)
335+
336+
err = ctx.Process(data, nil, nil, nil)
337+
assert.NoError(err)
338+
}
339+
340+
func TestContext_SpeakerTurnNext_Field_Present(t *testing.T) {
341+
assert := assert.New(t)
342+
343+
if _, err := os.Stat(ModelPath); os.IsNotExist(err) {
344+
t.Skip("Skipping test, model not found:", ModelPath)
345+
}
346+
if _, err := os.Stat(SamplePath); os.IsNotExist(err) {
347+
t.Skip("Skipping test, sample not found:", SamplePath)
348+
}
349+
350+
fh, err := os.Open(SamplePath)
351+
assert.NoError(err)
352+
defer func() { _ = fh.Close() }()
353+
354+
dec := wav.NewDecoder(fh)
355+
buf, err := dec.FullPCMBuffer()
356+
assert.NoError(err)
357+
assert.Equal(uint16(1), dec.NumChans)
358+
data := buf.AsFloat32Buffer().Data
359+
360+
model, err := whisper.New(ModelPath)
361+
assert.NoError(err)
362+
defer func() { _ = model.Close() }()
363+
364+
ctx, err := model.NewContext()
365+
assert.NoError(err)
366+
defer func() { _ = ctx.Close() }()
367+
368+
err = ctx.Process(data, nil, nil, nil)
369+
assert.NoError(err)
370+
371+
seg, err := ctx.NextSegment()
372+
assert.NoError(err)
373+
t.Logf("SpeakerTurnNext: %v", seg.SpeakerTurnNext)
374+
_ = seg.SpeakerTurnNext // ensure field exists and is readable
375+
}

bindings/go/pkg/whisper/interface.go

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ type TokenIdentifier interface {
4848
IsText(Token) (bool, error)
4949
}
5050

51+
type ParamsConfigure func(Parameters)
52+
5153
// Model is the interface to a whisper model. Create a new model with the
5254
// function whisper.New(string)
5355
type Model interface {
@@ -56,6 +58,11 @@ type Model interface {
5658
// Return a new speech-to-text context.
5759
NewContext() (Context, error)
5860

61+
NewParams(
62+
sampling SamplingStrategy,
63+
configure ParamsConfigure,
64+
) (Parameters, error)
65+
5966
// Return true if the model is multilingual.
6067
IsMultilingual() bool
6168

@@ -94,6 +101,25 @@ type Parameters interface {
94101
SetEntropyThold(t float32)
95102
SetInitialPrompt(prompt string)
96103

104+
SetNoContext(bool)
105+
SetPrintSpecial(bool)
106+
SetPrintProgress(bool)
107+
SetPrintRealtime(bool)
108+
SetPrintTimestamps(bool)
109+
110+
// Diarization (tinydiarize)
111+
SetDiarize(bool)
112+
113+
// Voice Activity Detection (VAD)
114+
SetVAD(bool)
115+
SetVADModelPath(string)
116+
SetVADThreshold(float32)
117+
SetVADMinSpeechMs(int)
118+
SetVADMinSilenceMs(int)
119+
SetVADMaxSpeechSec(float32)
120+
SetVADSpeechPadMs(int)
121+
SetVADSamplesOverlap(float32)
122+
97123
// Set the temperature
98124
SetTemperature(t float32)
99125

@@ -108,7 +134,8 @@ type Parameters interface {
108134
// Getter methods
109135
Language() string
110136
Threads() int
111-
WhisperParams() *whisper.Params
137+
138+
UnsafeParams() *whisper.Params
112139
}
113140

114141
// Context is the speech recognition context.
@@ -231,6 +258,9 @@ type Segment struct {
231258

232259
// The tokens of the segment.
233260
Tokens []Token
261+
262+
// True if the next segment is predicted as a speaker turn (tinydiarize)
263+
SpeakerTurnNext bool
234264
}
235265

236266
// Token is a text or special token

0 commit comments

Comments
 (0)