From 86435a400fa11448e04cf2d691403568ba1386e6 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 24 Jun 2024 15:50:41 -0400 Subject: [PATCH 1/2] support api_config overrides --- redisvl/utils/vectorize/text/azureopenai.py | 24 +++++++++++++++------ redisvl/utils/vectorize/text/openai.py | 13 ++++++----- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index fc13eb75..1a0ab15a 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -36,7 +36,7 @@ class AzureOpenAITextVectorizer(BaseVectorizer): "api_key": "your_api_key", # OR set AZURE_OPENAI_API_KEY in your env "api_version": "your_api_version", # OR set OPENAI_API_VERSION in your env "azure_endpoint": "your_azure_endpoint", # OR set AZURE_OPENAI_ENDPOINT in your env - } + } ) embedding = vectorizer.embed("Hello, world!") @@ -61,7 +61,8 @@ def __init__( 'Deployment name' not the 'Model name'. Defaults to 'text-embedding-ada-002'. api_config (Optional[Dict], optional): Dictionary containing the - API key, API version and Azure endpoint. Defaults to None. + API key, API version, Azure endpoint, and any other API options. + Defaults to None. Raises: ImportError: If the openai library is not installed. @@ -75,6 +76,9 @@ def _initialize_clients(self, api_config: Optional[Dict]): Setup the OpenAI clients using the provided API key or an environment variable. """ + if api_config is None: + api_config = {} + # Dynamic import of the openai module try: from openai import AsyncAzureOpenAI, AzureOpenAI @@ -86,7 +90,7 @@ def _initialize_clients(self, api_config: Optional[Dict]): # Fetch the API key, version and endpoint from api_config or environment variable azure_endpoint = ( - api_config.get("azure_endpoint") + api_config.pop("azure_endpoint") if api_config else os.getenv("AZURE_OPENAI_ENDPOINT") ) @@ -99,7 +103,7 @@ def _initialize_clients(self, api_config: Optional[Dict]): ) api_version = ( - api_config.get("api_version") + api_config.pop("api_version") if api_config else os.getenv("OPENAI_API_VERSION") ) @@ -112,7 +116,7 @@ def _initialize_clients(self, api_config: Optional[Dict]): ) api_key = ( - api_config.get("api_key") + api_config.pop("api_key") if api_config else os.getenv("AZURE_OPENAI_API_KEY") ) @@ -125,10 +129,16 @@ def _initialize_clients(self, api_config: Optional[Dict]): ) self._client = AzureOpenAI( - api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint + api_key=api_key, + api_version=api_version, + azure_endpoint=azure_endpoint, + **api_config, ) self._aclient = AsyncAzureOpenAI( - api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint + api_key=api_key, + api_version=api_version, + azure_endpoint=azure_endpoint, + **api_config, ) def _set_model_dims(self, model) -> int: diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index b5d2070c..421ea4ac 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -55,7 +55,7 @@ def __init__( model (str): Model to use for embedding. Defaults to 'text-embedding-ada-002'. api_config (Optional[Dict], optional): Dictionary containing the - API key. Defaults to None. + API key and any additional OpenAI API options. Defaults to None. Raises: ImportError: If the openai library is not installed. @@ -69,6 +69,9 @@ def _initialize_clients(self, api_config: Optional[Dict]): Setup the OpenAI clients using the provided API key or an environment variable. """ + if api_config is None: + api_config = {} + # Dynamic import of the openai module try: from openai import AsyncOpenAI, OpenAI @@ -78,9 +81,9 @@ def _initialize_clients(self, api_config: Optional[Dict]): Please install with `pip install openai`" ) - # Fetch the API key from api_config or environment variable + # Pull the API key from api_config or environment variable api_key = ( - api_config.get("api_key") if api_config else os.getenv("OPENAI_API_KEY") + api_config.pop("api_key") if api_config else os.getenv("OPENAI_API_KEY") ) if not api_key: raise ValueError( @@ -89,8 +92,8 @@ def _initialize_clients(self, api_config: Optional[Dict]): environment variable." ) - self._client = OpenAI(api_key=api_key) - self._aclient = AsyncOpenAI(api_key=api_key) + self._client = OpenAI(api_key=api_key, **api_config) + self._aclient = AsyncOpenAI(api_key=api_key, **api_config) def _set_model_dims(self, model) -> int: try: From 50e10626dea69a7a94e2c4ba1bb2b57349d85a46 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 24 Jun 2024 15:50:46 -0400 Subject: [PATCH 2/2] formatting --- tests/integration/test_query.py | 15 ++++++++++----- tests/unit/test_filter.py | 5 ++++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index 9acc98f5..df348f83 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -401,16 +401,21 @@ def test_sort_range_query(index, sorted_range_query): t = Text("job") % "" search(sorted_range_query, index, t, 7, sort=True) + def test_query_with_chunk_number_zero(): doc_base_id = "8675309" file_id = "e9ffbac9ff6f67cc" chunk_num = 0 filter_conditions = ( - (Tag("doc_base_id") == doc_base_id) & - (Tag("file_id") == file_id) & - (Num("chunk_number") == chunk_num) + (Tag("doc_base_id") == doc_base_id) + & (Tag("file_id") == file_id) + & (Num("chunk_number") == chunk_num) ) - expected_query_str = '((@doc_base_id:{8675309} @file_id:{e9ffbac9ff6f67cc}) @chunk_number:[0 0])' - assert str(filter_conditions) == expected_query_str, "Query with chunk_number zero is incorrect" + expected_query_str = ( + "((@doc_base_id:{8675309} @file_id:{e9ffbac9ff6f67cc}) @chunk_number:[0 0])" + ) + assert ( + str(filter_conditions) == expected_query_str + ), "Query with chunk_number zero is incorrect" diff --git a/tests/unit/test_filter.py b/tests/unit/test_filter.py index fce2783f..067402ea 100644 --- a/tests/unit/test_filter.py +++ b/tests/unit/test_filter.py @@ -286,6 +286,9 @@ def test_filters_combination(): tf4 = Geo("geo_field") == GeoRadius(1.0, 2.0, 3, "km") assert str(tf1 & tf2 & tf3 & tf4) == str(tf1 & tf4) + def test_num_filter_zero(): num_filter = Num("chunk_number") == 0 - assert str(num_filter) == "@chunk_number:[0 0]", "Num filter should handle zero correctly" \ No newline at end of file + assert ( + str(num_filter) == "@chunk_number:[0 0]" + ), "Num filter should handle zero correctly"