diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt index b199698aa7b..f98d5d3c403 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt @@ -67,6 +67,7 @@ import kotlinx.coroutines.flow.map import kotlinx.coroutines.launch import kotlinx.coroutines.withTimeout import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.json.ClassDiscriminatorMode import kotlinx.serialization.json.Json @OptIn(ExperimentalSerializationApi::class) @@ -75,6 +76,7 @@ internal val JSON = Json { prettyPrint = false isLenient = true explicitNulls = false + classDiscriminatorMode = ClassDiscriminatorMode.NONE } /** diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/FunctionDeclaration.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/FunctionDeclaration.kt index 2b73d5ccfb1..65b753efda7 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/FunctionDeclaration.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/FunctionDeclaration.kt @@ -61,12 +61,12 @@ public class FunctionDeclaration( internal val schema: Schema = Schema.obj(properties = parameters, optionalProperties = optionalParameters, nullable = false) - internal fun toInternal() = Internal(name, description, schema.toInternal()) + internal fun toInternal() = Internal(name, description, schema.toInternalOpenApi()) @Serializable internal data class Internal( val name: String, val description: String, - val parameters: Schema.Internal + val parameters: Schema.InternalOpenAPI ) } diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/GenerationConfig.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/GenerationConfig.kt index 7bab7fdf806..a496098787f 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/GenerationConfig.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/GenerationConfig.kt @@ -200,7 +200,7 @@ private constructor( frequencyPenalty = frequencyPenalty, presencePenalty = presencePenalty, responseMimeType = responseMimeType, - responseSchema = responseSchema?.toInternal(), + responseSchema = responseSchema?.toInternalOpenApi(), responseModalities = responseModalities?.map { it.toInternal() }, thinkingConfig = thinkingConfig?.toInternal() ) @@ -216,7 +216,7 @@ private constructor( @SerialName("response_mime_type") val responseMimeType: String? = null, @SerialName("presence_penalty") val presencePenalty: Float? = null, @SerialName("frequency_penalty") val frequencyPenalty: Float? = null, - @SerialName("response_schema") val responseSchema: Schema.Internal? = null, + @SerialName("response_schema") val responseSchema: Schema.InternalOpenAPI? = null, @SerialName("response_modalities") val responseModalities: List? = null, @SerialName("thinking_config") val thinkingConfig: ThinkingConfig.Internal? = null ) diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt index 4312fd5bdbd..f11436aad56 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt @@ -23,7 +23,6 @@ import java.io.ByteArrayOutputStream import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable -import kotlinx.serialization.SerializationException import kotlinx.serialization.json.JsonContentPolymorphicSerializer import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonNull diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Schema.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Schema.kt index 5f2f6ca9350..1dfa4ddecb0 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Schema.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Schema.kt @@ -322,46 +322,147 @@ internal constructor( public fun anyOf(schemas: List): Schema = Schema(type = "ANYOF", anyOf = schemas) } - internal fun toInternal(): Internal { + internal fun toInternalOpenApi(): InternalOpenAPI { val cleanedType = if (type == "ANYOF") { null } else { type } - return Internal( + return InternalOpenAPI( cleanedType, description, format, nullable, enum, - properties?.mapValues { it.value.toInternal() }, + properties?.mapValues { it.value.toInternalOpenApi() }, required, - items?.toInternal(), + items?.toInternalOpenApi(), title, minItems, maxItems, minimum, maximum, - anyOf?.map { it.toInternal() }, + anyOf?.map { it.toInternalOpenApi() }, + ) + } + + internal fun toInternalJson(): InternalJson { + val outType = + if (type == "ANYOF" || (type == "STRING" && format == "enum")) { + null + } else { + type.lowercase() + } + + val (outMinimum, outMaximum) = + if (outType == "integer" && format == "int32") { + (minimum ?: Integer.MIN_VALUE.toDouble()) to (maximum ?: Integer.MAX_VALUE.toDouble()) + } else { + minimum to maximum + } + + val outFormat = + if ( + (outType == "integer" && format == "int32") || + (outType == "number" && format == "float") || + format == "enum" + ) { + null + } else { + format + } + + if (nullable == true) { + return InternalJsonNullable( + outType?.let { listOf(it, "null") }, + description, + outFormat, + enum?.let { + buildList { + addAll(it) + add("null") + } + }, + properties?.mapValues { it.value.toInternalJson() }, + required, + items?.toInternalJson(), + title, + minItems, + maxItems, + outMinimum, + outMaximum, + anyOf?.map { it.toInternalJson() }, + ) + } + return InternalJsonNonNull( + outType, + description, + outFormat, + enum, + properties?.mapValues { it.value.toInternalJson() }, + required, + items?.toInternalJson(), + title, + minItems, + maxItems, + outMinimum, + outMaximum, + anyOf?.map { it.toInternalJson() }, ) } @Serializable - internal data class Internal( + internal data class InternalOpenAPI( val type: String? = null, val description: String? = null, val format: String? = null, val nullable: Boolean? = false, val enum: List? = null, - val properties: Map? = null, + val properties: Map? = null, val required: List? = null, - val items: Internal? = null, + val items: InternalOpenAPI? = null, val title: String? = null, val minItems: Int? = null, val maxItems: Int? = null, val minimum: Double? = null, val maximum: Double? = null, - val anyOf: List? = null, + val anyOf: List? = null, ) + + @Serializable internal sealed interface InternalJson + + @Serializable + internal data class InternalJsonNonNull( + val type: String? = null, + val description: String? = null, + val format: String? = null, + val enum: List? = null, + val properties: Map? = null, + val required: List? = null, + val items: InternalJson? = null, + val title: String? = null, + val minItems: Int? = null, + val maxItems: Int? = null, + val minimum: Double? = null, + val maximum: Double? = null, + val anyOf: List? = null, + ) : InternalJson + + @Serializable + internal data class InternalJsonNullable( + val type: List? = null, + val description: String? = null, + val format: String? = null, + val enum: List? = null, + val properties: Map? = null, + val required: List? = null, + val items: InternalJson? = null, + val title: String? = null, + val minItems: Int? = null, + val maxItems: Int? = null, + val minimum: Double? = null, + val maximum: Double? = null, + val anyOf: List? = null, + ) : InternalJson } diff --git a/firebase-ai/src/test/java/com/google/firebase/ai/SchemaTests.kt b/firebase-ai/src/test/java/com/google/firebase/ai/SchemaTests.kt index f9bdf8c835f..b275807d9b5 100644 --- a/firebase-ai/src/test/java/com/google/firebase/ai/SchemaTests.kt +++ b/firebase-ai/src/test/java/com/google/firebase/ai/SchemaTests.kt @@ -19,7 +19,9 @@ package com.google.firebase.ai import com.google.firebase.ai.type.Schema import com.google.firebase.ai.type.StringFormat import io.kotest.assertions.json.shouldEqualJson +import java.io.File import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.ClassDiscriminatorMode import kotlinx.serialization.json.Json import org.junit.Test @@ -93,7 +95,7 @@ internal class SchemaTests { """ .trimIndent() - Json.encodeToString(schemaDeclaration.toInternal()).shouldEqualJson(expectedJson) + Json.encodeToString(schemaDeclaration.toInternalOpenApi()).shouldEqualJson(expectedJson) } @Test @@ -216,6 +218,70 @@ internal class SchemaTests { """ .trimIndent() - Json.encodeToString(schemaDeclaration.toInternal()).shouldEqualJson(expectedJson) + Json.encodeToString(schemaDeclaration.toInternalOpenApi()).shouldEqualJson(expectedJson) } + + @Test + fun `schema encoding openAPI spec test`() { + val expectedSerialization = getSchemaJson("open-api-schema.json") + val serializedSchema = JSON_ENCODER.encodeToString(TEST_SCHEMA.toInternalOpenApi()) + serializedSchema.shouldEqualJson(expectedSerialization) + } + + @Test + fun `schema encoding jsonSchema spec test`() { + val expectedSerialization = getSchemaJson("json-schema.json") + val serializedSchema = JSON_ENCODER.encodeToString(TEST_SCHEMA.toInternalJson()) + serializedSchema.shouldEqualJson(expectedSerialization) + } + + internal fun getSchemaJson(filename: String): String { + return File("src/test/resources/vertexai-sdk-test-data/mock-responses/schema/${filename}") + .readText() + } + + private val JSON_ENCODER = Json { classDiscriminatorMode = ClassDiscriminatorMode.NONE } + + private val TEST_SCHEMA = + Schema.obj( + properties = + mapOf( + "integerTest" to Schema.integer(title = "integerTest", nullable = true), + "longTest" to + Schema.long( + title = "longTest", + nullable = false, + minimum = 0.0, + maximum = 5.0, + description = "a test long" + ), + "floatTest" to Schema.float(title = "floatTest", nullable = false), + "doubleTest" to Schema.double(title = "doubleTest", nullable = true), + "listTest" to + Schema.array( + items = Schema.integer(nullable = false), + title = "listTest", + nullable = false, + minItems = 0, + maxItems = 5 + ), + "booleanTest" to Schema.boolean(title = "booleanTest", nullable = false), + "stringTest" to + Schema.string(title = "stringTest", format = StringFormat.Custom("email")), + "objTest" to + Schema.obj( + properties = + mapOf( + "testInt" to Schema.integer(title = "testInt", nullable = false), + ), + title = "objTest", + description = "class kdoc should be used if property kdocs aren't present", + nullable = false + ), + "enumTest" to Schema.enumeration(values = listOf("val1", "val2", "val3")) + ), + optionalProperties = listOf("booleanTest"), + description = "A test kdoc", + nullable = false + ) } diff --git a/firebase-ai/src/test/java/com/google/firebase/ai/SerializationTests.kt b/firebase-ai/src/test/java/com/google/firebase/ai/SerializationTests.kt index e6bea7a9e6f..476d68261d2 100644 --- a/firebase-ai/src/test/java/com/google/firebase/ai/SerializationTests.kt +++ b/firebase-ai/src/test/java/com/google/firebase/ai/SerializationTests.kt @@ -437,7 +437,7 @@ internal class SerializationTests { } """ .trimIndent() - val actualJson = descriptorToJson(Schema.Internal.serializer().descriptor) + val actualJson = descriptorToJson(Schema.InternalOpenAPI.serializer().descriptor) expectedJsonAsString shouldEqualJson actualJson.toString() } diff --git a/firebase-ai/update_responses.sh b/firebase-ai/update_responses.sh index baf2676dfbb..881d63e2c53 100755 --- a/firebase-ai/update_responses.sh +++ b/firebase-ai/update_responses.sh @@ -17,7 +17,7 @@ # This script replaces mock response files for Vertex AI unit tests with a fresh # clone of the shared repository of Vertex AI test data. -RESPONSES_VERSION='v14.*' # The major version of mock responses to use +RESPONSES_VERSION='v15.*' # The major version of mock responses to use REPO_NAME="vertexai-sdk-test-data" REPO_LINK="https://github.com/FirebaseExtended/$REPO_NAME.git"