Skip to content

Commit ecbda54

Browse files
authored
fix(ai): Add audio dispatcher (#7483)
Adds an audio dispatcher for dispatching the recording of the microphone and the playback of audio to separate threads with elevated priorities. The threads are also marked with `detectNetwork` to catch improper usage. This should help avoid weird deadlocks with coroutines and provide a smoother recording/playback experience in apps with higher thread traffic.
1 parent 082c510 commit ecbda54

File tree

3 files changed

+74
-19
lines changed

3 files changed

+74
-19
lines changed

firebase-ai/src/main/kotlin/com/google/firebase/ai/common/util/android.kt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,8 @@
1717
package com.google.firebase.ai.common.util
1818

1919
import android.media.AudioRecord
20-
import kotlin.time.Duration.Companion.milliseconds
2120
import kotlinx.coroutines.delay
22-
import kotlinx.coroutines.flow.callbackFlow
2321
import kotlinx.coroutines.flow.flow
24-
import kotlinx.coroutines.isActive
25-
import kotlinx.coroutines.yield
2622

2723
/**
2824
* The minimum buffer size for this instance.

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/AudioHelper.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,10 @@ internal class AudioHelper(
162162
fun build(): AudioHelper {
163163
val playbackTrack =
164164
AudioTrack(
165-
AudioAttributes.Builder().setUsage(AudioAttributes.USAGE_MEDIA).setContentType(AudioAttributes.CONTENT_TYPE_SPEECH).build(),
165+
AudioAttributes.Builder()
166+
.setUsage(AudioAttributes.USAGE_MEDIA)
167+
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
168+
.build(),
166169
AudioFormat.Builder()
167170
.setSampleRate(24000)
168171
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@
1717
package com.google.firebase.ai.type
1818

1919
import android.Manifest.permission.RECORD_AUDIO
20+
import android.annotation.SuppressLint
2021
import android.content.pm.PackageManager
2122
import android.media.AudioFormat
2223
import android.media.AudioTrack
24+
import android.os.Process
25+
import android.os.StrictMode
26+
import android.os.StrictMode.ThreadPolicy
2327
import android.util.Log
2428
import androidx.annotation.RequiresPermission
2529
import androidx.core.content.ContextCompat
30+
import com.google.firebase.BuildConfig
2631
import com.google.firebase.FirebaseApp
2732
import com.google.firebase.ai.common.JSON
2833
import com.google.firebase.ai.common.util.CancelledCoroutineScope
@@ -33,19 +38,23 @@ import io.ktor.client.plugins.websocket.DefaultClientWebSocketSession
3338
import io.ktor.websocket.Frame
3439
import io.ktor.websocket.close
3540
import io.ktor.websocket.readBytes
36-
import kotlinx.coroutines.CoroutineName
3741
import java.util.concurrent.ConcurrentLinkedQueue
42+
import java.util.concurrent.Executors
43+
import java.util.concurrent.ThreadFactory
3844
import java.util.concurrent.atomic.AtomicBoolean
45+
import java.util.concurrent.atomic.AtomicLong
3946
import kotlin.coroutines.CoroutineContext
47+
import kotlinx.coroutines.CoroutineName
4048
import kotlinx.coroutines.CoroutineScope
41-
import kotlinx.coroutines.Dispatchers
49+
import kotlinx.coroutines.asCoroutineDispatcher
4250
import kotlinx.coroutines.cancel
4351
import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
4452
import kotlinx.coroutines.delay
4553
import kotlinx.coroutines.flow.Flow
4654
import kotlinx.coroutines.flow.buffer
4755
import kotlinx.coroutines.flow.catch
4856
import kotlinx.coroutines.flow.flow
57+
import kotlinx.coroutines.flow.flowOn
4958
import kotlinx.coroutines.flow.launchIn
5059
import kotlinx.coroutines.flow.onCompletion
5160
import kotlinx.coroutines.flow.onEach
@@ -67,11 +76,21 @@ internal constructor(
6776
private val firebaseApp: FirebaseApp,
6877
) {
6978
/**
70-
* Coroutine scope that we batch data on for [startAudioConversation].
79+
* Coroutine scope that we batch data on for network related behavior.
80+
*
81+
* Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope.
82+
*/
83+
private var networkScope = CancelledCoroutineScope
84+
85+
/**
86+
* Coroutine scope that we batch data on for audio recording and playback.
87+
*
88+
* Separate from [networkScope] to ensure interchanging of dispatchers doesn't cause any deadlocks
89+
* or issues.
7190
*
7291
* Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope.
7392
*/
74-
private var scope = CancelledCoroutineScope
93+
private var audioScope = CancelledCoroutineScope
7594

7695
/**
7796
* Playback audio data sent from the model.
@@ -129,16 +148,17 @@ internal constructor(
129148
}
130149

131150
FirebaseAIException.catchAsync {
132-
if (scope.isActive) {
151+
if (networkScope.isActive || audioScope.isActive) {
133152
Log.w(
134153
TAG,
135154
"startAudioConversation called after the recording has already started. " +
136155
"Call stopAudioConversation to close the previous connection."
137156
)
138157
return@catchAsync
139158
}
140-
// TODO: maybe it should be THREAD_PRIORITY_AUDIO anyways for playback and recording (not network though)
141-
scope = CoroutineScope(blockingDispatcher + childJob() + CoroutineName("LiveSession Scope"))
159+
networkScope =
160+
CoroutineScope(blockingDispatcher + childJob() + CoroutineName("LiveSession Network"))
161+
audioScope = CoroutineScope(audioDispatcher + childJob() + CoroutineName("LiveSession Audio"))
142162
audioHelper = AudioHelper.build()
143163

144164
recordUserAudio()
@@ -158,7 +178,8 @@ internal constructor(
158178
FirebaseAIException.catch {
159179
if (!startedReceiving.getAndSet(false)) return@catch
160180

161-
scope.cancel()
181+
networkScope.cancel()
182+
audioScope.cancel()
162183
playBackQueue.clear()
163184

164185
audioHelper?.release()
@@ -228,7 +249,8 @@ internal constructor(
228249
FirebaseAIException.catch {
229250
if (!startedReceiving.getAndSet(false)) return@catch
230251

231-
scope.cancel()
252+
networkScope.cancel()
253+
audioScope.cancel()
232254
playBackQueue.clear()
233255

234256
audioHelper?.release()
@@ -325,21 +347,22 @@ internal constructor(
325347
audioHelper
326348
?.listenToRecording()
327349
?.buffer(UNLIMITED)
350+
?.flowOn(audioDispatcher)
328351
?.accumulateUntil(MIN_BUFFER_SIZE)
329352
?.onEach {
330353
sendMediaStream(listOf(MediaData(it, "audio/pcm")))
331354
delay(0)
332355
}
333356
?.catch { throw FirebaseAIException.from(it) }
334-
?.launchIn(scope)
357+
?.launchIn(networkScope)
335358
}
336359

337360
/**
338361
* Processes responses from the model during an audio conversation.
339362
*
340363
* Audio messages are added to [playBackQueue].
341364
*
342-
* Launched asynchronously on [scope].
365+
* Launched asynchronously on [networkScope].
343366
*
344367
* @param functionCallHandler A callback function that is invoked whenever the server receives a
345368
* function call.
@@ -393,18 +416,18 @@ internal constructor(
393416
}
394417
}
395418
}
396-
.launchIn(scope)
419+
.launchIn(networkScope)
397420
}
398421

399422
/**
400423
* Listens for playback data from the model and plays the audio.
401424
*
402425
* Polls [playBackQueue] for data, and calls [AudioHelper.playAudio] when data is received.
403426
*
404-
* Launched asynchronously on [scope].
427+
* Launched asynchronously on [networkScope].
405428
*/
406429
private fun listenForModelPlayback(enableInterruptions: Boolean = false) {
407-
scope.launch {
430+
audioScope.launch {
408431
while (isActive) {
409432
val playbackData = playBackQueue.poll()
410433
if (playbackData == null) {
@@ -490,5 +513,38 @@ internal constructor(
490513
AudioFormat.CHANNEL_OUT_MONO,
491514
AudioFormat.ENCODING_PCM_16BIT
492515
)
516+
@SuppressLint("ThreadPoolCreation")
517+
val audioDispatcher =
518+
Executors.newCachedThreadPool(AudioThreadFactory()).asCoroutineDispatcher()
519+
}
520+
}
521+
522+
internal class AudioThreadFactory : ThreadFactory {
523+
private val threadCount = AtomicLong()
524+
private val policy: ThreadPolicy = audioPolicy()
525+
526+
override fun newThread(task: Runnable?): Thread? {
527+
val thread =
528+
DEFAULT.newThread {
529+
Process.setThreadPriority(Process.THREAD_PRIORITY_AUDIO)
530+
StrictMode.setThreadPolicy(policy)
531+
task?.run()
532+
}
533+
thread.name = "Firebase Audio Thread #${threadCount.andIncrement}"
534+
return thread
535+
}
536+
537+
companion object {
538+
val DEFAULT: ThreadFactory = Executors.defaultThreadFactory()
539+
540+
private fun audioPolicy(): ThreadPolicy {
541+
val builder = ThreadPolicy.Builder().detectNetwork()
542+
543+
if (BuildConfig.DEBUG) {
544+
builder.penaltyDeath()
545+
}
546+
547+
return builder.penaltyLog().build()
548+
}
493549
}
494550
}

0 commit comments

Comments
 (0)