Skip to content

Refactor live bidi #6870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
fa6fcf4
Temp Stash
daymxn Apr 11, 2025
9654496
Fix audiohelper bug
daymxn Apr 11, 2025
c906509
Rename class
daymxn Apr 11, 2025
909a730
Remove mediadata
daymxn Apr 11, 2025
d3bf812
Remove id changes, but leave a TODO
daymxn Apr 11, 2025
19147e3
Remove testing artifacts
daymxn Apr 11, 2025
61f6f4c
Update APIController.kt
daymxn Apr 11, 2025
6df1ef0
Formatting
daymxn Apr 11, 2025
b3e5bd9
Add docs to AudioHelper
daymxn Apr 11, 2025
28171b4
Update AudioHelper.kt
daymxn Apr 11, 2025
5df85dc
Document accumulateUntil
daymxn Apr 11, 2025
1bed4eb
Add back stopReceiving
daymxn Apr 11, 2025
89ec5a9
Add additional documentation
daymxn Apr 11, 2025
b3b9678
Cleanup javadocs
daymxn Apr 11, 2025
fbae7cb
Use blocking instead of background dispatcher
daymxn Apr 11, 2025
3d390a1
Emit empty buffer if no data is read
daymxn Apr 11, 2025
51d1fec
Add documentation for util methods
daymxn Apr 11, 2025
7de0eb7
Decode setupComplete to json
daymxn Apr 11, 2025
b137cd9
Update javadocs
daymxn Apr 11, 2025
3b408dc
Update java javadocs to match kotlin
daymxn Apr 11, 2025
cc8969a
Update CHANGELOG.md
daymxn Apr 11, 2025
8c99ced
fmt
daymxn Apr 11, 2025
c9dda5c
Add missing copyright
daymxn Apr 11, 2025
a0cb879
Update api.txt
daymxn Apr 11, 2025
b79f27f
Add back MediaData
daymxn Apr 14, 2025
d534ef0
Update CHANGELOG.md
daymxn Apr 14, 2025
73a0e3b
Merge branch 'main' into daymon-update-bidi
daymxn Apr 14, 2025
97d46cb
Use ByteArrayOutputStream
daymxn Apr 15, 2025
9fae0b8
Add catching for exceptions
daymxn Apr 15, 2025
0df186c
Handle the return value on write to playback track
daymxn Apr 15, 2025
b90e10d
Add note about startAudioConversation and sendFunctionResponse
daymxn Apr 15, 2025
f40dcfa
Catch recorder.stop exception
daymxn Apr 15, 2025
795dcfe
Add docs for exceptions thrown
daymxn Apr 15, 2025
e6475c7
Use a fold instead of a collect
daymxn Apr 15, 2025
b7b0f13
Merge branch 'main' into daymon-update-bidi
daymxn Apr 15, 2025
2e2892a
Update AudioHelper.kt
daymxn Apr 15, 2025
87f1ffa
Update android.kt
daymxn Apr 15, 2025
1ff1618
Update LiveSession.kt
daymxn Apr 15, 2025
644a50b
Update android.kt
daymxn Apr 15, 2025
0a1a48a
Merge branch 'main' into daymon-update-bidi
daymxn Apr 15, 2025
8f199ef
Revert "Update android.kt"
daymxn Apr 15, 2025
e38febd
Update android.kt
daymxn Apr 15, 2025
ab359fb
Update LiveSessionFutures.kt
daymxn Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions firebase-vertexai/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
* [feature] Added support for `HarmBlockThreshold.OFF`. See the
[model documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters#how_to_configure_content_filters){: .external}
for more information.
* [fixed] Improved thread usage when using a `LiveGenerativeModel`. (#6870)
* [fixed] Fixed an issue with `LiveContentResponse` audio data not being present when the model was
interrupted or the turn completed. (#6870)
* [fixed] Fixed an issue with `LiveSession` not converting exceptions to `FirebaseVertexAIException`. (#6870)


# 16.3.0
* [feature] Emits a warning when attempting to use an incompatible model with
Expand Down
2 changes: 1 addition & 1 deletion firebase-vertexai/api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ package com.google.firebase.vertexai.type {
method public suspend Object? send(String text, kotlin.coroutines.Continuation<? super kotlin.Unit>);
method public suspend Object? sendFunctionResponse(java.util.List<com.google.firebase.vertexai.type.FunctionResponsePart> functionList, kotlin.coroutines.Continuation<? super kotlin.Unit>);
method public suspend Object? sendMediaStream(java.util.List<com.google.firebase.vertexai.type.MediaData> mediaChunks, kotlin.coroutines.Continuation<? super kotlin.Unit>);
method public suspend Object? startAudioConversation(kotlin.jvm.functions.Function1<? super com.google.firebase.vertexai.type.FunctionCallPart,com.google.firebase.vertexai.type.FunctionResponsePart>? functionCallHandler = null, kotlin.coroutines.Continuation<? super kotlin.Unit>);
method @RequiresPermission(android.Manifest.permission.RECORD_AUDIO) public suspend Object? startAudioConversation(kotlin.jvm.functions.Function1<? super com.google.firebase.vertexai.type.FunctionCallPart,com.google.firebase.vertexai.type.FunctionResponsePart>? functionCallHandler = null, kotlin.coroutines.Continuation<? super kotlin.Unit>);
method public void stopAudioConversation();
method public void stopReceiving();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.google.firebase.vertexai
import android.util.Log
import com.google.firebase.Firebase
import com.google.firebase.FirebaseApp
import com.google.firebase.annotations.concurrent.Background
import com.google.firebase.annotations.concurrent.Blocking
import com.google.firebase.app
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
import com.google.firebase.auth.internal.InternalAuthProvider
Expand All @@ -41,7 +41,7 @@ import kotlin.coroutines.CoroutineContext
public class FirebaseVertexAI
internal constructor(
private val firebaseApp: FirebaseApp,
@Background private val backgroundDispatcher: CoroutineContext,
@Blocking private val blockingDispatcher: CoroutineContext,
private val location: String,
private val appCheckProvider: Provider<InteropAppCheckTokenProvider>,
private val internalAuthProvider: Provider<InternalAuthProvider>,
Expand Down Expand Up @@ -133,7 +133,7 @@ internal constructor(
"projects/${firebaseApp.options.projectId}/locations/${location}/publishers/google/models/${modelName}",
firebaseApp.options.apiKey,
firebaseApp,
backgroundDispatcher,
blockingDispatcher,
generationConfig,
tools,
systemInstruction,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.google.firebase.vertexai

import androidx.annotation.GuardedBy
import com.google.firebase.FirebaseApp
import com.google.firebase.annotations.concurrent.Background
import com.google.firebase.annotations.concurrent.Blocking
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
import com.google.firebase.auth.internal.InternalAuthProvider
import com.google.firebase.inject.Provider
Expand All @@ -31,7 +31,7 @@ import kotlin.coroutines.CoroutineContext
*/
internal class FirebaseVertexAIMultiResourceComponent(
private val app: FirebaseApp,
@Background val backgroundDispatcher: CoroutineContext,
@Blocking val blockingDispatcher: CoroutineContext,
private val appCheckProvider: Provider<InteropAppCheckTokenProvider>,
private val internalAuthProvider: Provider<InternalAuthProvider>,
) {
Expand All @@ -43,7 +43,7 @@ internal class FirebaseVertexAIMultiResourceComponent(
instances[location]
?: FirebaseVertexAI(
app,
backgroundDispatcher,
blockingDispatcher,
location,
appCheckProvider,
internalAuthProvider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.google.firebase.vertexai

import androidx.annotation.Keep
import com.google.firebase.FirebaseApp
import com.google.firebase.annotations.concurrent.Background
import com.google.firebase.annotations.concurrent.Blocking
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
import com.google.firebase.auth.internal.InternalAuthProvider
import com.google.firebase.components.Component
Expand All @@ -41,13 +41,13 @@ internal class FirebaseVertexAIRegistrar : ComponentRegistrar {
Component.builder(FirebaseVertexAIMultiResourceComponent::class.java)
.name(LIBRARY_NAME)
.add(Dependency.required(firebaseApp))
.add(Dependency.required(backgroundDispatcher))
.add(Dependency.required(blockingDispatcher))
.add(Dependency.optionalProvider(appCheckInterop))
.add(Dependency.optionalProvider(internalAuthProvider))
.factory { container ->
FirebaseVertexAIMultiResourceComponent(
container[firebaseApp],
container.get(backgroundDispatcher),
container.get(blockingDispatcher),
container.getProvider(appCheckInterop),
container.getProvider(internalAuthProvider)
)
Expand All @@ -62,7 +62,7 @@ internal class FirebaseVertexAIRegistrar : ComponentRegistrar {
private val firebaseApp = unqualified(FirebaseApp::class.java)
private val appCheckInterop = unqualified(InteropAppCheckTokenProvider::class.java)
private val internalAuthProvider = unqualified(InternalAuthProvider::class.java)
private val backgroundDispatcher =
Qualified.qualified(Background::class.java, CoroutineDispatcher::class.java)
private val blockingDispatcher =
Qualified.qualified(Blocking::class.java, CoroutineDispatcher::class.java)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
package com.google.firebase.vertexai

import com.google.firebase.FirebaseApp
import com.google.firebase.annotations.concurrent.Background
import com.google.firebase.annotations.concurrent.Blocking
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
import com.google.firebase.auth.internal.InternalAuthProvider
import com.google.firebase.vertexai.common.APIController
import com.google.firebase.vertexai.common.AppCheckHeaderProvider
import com.google.firebase.vertexai.type.BidiGenerateContentClientMessage
import com.google.firebase.vertexai.common.JSON
import com.google.firebase.vertexai.type.Content
import com.google.firebase.vertexai.type.LiveClientSetupMessage
import com.google.firebase.vertexai.type.LiveGenerationConfig
import com.google.firebase.vertexai.type.LiveSession
import com.google.firebase.vertexai.type.PublicPreviewAPI
Expand All @@ -38,6 +39,7 @@ import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject

/**
* Represents a multimodal model (like Gemini) capable of real-time content generation based on
Expand All @@ -47,7 +49,7 @@ import kotlinx.serialization.json.Json
public class LiveGenerativeModel
internal constructor(
private val modelName: String,
@Background private val backgroundDispatcher: CoroutineContext,
@Blocking private val blockingDispatcher: CoroutineContext,
private val config: LiveGenerationConfig? = null,
private val tools: List<Tool>? = null,
private val systemInstruction: Content? = null,
Expand All @@ -58,7 +60,7 @@ internal constructor(
modelName: String,
apiKey: String,
firebaseApp: FirebaseApp,
backgroundDispatcher: CoroutineContext,
blockingDispatcher: CoroutineContext,
config: LiveGenerationConfig? = null,
tools: List<Tool>? = null,
systemInstruction: Content? = null,
Expand All @@ -68,7 +70,7 @@ internal constructor(
internalAuthProvider: InternalAuthProvider? = null,
) : this(
modelName,
backgroundDispatcher,
blockingDispatcher,
config,
tools,
systemInstruction,
Expand All @@ -93,7 +95,7 @@ internal constructor(
@OptIn(ExperimentalSerializationApi::class)
public suspend fun connect(): LiveSession {
val clientMessage =
BidiGenerateContentClientMessage(
LiveClientSetupMessage(
modelName,
config?.toInternal(),
tools?.map { it.toInternal() },
Expand All @@ -104,10 +106,11 @@ internal constructor(
try {
val webSession = controller.getWebSocketSession(location)
webSession.send(Frame.Text(data))
val receivedJson = webSession.incoming.receive().readBytes().toString(Charsets.UTF_8)
// TODO: Try to decode the json instead of string matching.
return if (receivedJson.contains("setupComplete")) {
LiveSession(session = webSession, backgroundDispatcher = backgroundDispatcher)
val receivedJsonStr = webSession.incoming.receive().readBytes().toString(Charsets.UTF_8)
val receivedJson = JSON.parseToJsonElement(receivedJsonStr)

return if (receivedJson is JsonObject && "setupComplete" in receivedJson) {
LiveSession(session = webSession, blockingDispatcher = blockingDispatcher)
} else {
webSession.close()
throw ServiceConnectionHandshakeFailedException("Unable to connect to the server")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ internal constructor(

suspend fun getWebSocketSession(location: String): ClientWebSocketSession =
client.webSocketSession(getBidiEndpoint(location))

fun generateContentStream(
request: GenerateContentRequest
): Flow<GenerateContentResponse.Internal> =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.vertexai.common.util

import android.media.AudioRecord
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.yield

/**
* The minimum buffer size for this instance.
*
* The same as calling [AudioRecord.getMinBufferSize], except the params are pre-populated.
*/
internal val AudioRecord.minBufferSize: Int
get() = AudioRecord.getMinBufferSize(sampleRate, channelConfiguration, audioFormat)

/**
* Reads from this [AudioRecord] and returns the data in a flow.
*
* Will yield when this instance is not recording.
*/
internal fun AudioRecord.readAsFlow() = flow {
val buffer = ByteArray(minBufferSize)

while (true) {
if (recordingState != AudioRecord.RECORDSTATE_RECORDING) {
yield()
continue
}

val bytesRead = read(buffer, 0, buffer.size)
if (bytesRead > 0) {
emit(buffer.copyOf(bytesRead))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@

package com.google.firebase.vertexai.common.util

import java.io.ByteArrayOutputStream
import java.lang.reflect.Field
import kotlin.coroutines.EmptyCoroutineContext
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancel
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.fold

/**
* Removes the last character from the [StringBuilder].
Expand All @@ -39,3 +48,56 @@ internal fun StringBuilder.removeLast(): StringBuilder =
* ```
*/
internal inline fun <reified T : Annotation> Field.getAnnotation() = getAnnotation(T::class.java)

/**
* Collects bytes from this flow and doesn't emit them back until [minSize] is reached.
*
* For example:
* ```
* val byteArr = flowOf(byteArrayOf(1), byteArrayOf(2, 3, 4), byteArrayOf(5, 6, 7, 8))
* val expectedResult = listOf(byteArrayOf(1, 2, 3, 4), byteArrayOf( 5, 6, 7, 8))
*
* byteArr.accumulateUntil(4).toList() shouldContainExactly expectedResult
* ```
*
* @param minSize The minimum about of bytes the array should have before being sent down-stream
* @param emitLeftOvers If the flow completes and there are bytes left over that don't meet the
* [minSize], send them anyways.
*/
internal fun Flow<ByteArray>.accumulateUntil(
minSize: Int,
emitLeftOvers: Boolean = false
): Flow<ByteArray> = flow {
val remaining =
fold(ByteArrayOutputStream()) { buffer, it ->
buffer.apply {
write(it, 0, it.size)
if (size() >= minSize) {
emit(toByteArray())
reset()
}
}
}

if (emitLeftOvers && remaining.size() > 0) {
emit(remaining.toByteArray())
}
}

/**
* Create a [Job] that is a child of the [currentCoroutineContext], if any.
*
* This is useful when you want a coroutine scope to be canceled when its parent scope is canceled,
* and you don't have full control over the parent scope, but you don't want the cancellation of the
* child to impact the parent.
*
* If the parent coroutine context does not have a job, an empty one will be created.
*/
internal suspend inline fun childJob() = Job(currentCoroutineContext()[Job] ?: Job())

/**
* A constant value pointing to a cancelled [CoroutineScope].
*
* Useful when you want to initialize a mutable [CoroutineScope] in a canceled state.
*/
internal val CancelledCoroutineScope = CoroutineScope(EmptyCoroutineContext).apply { cancel() }
Loading
Loading