Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package com.openai.client.okhttp

import com.fasterxml.jackson.databind.json.JsonMapper
import com.openai.azure.AzureOpenAIServiceVersion
import com.openai.azure.AzureUrlPathMode
import com.openai.client.OpenAIClient
import com.openai.client.OpenAIClientImpl
import com.openai.core.ClientOptions
Expand All @@ -12,7 +13,6 @@ import com.openai.core.http.AsyncStreamResponse
import com.openai.core.http.Headers
import com.openai.core.http.HttpClient
import com.openai.core.http.QueryParams
import com.openai.core.jsonMapper
import com.openai.credential.Credential
import java.net.Proxy
import java.time.Clock
Expand Down Expand Up @@ -204,6 +204,10 @@ class OpenAIOkHttpClient private constructor() {
clientOptions.azureServiceVersion(azureServiceVersion)
}

fun azureUrlPathMode(azureUrlPathMode: AzureUrlPathMode) = apply {
clientOptions.azureUrlPathMode(azureUrlPathMode)
}

fun organization(organization: String?) = apply { clientOptions.organization(organization) }

/** Alias for calling [Builder.organization] with `organization.orElse(null)`. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package com.openai.client.okhttp

import com.fasterxml.jackson.databind.json.JsonMapper
import com.openai.azure.AzureOpenAIServiceVersion
import com.openai.azure.AzureUrlPathMode
import com.openai.client.OpenAIClientAsync
import com.openai.client.OpenAIClientAsyncImpl
import com.openai.core.ClientOptions
Expand All @@ -12,7 +13,6 @@ import com.openai.core.http.AsyncStreamResponse
import com.openai.core.http.Headers
import com.openai.core.http.HttpClient
import com.openai.core.http.QueryParams
import com.openai.core.jsonMapper
import com.openai.credential.Credential
import java.net.Proxy
import java.time.Clock
Expand Down Expand Up @@ -204,6 +204,10 @@ class OpenAIOkHttpClientAsync private constructor() {
clientOptions.azureServiceVersion(azureServiceVersion)
}

fun azureUrlPath(azureUrlPathMode: AzureUrlPathMode) = apply {
clientOptions.azureUrlPathMode(azureUrlPathMode)
}

fun organization(organization: String?) = apply { clientOptions.organization(organization) }

/** Alias for calling [Builder.organization] with `organization.orElse(null)`. */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.openai.azure

import java.net.URI

/** Represents the category of an Azure URL. */
internal enum class AzureUrlCategory {
/** Azure host _not_ ending with `/openai/v1`. */
AZURE_LEGACY,
/** Azure host ending with `/openai/v1`. */
AZURE_UNIFIED,
/** Anything else. */
NON_AZURE;

fun isAzure(): Boolean =
when (this) {
AZURE_LEGACY,
AZURE_UNIFIED -> true
NON_AZURE -> false
}

companion object {

fun categorizeBaseUrl(baseUrl: String, pathMode: AzureUrlPathMode): AzureUrlCategory {
val trimmedBaseUrl = baseUrl.trim().trimEnd('/')
val host = URI.create(trimmedBaseUrl).host
return when {
// Azure OpenAI resource URL with the old schema.
host.endsWith(".openai.azure.com", ignoreCase = true) ||
// Azure OpenAI resource URL with the OpenAI unified schema.
host.endsWith(".services.ai.azure.com", ignoreCase = true) ||
// Azure OpenAI resource URL, but with a schema different to the known ones.
host.endsWith(".azure-api.net", ignoreCase = true) ||
host.endsWith(".cognitiveservices.azure.com", ignoreCase = true) ->
when (pathMode) {
AzureUrlPathMode.LEGACY -> AZURE_LEGACY
AzureUrlPathMode.UNIFIED ->
if (trimmedBaseUrl.endsWith("/openai/v1")) AZURE_UNIFIED
else AZURE_LEGACY
}

else -> NON_AZURE
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.openai.azure

/**
* To force the deployment or model named to be part of the URL path for Azure OpenAI requests, use
* [AzureUrlPathMode.LEGACY]. The default is [AzureUrlPathMode.UNIFIED].
*/
enum class AzureUrlPathMode {
LEGACY,
UNIFIED,
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@ package com.openai.azure

import com.openai.core.ClientOptions
import com.openai.core.http.HttpRequest
import com.openai.core.isAzureEndpoint
import com.openai.credential.BearerTokenCredential

@JvmSynthetic
internal fun HttpRequest.Builder.addPathSegmentsForAzure(
clientOptions: ClientOptions,
deploymentModel: String?,
): HttpRequest.Builder = apply {
if (isAzureEndpoint(clientOptions.baseUrl())) {
val urlCategory =
AzureUrlCategory.categorizeBaseUrl(clientOptions.baseUrl(), clientOptions.azureUrlPathMode)
if (urlCategory == AzureUrlCategory.AZURE_LEGACY) {
// Legacy known Azure endpoints are treated the old way.
addPathSegment("openai")
deploymentModel?.let { addPathSegments("deployments", it) }
}
Expand All @@ -20,10 +22,9 @@ internal fun HttpRequest.Builder.addPathSegmentsForAzure(
internal fun HttpRequest.Builder.replaceBearerTokenForAzure(
clientOptions: ClientOptions
): HttpRequest.Builder = apply {
if (
isAzureEndpoint(clientOptions.baseUrl()) &&
clientOptions.credential is BearerTokenCredential
) {
val urlCategory =
AzureUrlCategory.categorizeBaseUrl(clientOptions.baseUrl(), clientOptions.azureUrlPathMode)
if (urlCategory.isAzure() && clientOptions.credential is BearerTokenCredential) {
replaceHeaders("Authorization", "Bearer ${clientOptions.credential.token()}")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package com.openai.core

import com.fasterxml.jackson.databind.json.JsonMapper
import com.openai.azure.AzureOpenAIServiceVersion
import com.openai.azure.AzureUrlCategory
import com.openai.azure.AzureUrlPathMode
import com.openai.azure.credential.AzureApiKeyCredential
import com.openai.core.http.AsyncStreamResponse
import com.openai.core.http.Headers
Expand Down Expand Up @@ -98,6 +100,7 @@ private constructor(
@get:JvmName("maxRetries") val maxRetries: Int,
@get:JvmName("credential") val credential: Credential,
@get:JvmName("azureServiceVersion") val azureServiceVersion: AzureOpenAIServiceVersion?,
@get:JvmName("azureUrlPathMode") val azureUrlPathMode: AzureUrlPathMode,
private val organization: String?,
private val project: String?,
private val webhookSecret: String?,
Expand Down Expand Up @@ -163,6 +166,7 @@ private constructor(
private var maxRetries: Int = 2
private var credential: Credential? = null
private var azureServiceVersion: AzureOpenAIServiceVersion? = null
private var azureUrlPathMode: AzureUrlPathMode = AzureUrlPathMode.UNIFIED
private var organization: String? = null
private var project: String? = null
private var webhookSecret: String? = null
Expand All @@ -182,6 +186,7 @@ private constructor(
maxRetries = clientOptions.maxRetries
credential = clientOptions.credential
azureServiceVersion = clientOptions.azureServiceVersion
azureUrlPathMode = clientOptions.azureUrlPathMode
organization = clientOptions.organization
project = clientOptions.project
webhookSecret = clientOptions.webhookSecret
Expand Down Expand Up @@ -297,6 +302,10 @@ private constructor(
this.azureServiceVersion = azureServiceVersion
}

fun azureUrlPathMode(azureUrlPathMode: AzureUrlPathMode) = apply {
this.azureUrlPathMode = azureUrlPathMode
}

fun organization(organization: String?) = apply { this.organization = organization }

/** Alias for calling [Builder.organization] with `organization.orElse(null)`. */
Expand Down Expand Up @@ -485,14 +494,20 @@ private constructor(
}

baseUrl?.let {
if (isAzureEndpoint(it)) {
// Default Azure OpenAI version is used if Azure user doesn't
// specific a service API version in 'queryParams'.
replaceQueryParams(
"api-version",
(azureServiceVersion ?: AzureOpenAIServiceVersion.latestStableVersion())
.value,
)
when (AzureUrlCategory.categorizeBaseUrl(it, azureUrlPathMode)) {
// Legacy Azure routes will still require an api-version value.
AzureUrlCategory.AZURE_LEGACY ->
replaceQueryParams(
"api-version",
(azureServiceVersion ?: AzureOpenAIServiceVersion.latestStableVersion())
.value,
)
// We only add the value if it's defined by the user for unified Azure routes.
AzureUrlCategory.AZURE_UNIFIED ->
azureServiceVersion?.let { version ->
replaceQueryParams("api-version", version.value)
}
AzureUrlCategory.NON_AZURE -> {}
}
}

Expand Down Expand Up @@ -532,6 +547,7 @@ private constructor(
maxRetries,
credential,
azureServiceVersion,
azureUrlPathMode,
organization,
project,
webhookSecret,
Expand Down
11 changes: 0 additions & 11 deletions openai-java-core/src/main/kotlin/com/openai/core/Utils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,4 @@ internal fun Any?.contentToString(): String {
return string
}

@JvmSynthetic
internal fun isAzureEndpoint(baseUrl: String): Boolean {
// Azure Endpoint should be in the format of `https://<region>.openai.azure.com`.
// Or `https://<region>.azure-api.net` for Azure OpenAI Management URL.
// Or `<user>-random-<region>.cognitiveservices.azure.com`.
val trimmedBaseUrl = baseUrl.trim().trimEnd('/')
return trimmedBaseUrl.endsWith(".openai.azure.com", true) ||
trimmedBaseUrl.endsWith(".azure-api.net", true) ||
trimmedBaseUrl.endsWith(".cognitiveservices.azure.com", true)
}

internal interface Enum
Loading
Loading