Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ import kotlinx.coroutines.flow.map
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.ClassDiscriminatorMode
import kotlinx.serialization.json.Json

@OptIn(ExperimentalSerializationApi::class)
Expand All @@ -75,6 +76,7 @@ internal val JSON = Json {
prettyPrint = false
isLenient = true
explicitNulls = false
classDiscriminatorMode = ClassDiscriminatorMode.NONE
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ public class FunctionDeclaration(
internal val schema: Schema =
Schema.obj(properties = parameters, optionalProperties = optionalParameters, nullable = false)

internal fun toInternal() = Internal(name, description, schema.toInternal())
internal fun toInternal() = Internal(name, description, schema.toInternalOpenApi())

@Serializable
internal data class Internal(
val name: String,
val description: String,
val parameters: Schema.Internal
val parameters: Schema.InternalOpenAPI
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ private constructor(
frequencyPenalty = frequencyPenalty,
presencePenalty = presencePenalty,
responseMimeType = responseMimeType,
responseSchema = responseSchema?.toInternal(),
responseSchema = responseSchema?.toInternalOpenApi(),
responseModalities = responseModalities?.map { it.toInternal() },
thinkingConfig = thinkingConfig?.toInternal()
)
Expand All @@ -216,7 +216,7 @@ private constructor(
@SerialName("response_mime_type") val responseMimeType: String? = null,
@SerialName("presence_penalty") val presencePenalty: Float? = null,
@SerialName("frequency_penalty") val frequencyPenalty: Float? = null,
@SerialName("response_schema") val responseSchema: Schema.Internal? = null,
@SerialName("response_schema") val responseSchema: Schema.InternalOpenAPI? = null,
@SerialName("response_modalities") val responseModalities: List<String>? = null,
@SerialName("thinking_config") val thinkingConfig: ThinkingConfig.Internal? = null
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import java.io.ByteArrayOutputStream
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
import kotlinx.serialization.json.JsonContentPolymorphicSerializer
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonNull
Expand Down
119 changes: 110 additions & 9 deletions firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Schema.kt
Original file line number Diff line number Diff line change
Expand Up @@ -322,46 +322,147 @@ internal constructor(
public fun anyOf(schemas: List<Schema>): Schema = Schema(type = "ANYOF", anyOf = schemas)
}

internal fun toInternal(): Internal {
internal fun toInternalOpenApi(): InternalOpenAPI {
val cleanedType =
if (type == "ANYOF") {
null
} else {
type
}
return Internal(
return InternalOpenAPI(
cleanedType,
description,
format,
nullable,
enum,
properties?.mapValues { it.value.toInternal() },
properties?.mapValues { it.value.toInternalOpenApi() },
required,
items?.toInternal(),
items?.toInternalOpenApi(),
title,
minItems,
maxItems,
minimum,
maximum,
anyOf?.map { it.toInternal() },
anyOf?.map { it.toInternalOpenApi() },
)
}

internal fun toInternalJson(): InternalJson {
val outType =
if (type == "ANYOF" || (type == "STRING" && format == "enum")) {
null
} else {
type.lowercase()
}

val (outMinimum, outMaximum) =
if (outType == "integer" && format == "int32") {
(minimum ?: Integer.MIN_VALUE.toDouble()) to (maximum ?: Integer.MAX_VALUE.toDouble())
} else {
minimum to maximum
}

val outFormat =
if (
(outType == "integer" && format == "int32") ||
(outType == "number" && format == "float") ||
format == "enum"
) {
null
} else {
format
}

if (nullable == true) {
return InternalJsonNullable(
outType?.let { listOf(it, "null") },
description,
outFormat,
enum?.let {
buildList {
addAll(it)
add("null")
}
},
properties?.mapValues { it.value.toInternalJson() },
required,
items?.toInternalJson(),
title,
minItems,
maxItems,
outMinimum,
outMaximum,
anyOf?.map { it.toInternalJson() },
)
}
return InternalJsonNonNull(
outType,
description,
outFormat,
enum,
properties?.mapValues { it.value.toInternalJson() },
required,
items?.toInternalJson(),
title,
minItems,
maxItems,
outMinimum,
outMaximum,
anyOf?.map { it.toInternalJson() },
)
}

@Serializable
internal data class Internal(
internal data class InternalOpenAPI(
val type: String? = null,
val description: String? = null,
val format: String? = null,
val nullable: Boolean? = false,
val enum: List<String>? = null,
val properties: Map<String, Internal>? = null,
val properties: Map<String, InternalOpenAPI>? = null,
val required: List<String>? = null,
val items: Internal? = null,
val items: InternalOpenAPI? = null,
val title: String? = null,
val minItems: Int? = null,
val maxItems: Int? = null,
val minimum: Double? = null,
val maximum: Double? = null,
val anyOf: List<Internal>? = null,
val anyOf: List<InternalOpenAPI>? = null,
)

@Serializable internal sealed interface InternalJson

@Serializable
internal data class InternalJsonNonNull(
val type: String? = null,
val description: String? = null,
val format: String? = null,
val enum: List<String>? = null,
val properties: Map<String, InternalJson>? = null,
val required: List<String>? = null,
val items: InternalJson? = null,
val title: String? = null,
val minItems: Int? = null,
val maxItems: Int? = null,
val minimum: Double? = null,
val maximum: Double? = null,
val anyOf: List<InternalJson>? = null,
) : InternalJson

@Serializable
internal data class InternalJsonNullable(
val type: List<String>? = null,
val description: String? = null,
val format: String? = null,
val enum: List<String>? = null,
val properties: Map<String, InternalJson>? = null,
val required: List<String>? = null,
val items: InternalJson? = null,
val title: String? = null,
val minItems: Int? = null,
val maxItems: Int? = null,
val minimum: Double? = null,
val maximum: Double? = null,
val anyOf: List<InternalJson>? = null,
) : InternalJson
}
70 changes: 68 additions & 2 deletions firebase-ai/src/test/java/com/google/firebase/ai/SchemaTests.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package com.google.firebase.ai
import com.google.firebase.ai.type.Schema
import com.google.firebase.ai.type.StringFormat
import io.kotest.assertions.json.shouldEqualJson
import java.io.File
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.ClassDiscriminatorMode
import kotlinx.serialization.json.Json
import org.junit.Test

Expand Down Expand Up @@ -93,7 +95,7 @@ internal class SchemaTests {
"""
.trimIndent()

Json.encodeToString(schemaDeclaration.toInternal()).shouldEqualJson(expectedJson)
Json.encodeToString(schemaDeclaration.toInternalOpenApi()).shouldEqualJson(expectedJson)
}

@Test
Expand Down Expand Up @@ -216,6 +218,70 @@ internal class SchemaTests {
"""
.trimIndent()

Json.encodeToString(schemaDeclaration.toInternal()).shouldEqualJson(expectedJson)
Json.encodeToString(schemaDeclaration.toInternalOpenApi()).shouldEqualJson(expectedJson)
}

@Test
fun `schema encoding openAPI spec test`() {
val expectedSerialization = getSchemaJson("open-api-schema.json")
val serializedSchema = JSON_ENCODER.encodeToString(TEST_SCHEMA.toInternalOpenApi())
serializedSchema.shouldEqualJson(expectedSerialization)
}

@Test
fun `schema encoding jsonSchema spec test`() {
val expectedSerialization = getSchemaJson("json-schema.json")
val serializedSchema = JSON_ENCODER.encodeToString(TEST_SCHEMA.toInternalJson())
serializedSchema.shouldEqualJson(expectedSerialization)
}

internal fun getSchemaJson(filename: String): String {
return File("src/test/resources/vertexai-sdk-test-data/mock-responses/schema/${filename}")
.readText()
}

private val JSON_ENCODER = Json { classDiscriminatorMode = ClassDiscriminatorMode.NONE }

private val TEST_SCHEMA =
Schema.obj(
properties =
mapOf(
"integerTest" to Schema.integer(title = "integerTest", nullable = true),
"longTest" to
Schema.long(
title = "longTest",
nullable = false,
minimum = 0.0,
maximum = 5.0,
description = "a test long"
),
"floatTest" to Schema.float(title = "floatTest", nullable = false),
"doubleTest" to Schema.double(title = "doubleTest", nullable = true),
"listTest" to
Schema.array(
items = Schema.integer(nullable = false),
title = "listTest",
nullable = false,
minItems = 0,
maxItems = 5
),
"booleanTest" to Schema.boolean(title = "booleanTest", nullable = false),
"stringTest" to
Schema.string(title = "stringTest", format = StringFormat.Custom("email")),
"objTest" to
Schema.obj(
properties =
mapOf(
"testInt" to Schema.integer(title = "testInt", nullable = false),
),
title = "objTest",
description = "class kdoc should be used if property kdocs aren't present",
nullable = false
),
"enumTest" to Schema.enumeration(values = listOf("val1", "val2", "val3"))
),
optionalProperties = listOf("booleanTest"),
description = "A test kdoc",
nullable = false
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ internal class SerializationTests {
}
"""
.trimIndent()
val actualJson = descriptorToJson(Schema.Internal.serializer().descriptor)
val actualJson = descriptorToJson(Schema.InternalOpenAPI.serializer().descriptor)
expectedJsonAsString shouldEqualJson actualJson.toString()
}

Expand Down
2 changes: 1 addition & 1 deletion firebase-ai/update_responses.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# This script replaces mock response files for Vertex AI unit tests with a fresh
# clone of the shared repository of Vertex AI test data.

RESPONSES_VERSION='v14.*' # The major version of mock responses to use
RESPONSES_VERSION='v15.*' # The major version of mock responses to use
REPO_NAME="vertexai-sdk-test-data"
REPO_LINK="https://github.com/FirebaseExtended/$REPO_NAME.git"

Expand Down
Loading