From d7d991dfc95577f0b5c4f3c038b998a16133d145 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Tue, 17 Dec 2024 17:00:02 -0800 Subject: [PATCH 1/5] temp hack to make realtime work with azure --- src/openai/lib/azure.py | 2 + .../resources/beta/realtime/realtime.py | 38 +++++++++++++++++-- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index 13d9f31838..ee32f13b95 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -224,6 +224,7 @@ def __init__( self._api_version = api_version self._azure_ad_token = azure_ad_token self._azure_ad_token_provider = azure_ad_token_provider + self._azure_deployment = azure_deployment @override def copy( @@ -471,6 +472,7 @@ def __init__( self._api_version = api_version self._azure_ad_token = azure_ad_token self._azure_ad_token_provider = azure_ad_token_provider + self._azure_deployment = azure_deployment @override def copy( diff --git a/src/openai/resources/beta/realtime/realtime.py b/src/openai/resources/beta/realtime/realtime.py index c79fd46217..d940ed44ae 100644 --- a/src/openai/resources/beta/realtime/realtime.py +++ b/src/openai/resources/beta/realtime/realtime.py @@ -319,11 +319,26 @@ async def __aenter__(self) -> AsyncRealtimeConnection: except ImportError as exc: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc + auth_headers = self.__client.auth_headers + extra_query = self.__extra_query + if self.__client.__class__.__name__ == "AsyncAzureOpenAI": + extra_query = { + **self.__extra_query, + "api-version": self.__client._api_version, + "deployment": self.__client._azure_deployment or self.__model + } + if self.__client.api_key != "": + auth_headers = {"api-key": self.__client.api_key} + else: + token = await self.__client._get_azure_ad_token() + if token: + auth_headers = {"Authorization": f"Bearer {token}"} + url = self._prepare_url().copy_with( params={ **self.__client.base_url.params, "model": self.__model, - **self.__extra_query, + **extra_query, }, ) log.debug("Connecting to %s", url) @@ -336,7 +351,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection: user_agent_header=self.__client.user_agent, additional_headers=_merge_mappings( { - **self.__client.auth_headers, + **auth_headers, "OpenAI-Beta": "realtime=v1", }, self.__extra_headers, @@ -496,11 +511,26 @@ def __enter__(self) -> RealtimeConnection: except ImportError as exc: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc + auth_headers = self.__client.auth_headers + extra_query = self.__extra_query + if self.__client.__class__.__name__ == "AzureOpenAI": + extra_query = { + **self.__extra_query, + "api-version": self.__client._api_version, + "deployment": self.__client._azure_deployment or self.__model + } + if self.__client.api_key != "": + auth_headers = {"api-key": self.__client.api_key} + else: + token = self.__client._get_azure_ad_token() + if token: + auth_headers = {"Authorization": f"Bearer {token}"} + url = self._prepare_url().copy_with( params={ **self.__client.base_url.params, "model": self.__model, - **self.__extra_query, + **extra_query, }, ) log.debug("Connecting to %s", url) @@ -513,7 +543,7 @@ def __enter__(self) -> RealtimeConnection: user_agent_header=self.__client.user_agent, additional_headers=_merge_mappings( { - **self.__client.auth_headers, + **auth_headers, "OpenAI-Beta": "realtime=v1", }, self.__extra_headers, From 69d816dc053e95738a7a0783843f6c64ad07fd0b Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Wed, 18 Dec 2024 10:59:11 -0800 Subject: [PATCH 2/5] lint --- src/openai/_utils/__init__.py | 2 ++ src/openai/_utils/_utils.py | 12 ++++++++++++ src/openai/resources/beta/realtime/realtime.py | 10 ++++++---- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/openai/_utils/__init__.py b/src/openai/_utils/__init__.py index af2c9bb77e..bd01c088dc 100644 --- a/src/openai/_utils/__init__.py +++ b/src/openai/_utils/__init__.py @@ -25,6 +25,7 @@ coerce_integer as coerce_integer, file_from_path as file_from_path, parse_datetime as parse_datetime, + is_azure_client as is_azure_client, strip_not_given as strip_not_given, deepcopy_minimal as deepcopy_minimal, get_async_library as get_async_library, @@ -32,6 +33,7 @@ get_required_header as get_required_header, maybe_coerce_boolean as maybe_coerce_boolean, maybe_coerce_integer as maybe_coerce_integer, + is_async_azure_client as is_async_azure_client, ) from ._typing import ( is_list_type as is_list_type, diff --git a/src/openai/_utils/_utils.py b/src/openai/_utils/_utils.py index e5811bba42..c816fa6f91 100644 --- a/src/openai/_utils/_utils.py +++ b/src/openai/_utils/_utils.py @@ -5,6 +5,7 @@ import inspect import functools from typing import ( + TYPE_CHECKING, Any, Tuple, Mapping, @@ -30,6 +31,9 @@ _SequenceT = TypeVar("_SequenceT", bound=Sequence[object]) CallableT = TypeVar("CallableT", bound=Callable[..., Any]) +if TYPE_CHECKING: + from ..lib.azure import AzureOpenAI, AsyncAzureOpenAI + def flatten(t: Iterable[Iterable[_T]]) -> list[_T]: return [item for sublist in t for item in sublist] @@ -412,3 +416,11 @@ def json_safe(data: object) -> object: return data.isoformat() return data + + +def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]: + return hasattr(client, "_azure_ad_token_provider") + + +def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]: + return hasattr(client, "_azure_ad_token_provider") diff --git a/src/openai/resources/beta/realtime/realtime.py b/src/openai/resources/beta/realtime/realtime.py index d940ed44ae..dd5c830778 100644 --- a/src/openai/resources/beta/realtime/realtime.py +++ b/src/openai/resources/beta/realtime/realtime.py @@ -21,9 +21,11 @@ ) from ...._types import NOT_GIVEN, Query, Headers, NotGiven from ...._utils import ( + is_azure_client, maybe_transform, strip_not_given, async_maybe_transform, + is_async_azure_client, ) from ...._compat import cached_property from ...._models import construct_type_unchecked @@ -321,11 +323,11 @@ async def __aenter__(self) -> AsyncRealtimeConnection: auth_headers = self.__client.auth_headers extra_query = self.__extra_query - if self.__client.__class__.__name__ == "AsyncAzureOpenAI": + if is_async_azure_client(self.__client): extra_query = { **self.__extra_query, "api-version": self.__client._api_version, - "deployment": self.__client._azure_deployment or self.__model + "deployment": self.__client._azure_deployment or self.__model, } if self.__client.api_key != "": auth_headers = {"api-key": self.__client.api_key} @@ -513,11 +515,11 @@ def __enter__(self) -> RealtimeConnection: auth_headers = self.__client.auth_headers extra_query = self.__extra_query - if self.__client.__class__.__name__ == "AzureOpenAI": + if is_azure_client(self.__client): extra_query = { **self.__extra_query, "api-version": self.__client._api_version, - "deployment": self.__client._azure_deployment or self.__model + "deployment": self.__client._azure_deployment or self.__model, } if self.__client.api_key != "": auth_headers = {"api-key": self.__client.api_key} From 80672db0e080caed0c1a3f0e13d7548711b44f38 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Wed, 18 Dec 2024 13:02:45 -0800 Subject: [PATCH 3/5] extract azure logic out of enter --- src/openai/_utils/__init__.py | 2 + src/openai/_utils/_utils.py | 44 +++++++++++++++++-- .../resources/beta/realtime/realtime.py | 30 +++---------- 3 files changed, 49 insertions(+), 27 deletions(-) diff --git a/src/openai/_utils/__init__.py b/src/openai/_utils/__init__.py index bd01c088dc..61b07d1f4b 100644 --- a/src/openai/_utils/__init__.py +++ b/src/openai/_utils/__init__.py @@ -34,6 +34,8 @@ maybe_coerce_boolean as maybe_coerce_boolean, maybe_coerce_integer as maybe_coerce_integer, is_async_azure_client as is_async_azure_client, + configure_azure_realtime as configure_azure_realtime, + configure_azure_realtime_async as configure_azure_realtime_async, ) from ._typing import ( is_list_type as is_list_type, diff --git a/src/openai/_utils/_utils.py b/src/openai/_utils/_utils.py index c816fa6f91..1398734cf0 100644 --- a/src/openai/_utils/_utils.py +++ b/src/openai/_utils/_utils.py @@ -22,7 +22,7 @@ import sniffio -from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike +from .._types import Query, NotGiven, FileTypes, NotGivenOr, HeadersLike from .._compat import parse_date as parse_date, parse_datetime as parse_datetime _T = TypeVar("_T") @@ -419,8 +419,46 @@ def json_safe(data: object) -> object: def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]: - return hasattr(client, "_azure_ad_token_provider") + from ..lib.azure import AzureOpenAI + + return isinstance(client, AzureOpenAI) def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]: - return hasattr(client, "_azure_ad_token_provider") + from ..lib.azure import AsyncAzureOpenAI + + return isinstance(client, AsyncAzureOpenAI) + + +def configure_azure_realtime(client: AzureOpenAI, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]: + auth_headers = {} + query = { + **extra_query, + "api-version": client._api_version, + "deployment": client._azure_deployment or model, + } + if client.api_key != "": + auth_headers = {"api-key": client.api_key} + else: + token = client._get_azure_ad_token() + if token: + auth_headers = {"Authorization": f"Bearer {token}"} + return query, auth_headers + + +async def configure_azure_realtime_async( + client: AsyncAzureOpenAI, model: str, extra_query: Query +) -> tuple[Query, dict[str, str]]: + auth_headers = {} + query = { + **extra_query, + "api-version": client._api_version, + "deployment": client._azure_deployment or model, + } + if client.api_key != "": + auth_headers = {"api-key": client.api_key} + else: + token = await client._get_azure_ad_token() + if token: + auth_headers = {"Authorization": f"Bearer {token}"} + return query, auth_headers diff --git a/src/openai/resources/beta/realtime/realtime.py b/src/openai/resources/beta/realtime/realtime.py index dd5c830778..3c4f3e4f05 100644 --- a/src/openai/resources/beta/realtime/realtime.py +++ b/src/openai/resources/beta/realtime/realtime.py @@ -26,6 +26,8 @@ strip_not_given, async_maybe_transform, is_async_azure_client, + configure_azure_realtime, + configure_azure_realtime_async, ) from ...._compat import cached_property from ...._models import construct_type_unchecked @@ -321,20 +323,10 @@ async def __aenter__(self) -> AsyncRealtimeConnection: except ImportError as exc: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc - auth_headers = self.__client.auth_headers extra_query = self.__extra_query + auth_headers = self.__client.auth_headers if is_async_azure_client(self.__client): - extra_query = { - **self.__extra_query, - "api-version": self.__client._api_version, - "deployment": self.__client._azure_deployment or self.__model, - } - if self.__client.api_key != "": - auth_headers = {"api-key": self.__client.api_key} - else: - token = await self.__client._get_azure_ad_token() - if token: - auth_headers = {"Authorization": f"Bearer {token}"} + extra_query, auth_headers = await configure_azure_realtime_async(self.__client, self.__model, extra_query) url = self._prepare_url().copy_with( params={ @@ -513,20 +505,10 @@ def __enter__(self) -> RealtimeConnection: except ImportError as exc: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc - auth_headers = self.__client.auth_headers extra_query = self.__extra_query + auth_headers = self.__client.auth_headers if is_azure_client(self.__client): - extra_query = { - **self.__extra_query, - "api-version": self.__client._api_version, - "deployment": self.__client._azure_deployment or self.__model, - } - if self.__client.api_key != "": - auth_headers = {"api-key": self.__client.api_key} - else: - token = self.__client._get_azure_ad_token() - if token: - auth_headers = {"Authorization": f"Bearer {token}"} + extra_query, auth_headers = configure_azure_realtime(self.__client, self.__model, extra_query) url = self._prepare_url().copy_with( params={ From d34eb2c12c13a981ee94d417d7fa65d73b03bef2 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Wed, 18 Dec 2024 17:09:13 -0800 Subject: [PATCH 4/5] remove azure_deployment for now --- src/openai/_utils/_utils.py | 4 ++-- src/openai/lib/azure.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/openai/_utils/_utils.py b/src/openai/_utils/_utils.py index 1398734cf0..6d88ad06aa 100644 --- a/src/openai/_utils/_utils.py +++ b/src/openai/_utils/_utils.py @@ -435,7 +435,7 @@ def configure_azure_realtime(client: AzureOpenAI, model: str, extra_query: Query query = { **extra_query, "api-version": client._api_version, - "deployment": client._azure_deployment or model, + "deployment": model, } if client.api_key != "": auth_headers = {"api-key": client.api_key} @@ -453,7 +453,7 @@ async def configure_azure_realtime_async( query = { **extra_query, "api-version": client._api_version, - "deployment": client._azure_deployment or model, + "deployment": model, } if client.api_key != "": auth_headers = {"api-key": client.api_key} diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index ee32f13b95..13d9f31838 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -224,7 +224,6 @@ def __init__( self._api_version = api_version self._azure_ad_token = azure_ad_token self._azure_ad_token_provider = azure_ad_token_provider - self._azure_deployment = azure_deployment @override def copy( @@ -472,7 +471,6 @@ def __init__( self._api_version = api_version self._azure_ad_token = azure_ad_token self._azure_ad_token_provider = azure_ad_token_provider - self._azure_deployment = azure_deployment @override def copy( From 297ab10ef1d5b3c83f3e0664d8433ebb5b428ff3 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Thu, 19 Dec 2024 09:13:01 -0800 Subject: [PATCH 5/5] move configure_azure to clients --- src/openai/_utils/__init__.py | 2 -- src/openai/_utils/_utils.py | 36 +------------------ src/openai/lib/azure.py | 32 ++++++++++++++++- .../resources/beta/realtime/realtime.py | 6 ++-- 4 files changed, 34 insertions(+), 42 deletions(-) diff --git a/src/openai/_utils/__init__.py b/src/openai/_utils/__init__.py index 61b07d1f4b..bd01c088dc 100644 --- a/src/openai/_utils/__init__.py +++ b/src/openai/_utils/__init__.py @@ -34,8 +34,6 @@ maybe_coerce_boolean as maybe_coerce_boolean, maybe_coerce_integer as maybe_coerce_integer, is_async_azure_client as is_async_azure_client, - configure_azure_realtime as configure_azure_realtime, - configure_azure_realtime_async as configure_azure_realtime_async, ) from ._typing import ( is_list_type as is_list_type, diff --git a/src/openai/_utils/_utils.py b/src/openai/_utils/_utils.py index 6d88ad06aa..d6734e6b8f 100644 --- a/src/openai/_utils/_utils.py +++ b/src/openai/_utils/_utils.py @@ -22,7 +22,7 @@ import sniffio -from .._types import Query, NotGiven, FileTypes, NotGivenOr, HeadersLike +from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike from .._compat import parse_date as parse_date, parse_datetime as parse_datetime _T = TypeVar("_T") @@ -428,37 +428,3 @@ def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]: from ..lib.azure import AsyncAzureOpenAI return isinstance(client, AsyncAzureOpenAI) - - -def configure_azure_realtime(client: AzureOpenAI, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]: - auth_headers = {} - query = { - **extra_query, - "api-version": client._api_version, - "deployment": model, - } - if client.api_key != "": - auth_headers = {"api-key": client.api_key} - else: - token = client._get_azure_ad_token() - if token: - auth_headers = {"Authorization": f"Bearer {token}"} - return query, auth_headers - - -async def configure_azure_realtime_async( - client: AsyncAzureOpenAI, model: str, extra_query: Query -) -> tuple[Query, dict[str, str]]: - auth_headers = {} - query = { - **extra_query, - "api-version": client._api_version, - "deployment": model, - } - if client.api_key != "": - auth_headers = {"api-key": client.api_key} - else: - token = await client._get_azure_ad_token() - if token: - auth_headers = {"Authorization": f"Bearer {token}"} - return query, auth_headers diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index 13d9f31838..f857d76e51 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -7,7 +7,7 @@ import httpx -from .._types import NOT_GIVEN, Omit, Timeout, NotGiven +from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven from .._utils import is_given, is_mapping from .._client import OpenAI, AsyncOpenAI from .._compat import model_copy @@ -307,6 +307,21 @@ def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: return options + def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]: + auth_headers = {} + query = { + **extra_query, + "api-version": self._api_version, + "deployment": model, + } + if self.api_key != "": + auth_headers = {"api-key": self.api_key} + else: + token = self._get_azure_ad_token() + if token: + auth_headers = {"Authorization": f"Bearer {token}"} + return query, auth_headers + class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI): @overload @@ -555,3 +570,18 @@ async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOp raise ValueError("Unable to handle auth") return options + + async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]: + auth_headers = {} + query = { + **extra_query, + "api-version": self._api_version, + "deployment": model, + } + if self.api_key != "": + auth_headers = {"api-key": self.api_key} + else: + token = await self._get_azure_ad_token() + if token: + auth_headers = {"Authorization": f"Bearer {token}"} + return query, auth_headers diff --git a/src/openai/resources/beta/realtime/realtime.py b/src/openai/resources/beta/realtime/realtime.py index 3c4f3e4f05..b39b410ecf 100644 --- a/src/openai/resources/beta/realtime/realtime.py +++ b/src/openai/resources/beta/realtime/realtime.py @@ -26,8 +26,6 @@ strip_not_given, async_maybe_transform, is_async_azure_client, - configure_azure_realtime, - configure_azure_realtime_async, ) from ...._compat import cached_property from ...._models import construct_type_unchecked @@ -326,7 +324,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection: extra_query = self.__extra_query auth_headers = self.__client.auth_headers if is_async_azure_client(self.__client): - extra_query, auth_headers = await configure_azure_realtime_async(self.__client, self.__model, extra_query) + extra_query, auth_headers = await self.__client._configure_realtime(self.__model, extra_query) url = self._prepare_url().copy_with( params={ @@ -508,7 +506,7 @@ def __enter__(self) -> RealtimeConnection: extra_query = self.__extra_query auth_headers = self.__client.auth_headers if is_azure_client(self.__client): - extra_query, auth_headers = configure_azure_realtime(self.__client, self.__model, extra_query) + extra_query, auth_headers = self.__client._configure_realtime(self.__model, extra_query) url = self._prepare_url().copy_with( params={