Skip to content

Commit 67eee3d

Browse files
Refactor search index to improve connection handling (#192)
We were leaning on a hack to run some async code in a sync setting. This was both dangerous and an anti-pattern. We also needed to refactor some of the shared and non-shared content between the BaseSearchIndex and derivatives.
1 parent cb61457 commit 67eee3d

File tree

8 files changed

+180
-184
lines changed

8 files changed

+180
-184
lines changed

docs/examples/openai_qna.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@
651651
"client = redis.Redis.from_url(\"redis://localhost:6379\")\n",
652652
"schema = IndexSchema.from_yaml(\"wiki_schema.yaml\")\n",
653653
"\n",
654-
"index = AsyncSearchIndex(schema, client)\n",
654+
"index = await AsyncSearchIndex(schema).set_client(client)\n",
655655
"\n",
656656
"await index.create()"
657657
]

docs/user_guide/getting_started_01.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@
486486
"client = Redis.from_url(\"redis://localhost:6379\")\n",
487487
"\n",
488488
"index = AsyncSearchIndex.from_dict(schema)\n",
489-
"index.set_client(client)"
489+
"await index.set_client(client)"
490490
]
491491
},
492492
{

redisvl/extensions/session_manager/standard_session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
self._client = RedisConnectionFactory.get_redis_connection(
5252
redis_url, **connection_kwargs
5353
)
54-
RedisConnectionFactory.validate_redis(self._client)
54+
RedisConnectionFactory.validate_sync_redis(self._client)
5555

5656
self.set_scope(session_tag, user_tag)
5757

redisvl/index/index.py

+105-63
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,24 @@ def decorator(func):
9494
@wraps(func)
9595
def wrapper(self, *args, **kwargs):
9696
result = func(self, *args, **kwargs)
97-
RedisConnectionFactory.validate_redis(self._redis_client, self._lib_name)
97+
RedisConnectionFactory.validate_sync_redis(
98+
self._redis_client, self._lib_name
99+
)
100+
return result
101+
102+
return wrapper
103+
104+
return decorator
105+
106+
107+
def setup_async_redis():
108+
def decorator(func):
109+
@wraps(func)
110+
async def wrapper(self, *args, **kwargs):
111+
result = await func(self, *args, **kwargs)
112+
await RedisConnectionFactory.validate_async_redis(
113+
self._redis_client, self._lib_name
114+
)
98115
return result
99116

100117
return wrapper
@@ -140,41 +157,10 @@ class BaseSearchIndex:
140157
StorageType.JSON: JsonStorage,
141158
}
142159

143-
def __init__(
144-
self,
145-
schema: IndexSchema,
146-
redis_client: Optional[Union[redis.Redis, aredis.Redis]] = None,
147-
redis_url: Optional[str] = None,
148-
connection_args: Dict[str, Any] = {},
149-
**kwargs,
150-
):
151-
"""Initialize the RedisVL search index with a schema, Redis client
152-
(or URL string with other connection args), connection_args, and other
153-
kwargs.
154-
155-
Args:
156-
schema (IndexSchema): Index schema object.
157-
redis_client(Union[redis.Redis, aredis.Redis], optional): An
158-
instantiated redis client.
159-
redis_url (str, optional): The URL of the Redis server to
160-
connect to.
161-
connection_args (Dict[str, Any], optional): Redis client connection
162-
args.
163-
"""
164-
# final validation on schema object
165-
if not isinstance(schema, IndexSchema):
166-
raise ValueError("Must provide a valid IndexSchema object")
167-
168-
self.schema = schema
169-
170-
self._lib_name: Optional[str] = kwargs.pop("lib_name", None)
160+
schema: IndexSchema
171161

172-
# set up redis connection
173-
self._redis_client: Optional[Union[redis.Redis, aredis.Redis]] = None
174-
if redis_client is not None:
175-
self.set_client(redis_client)
176-
elif redis_url is not None:
177-
self.connect(redis_url, **connection_args)
162+
def __init__(*args, **kwargs):
163+
pass
178164

179165
@property
180166
def _storage(self) -> BaseStorage:
@@ -237,8 +223,6 @@ def from_dict(cls, schema_dict: Dict[str, Any], **kwargs):
237223
238224
Args:
239225
schema_dict (Dict[str, Any]): A dictionary containing the schema.
240-
connection_args (Dict[str, Any], optional): Redis client connection
241-
args.
242226
243227
Returns:
244228
SearchIndex: A RedisVL SearchIndex object.
@@ -262,14 +246,6 @@ def from_dict(cls, schema_dict: Dict[str, Any], **kwargs):
262246
schema = IndexSchema.from_dict(schema_dict)
263247
return cls(schema=schema, **kwargs)
264248

265-
def connect(self, redis_url: Optional[str] = None, **kwargs):
266-
"""Connect to Redis at a given URL."""
267-
raise NotImplementedError
268-
269-
def set_client(self, client: Union[redis.Redis, aredis.Redis]):
270-
"""Manually set the Redis client to use with the search index."""
271-
raise NotImplementedError
272-
273249
def disconnect(self):
274250
"""Disconnect from the Redis database."""
275251
self._redis_client = None
@@ -323,6 +299,43 @@ class SearchIndex(BaseSearchIndex):
323299
324300
"""
325301

302+
def __init__(
303+
self,
304+
schema: IndexSchema,
305+
redis_client: Optional[redis.Redis] = None,
306+
redis_url: Optional[str] = None,
307+
connection_args: Dict[str, Any] = {},
308+
**kwargs,
309+
):
310+
"""Initialize the RedisVL search index with a schema, Redis client
311+
(or URL string with other connection args), connection_args, and other
312+
kwargs.
313+
314+
Args:
315+
schema (IndexSchema): Index schema object.
316+
redis_client(Optional[redis.Redis]): An
317+
instantiated redis client.
318+
redis_url (Optional[str]): The URL of the Redis server to
319+
connect to.
320+
connection_args (Dict[str, Any], optional): Redis client connection
321+
args.
322+
"""
323+
# final validation on schema object
324+
if not isinstance(schema, IndexSchema):
325+
raise ValueError("Must provide a valid IndexSchema object")
326+
327+
self.schema = schema
328+
329+
self._lib_name: Optional[str] = kwargs.pop("lib_name", None)
330+
331+
# set up redis connection
332+
self._redis_client: Optional[redis.Redis] = None
333+
334+
if redis_client is not None:
335+
self.set_client(redis_client)
336+
elif redis_url is not None:
337+
self.connect(redis_url, **connection_args)
338+
326339
@classmethod
327340
def from_existing(
328341
cls,
@@ -342,7 +355,7 @@ def from_existing(
342355
)
343356

344357
# Validate modules
345-
installed_modules = RedisConnectionFactory._get_modules(redis_client)
358+
installed_modules = RedisConnectionFactory.get_modules(redis_client)
346359
validate_modules(installed_modules, [{"name": "search", "ver": 20810}])
347360

348361
# Fetch index info and convert to schema
@@ -380,15 +393,15 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
380393
return self.set_client(client)
381394

382395
@setup_redis()
383-
def set_client(self, client: redis.Redis, **kwargs):
396+
def set_client(self, redis_client: redis.Redis, **kwargs):
384397
"""Manually set the Redis client to use with the search index.
385398
386399
This method configures the search index to use a specific Redis or
387400
Async Redis client. It is useful for cases where an external,
388401
custom-configured client is preferred instead of creating a new one.
389402
390403
Args:
391-
client (redis.Redis): A Redis or Async Redis
404+
redis_client (redis.Redis): A Redis or Async Redis
392405
client instance to be used for the connection.
393406
394407
Raises:
@@ -404,10 +417,10 @@ def set_client(self, client: redis.Redis, **kwargs):
404417
index.set_client(client)
405418
406419
"""
407-
if not isinstance(client, redis.Redis):
420+
if not isinstance(redis_client, redis.Redis):
408421
raise TypeError("Invalid Redis client instance")
409422

410-
self._redis_client = client
423+
self._redis_client = redis_client
411424

412425
return self
413426

@@ -759,7 +772,7 @@ class AsyncSearchIndex(BaseSearchIndex):
759772
760773
# initialize the index object with schema from file
761774
index = AsyncSearchIndex.from_yaml("schemas/schema.yaml")
762-
index.connect(redis_url="redis://localhost:6379")
775+
await index.connect(redis_url="redis://localhost:6379")
763776
764777
# create the index
765778
await index.create(overwrite=True)
@@ -772,6 +785,34 @@ class AsyncSearchIndex(BaseSearchIndex):
772785
773786
"""
774787

788+
def __init__(
789+
self,
790+
schema: IndexSchema,
791+
**kwargs,
792+
):
793+
"""Initialize the RedisVL async search index with a schema.
794+
795+
Args:
796+
schema (IndexSchema): Index schema object.
797+
connection_args (Dict[str, Any], optional): Redis client connection
798+
args.
799+
"""
800+
# final validation on schema object
801+
if not isinstance(schema, IndexSchema):
802+
raise ValueError("Must provide a valid IndexSchema object")
803+
804+
self.schema = schema
805+
806+
self._lib_name: Optional[str] = kwargs.pop("lib_name", None)
807+
808+
# set up empty redis connection
809+
self._redis_client: Optional[aredis.Redis] = None
810+
811+
if "redis_client" in kwargs or "redis_url" in kwargs:
812+
logger.warning(
813+
"Must use set_client() or connect() methods to provide a Redis connection to AsyncSearchIndex"
814+
)
815+
775816
@classmethod
776817
async def from_existing(
777818
cls,
@@ -791,18 +832,18 @@ async def from_existing(
791832
)
792833

793834
# Validate modules
794-
installed_modules = await RedisConnectionFactory._get_modules_async(
795-
redis_client
796-
)
835+
installed_modules = await RedisConnectionFactory.get_modules_async(redis_client)
797836
validate_modules(installed_modules, [{"name": "search", "ver": 20810}])
798837

799838
# Fetch index info and convert to schema
800839
index_info = await cls._info(name, redis_client)
801840
schema_dict = convert_index_info_to_schema(index_info)
802841
schema = IndexSchema.from_dict(schema_dict)
803-
return cls(schema, redis_client, **kwargs)
842+
index = cls(schema, **kwargs)
843+
await index.set_client(redis_client)
844+
return index
804845

805-
def connect(self, redis_url: Optional[str] = None, **kwargs):
846+
async def connect(self, redis_url: Optional[str] = None, **kwargs):
806847
"""Connect to a Redis instance using the provided `redis_url`, falling
807848
back to the `REDIS_URL` environment variable (if available).
808849
@@ -828,18 +869,18 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
828869
client = RedisConnectionFactory.connect(
829870
redis_url=redis_url, use_async=True, **kwargs
830871
)
831-
return self.set_client(client)
872+
return await self.set_client(client)
832873

833-
@setup_redis()
834-
def set_client(self, client: aredis.Redis):
874+
@setup_async_redis()
875+
async def set_client(self, redis_client: aredis.Redis):
835876
"""Manually set the Redis client to use with the search index.
836877
837878
This method configures the search index to use a specific
838879
Async Redis client. It is useful for cases where an external,
839880
custom-configured client is preferred instead of creating a new one.
840881
841882
Args:
842-
client (aredis.Redis): An Async Redis
883+
redis_client (aredis.Redis): An Async Redis
843884
client instance to be used for the connection.
844885
845886
Raises:
@@ -853,13 +894,13 @@ def set_client(self, client: aredis.Redis):
853894
# async Redis client and index
854895
client = aredis.Redis.from_url("redis://localhost:6379")
855896
index = AsyncSearchIndex.from_yaml("schemas/schema.yaml")
856-
index.set_client(client)
897+
await index.set_client(client)
857898
858899
"""
859-
if not isinstance(client, aredis.Redis):
900+
if not isinstance(redis_client, aredis.Redis):
860901
raise TypeError("Invalid Redis client instance")
861902

862-
self._redis_client = client
903+
self._redis_client = redis_client
863904

864905
return self
865906

@@ -889,6 +930,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
889930
await index.create(overwrite=True, drop=True)
890931
"""
891932
redis_fields = self.schema.redis_fields
933+
892934
if not redis_fields:
893935
raise ValueError("No fields defined for index")
894936
if not isinstance(overwrite, bool):

0 commit comments

Comments
 (0)