1717package com.google.firebase.ai.type
1818
1919import android.Manifest.permission.RECORD_AUDIO
20+ import android.annotation.SuppressLint
2021import android.content.pm.PackageManager
2122import android.media.AudioFormat
2223import android.media.AudioTrack
24+ import android.os.Process
25+ import android.os.StrictMode
26+ import android.os.StrictMode.ThreadPolicy
2327import android.util.Log
2428import androidx.annotation.RequiresPermission
2529import androidx.core.content.ContextCompat
30+ import com.google.firebase.BuildConfig
2631import com.google.firebase.FirebaseApp
2732import com.google.firebase.ai.common.JSON
2833import com.google.firebase.ai.common.util.CancelledCoroutineScope
@@ -33,19 +38,23 @@ import io.ktor.client.plugins.websocket.DefaultClientWebSocketSession
3338import io.ktor.websocket.Frame
3439import io.ktor.websocket.close
3540import io.ktor.websocket.readBytes
36- import kotlinx.coroutines.CoroutineName
3741import java.util.concurrent.ConcurrentLinkedQueue
42+ import java.util.concurrent.Executors
43+ import java.util.concurrent.ThreadFactory
3844import java.util.concurrent.atomic.AtomicBoolean
45+ import java.util.concurrent.atomic.AtomicLong
3946import kotlin.coroutines.CoroutineContext
47+ import kotlinx.coroutines.CoroutineName
4048import kotlinx.coroutines.CoroutineScope
41- import kotlinx.coroutines.Dispatchers
49+ import kotlinx.coroutines.asCoroutineDispatcher
4250import kotlinx.coroutines.cancel
4351import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
4452import kotlinx.coroutines.delay
4553import kotlinx.coroutines.flow.Flow
4654import kotlinx.coroutines.flow.buffer
4755import kotlinx.coroutines.flow.catch
4856import kotlinx.coroutines.flow.flow
57+ import kotlinx.coroutines.flow.flowOn
4958import kotlinx.coroutines.flow.launchIn
5059import kotlinx.coroutines.flow.onCompletion
5160import kotlinx.coroutines.flow.onEach
@@ -67,11 +76,21 @@ internal constructor(
6776 private val firebaseApp: FirebaseApp ,
6877) {
6978 /* *
70- * Coroutine scope that we batch data on for [startAudioConversation].
79+ * Coroutine scope that we batch data on for network related behavior.
80+ *
81+ * Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope.
82+ */
83+ private var networkScope = CancelledCoroutineScope
84+
85+ /* *
86+ * Coroutine scope that we batch data on for audio recording and playback.
87+ *
88+ * Separate from [networkScope] to ensure interchanging of dispatchers doesn't cause any deadlocks
89+ * or issues.
7190 *
7291 * Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope.
7392 */
74- private var scope = CancelledCoroutineScope
93+ private var audioScope = CancelledCoroutineScope
7594
7695 /* *
7796 * Playback audio data sent from the model.
@@ -129,16 +148,17 @@ internal constructor(
129148 }
130149
131150 FirebaseAIException .catchAsync {
132- if (scope .isActive) {
151+ if (networkScope.isActive || audioScope .isActive) {
133152 Log .w(
134153 TAG ,
135154 " startAudioConversation called after the recording has already started. " +
136155 " Call stopAudioConversation to close the previous connection."
137156 )
138157 return @catchAsync
139158 }
140- // TODO: maybe it should be THREAD_PRIORITY_AUDIO anyways for playback and recording (not network though)
141- scope = CoroutineScope (blockingDispatcher + childJob() + CoroutineName (" LiveSession Scope" ))
159+ networkScope =
160+ CoroutineScope (blockingDispatcher + childJob() + CoroutineName (" LiveSession Network" ))
161+ audioScope = CoroutineScope (audioDispatcher + childJob() + CoroutineName (" LiveSession Audio" ))
142162 audioHelper = AudioHelper .build()
143163
144164 recordUserAudio()
@@ -158,7 +178,8 @@ internal constructor(
158178 FirebaseAIException .catch {
159179 if (! startedReceiving.getAndSet(false )) return @catch
160180
161- scope.cancel()
181+ networkScope.cancel()
182+ audioScope.cancel()
162183 playBackQueue.clear()
163184
164185 audioHelper?.release()
@@ -228,7 +249,8 @@ internal constructor(
228249 FirebaseAIException .catch {
229250 if (! startedReceiving.getAndSet(false )) return @catch
230251
231- scope.cancel()
252+ networkScope.cancel()
253+ audioScope.cancel()
232254 playBackQueue.clear()
233255
234256 audioHelper?.release()
@@ -325,21 +347,22 @@ internal constructor(
325347 audioHelper
326348 ?.listenToRecording()
327349 ?.buffer(UNLIMITED )
350+ ?.flowOn(audioDispatcher)
328351 ?.accumulateUntil(MIN_BUFFER_SIZE )
329352 ?.onEach {
330353 sendMediaStream(listOf (MediaData (it, " audio/pcm" )))
331354 delay(0 )
332355 }
333356 ?.catch { throw FirebaseAIException .from(it) }
334- ?.launchIn(scope )
357+ ?.launchIn(networkScope )
335358 }
336359
337360 /* *
338361 * Processes responses from the model during an audio conversation.
339362 *
340363 * Audio messages are added to [playBackQueue].
341364 *
342- * Launched asynchronously on [scope ].
365+ * Launched asynchronously on [networkScope ].
343366 *
344367 * @param functionCallHandler A callback function that is invoked whenever the server receives a
345368 * function call.
@@ -393,18 +416,18 @@ internal constructor(
393416 }
394417 }
395418 }
396- .launchIn(scope )
419+ .launchIn(networkScope )
397420 }
398421
399422 /* *
400423 * Listens for playback data from the model and plays the audio.
401424 *
402425 * Polls [playBackQueue] for data, and calls [AudioHelper.playAudio] when data is received.
403426 *
404- * Launched asynchronously on [scope ].
427+ * Launched asynchronously on [networkScope ].
405428 */
406429 private fun listenForModelPlayback (enableInterruptions : Boolean = false) {
407- scope .launch {
430+ audioScope .launch {
408431 while (isActive) {
409432 val playbackData = playBackQueue.poll()
410433 if (playbackData == null ) {
@@ -490,5 +513,38 @@ internal constructor(
490513 AudioFormat .CHANNEL_OUT_MONO ,
491514 AudioFormat .ENCODING_PCM_16BIT
492515 )
516+ @SuppressLint(" ThreadPoolCreation" )
517+ val audioDispatcher =
518+ Executors .newCachedThreadPool(AudioThreadFactory ()).asCoroutineDispatcher()
519+ }
520+ }
521+
522+ internal class AudioThreadFactory : ThreadFactory {
523+ private val threadCount = AtomicLong ()
524+ private val policy: ThreadPolicy = audioPolicy()
525+
526+ override fun newThread (task : Runnable ? ): Thread ? {
527+ val thread =
528+ DEFAULT .newThread {
529+ Process .setThreadPriority(Process .THREAD_PRIORITY_AUDIO )
530+ StrictMode .setThreadPolicy(policy)
531+ task?.run ()
532+ }
533+ thread.name = " Firebase Audio Thread #${threadCount.andIncrement} "
534+ return thread
535+ }
536+
537+ companion object {
538+ val DEFAULT : ThreadFactory = Executors .defaultThreadFactory()
539+
540+ private fun audioPolicy (): ThreadPolicy {
541+ val builder = ThreadPolicy .Builder ().detectNetwork()
542+
543+ if (BuildConfig .DEBUG ) {
544+ builder.penaltyDeath()
545+ }
546+
547+ return builder.penaltyLog().build()
548+ }
493549 }
494550}
0 commit comments