Skip to content

Commit ee8d602

Browse files
Support OpenAI config values in vectorizers (#171)
Addresses part of #160
1 parent e332bea commit ee8d602

File tree

4 files changed

+39
-18
lines changed

4 files changed

+39
-18
lines changed

redisvl/utils/vectorize/text/azureopenai.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class AzureOpenAITextVectorizer(BaseVectorizer):
3636
"api_key": "your_api_key", # OR set AZURE_OPENAI_API_KEY in your env
3737
"api_version": "your_api_version", # OR set OPENAI_API_VERSION in your env
3838
"azure_endpoint": "your_azure_endpoint", # OR set AZURE_OPENAI_ENDPOINT in your env
39-
}
39+
}
4040
)
4141
embedding = vectorizer.embed("Hello, world!")
4242
@@ -61,7 +61,8 @@ def __init__(
6161
'Deployment name' not the 'Model name'. Defaults to
6262
'text-embedding-ada-002'.
6363
api_config (Optional[Dict], optional): Dictionary containing the
64-
API key, API version and Azure endpoint. Defaults to None.
64+
API key, API version, Azure endpoint, and any other API options.
65+
Defaults to None.
6566
6667
Raises:
6768
ImportError: If the openai library is not installed.
@@ -75,6 +76,9 @@ def _initialize_clients(self, api_config: Optional[Dict]):
7576
Setup the OpenAI clients using the provided API key or an
7677
environment variable.
7778
"""
79+
if api_config is None:
80+
api_config = {}
81+
7882
# Dynamic import of the openai module
7983
try:
8084
from openai import AsyncAzureOpenAI, AzureOpenAI
@@ -86,7 +90,7 @@ def _initialize_clients(self, api_config: Optional[Dict]):
8690

8791
# Fetch the API key, version and endpoint from api_config or environment variable
8892
azure_endpoint = (
89-
api_config.get("azure_endpoint")
93+
api_config.pop("azure_endpoint")
9094
if api_config
9195
else os.getenv("AZURE_OPENAI_ENDPOINT")
9296
)
@@ -99,7 +103,7 @@ def _initialize_clients(self, api_config: Optional[Dict]):
99103
)
100104

101105
api_version = (
102-
api_config.get("api_version")
106+
api_config.pop("api_version")
103107
if api_config
104108
else os.getenv("OPENAI_API_VERSION")
105109
)
@@ -112,7 +116,7 @@ def _initialize_clients(self, api_config: Optional[Dict]):
112116
)
113117

114118
api_key = (
115-
api_config.get("api_key")
119+
api_config.pop("api_key")
116120
if api_config
117121
else os.getenv("AZURE_OPENAI_API_KEY")
118122
)
@@ -125,10 +129,16 @@ def _initialize_clients(self, api_config: Optional[Dict]):
125129
)
126130

127131
self._client = AzureOpenAI(
128-
api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint
132+
api_key=api_key,
133+
api_version=api_version,
134+
azure_endpoint=azure_endpoint,
135+
**api_config,
129136
)
130137
self._aclient = AsyncAzureOpenAI(
131-
api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint
138+
api_key=api_key,
139+
api_version=api_version,
140+
azure_endpoint=azure_endpoint,
141+
**api_config,
132142
)
133143

134144
def _set_model_dims(self, model) -> int:

redisvl/utils/vectorize/text/openai.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
model (str): Model to use for embedding. Defaults to
5656
'text-embedding-ada-002'.
5757
api_config (Optional[Dict], optional): Dictionary containing the
58-
API key. Defaults to None.
58+
API key and any additional OpenAI API options. Defaults to None.
5959
6060
Raises:
6161
ImportError: If the openai library is not installed.
@@ -69,6 +69,9 @@ def _initialize_clients(self, api_config: Optional[Dict]):
6969
Setup the OpenAI clients using the provided API key or an
7070
environment variable.
7171
"""
72+
if api_config is None:
73+
api_config = {}
74+
7275
# Dynamic import of the openai module
7376
try:
7477
from openai import AsyncOpenAI, OpenAI
@@ -78,9 +81,9 @@ def _initialize_clients(self, api_config: Optional[Dict]):
7881
Please install with `pip install openai`"
7982
)
8083

81-
# Fetch the API key from api_config or environment variable
84+
# Pull the API key from api_config or environment variable
8285
api_key = (
83-
api_config.get("api_key") if api_config else os.getenv("OPENAI_API_KEY")
86+
api_config.pop("api_key") if api_config else os.getenv("OPENAI_API_KEY")
8487
)
8588
if not api_key:
8689
raise ValueError(
@@ -89,8 +92,8 @@ def _initialize_clients(self, api_config: Optional[Dict]):
8992
environment variable."
9093
)
9194

92-
self._client = OpenAI(api_key=api_key)
93-
self._aclient = AsyncOpenAI(api_key=api_key)
95+
self._client = OpenAI(api_key=api_key, **api_config)
96+
self._aclient = AsyncOpenAI(api_key=api_key, **api_config)
9497

9598
def _set_model_dims(self, model) -> int:
9699
try:

tests/integration/test_query.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -401,16 +401,21 @@ def test_sort_range_query(index, sorted_range_query):
401401
t = Text("job") % ""
402402
search(sorted_range_query, index, t, 7, sort=True)
403403

404+
404405
def test_query_with_chunk_number_zero():
405406
doc_base_id = "8675309"
406407
file_id = "e9ffbac9ff6f67cc"
407408
chunk_num = 0
408409

409410
filter_conditions = (
410-
(Tag("doc_base_id") == doc_base_id) &
411-
(Tag("file_id") == file_id) &
412-
(Num("chunk_number") == chunk_num)
411+
(Tag("doc_base_id") == doc_base_id)
412+
& (Tag("file_id") == file_id)
413+
& (Num("chunk_number") == chunk_num)
413414
)
414415

415-
expected_query_str = '((@doc_base_id:{8675309} @file_id:{e9ffbac9ff6f67cc}) @chunk_number:[0 0])'
416-
assert str(filter_conditions) == expected_query_str, "Query with chunk_number zero is incorrect"
416+
expected_query_str = (
417+
"((@doc_base_id:{8675309} @file_id:{e9ffbac9ff6f67cc}) @chunk_number:[0 0])"
418+
)
419+
assert (
420+
str(filter_conditions) == expected_query_str
421+
), "Query with chunk_number zero is incorrect"

tests/unit/test_filter.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,9 @@ def test_filters_combination():
286286
tf4 = Geo("geo_field") == GeoRadius(1.0, 2.0, 3, "km")
287287
assert str(tf1 & tf2 & tf3 & tf4) == str(tf1 & tf4)
288288

289+
289290
def test_num_filter_zero():
290291
num_filter = Num("chunk_number") == 0
291-
assert str(num_filter) == "@chunk_number:[0 0]", "Num filter should handle zero correctly"
292+
assert (
293+
str(num_filter) == "@chunk_number:[0 0]"
294+
), "Num filter should handle zero correctly"

0 commit comments

Comments
 (0)