From efb3de0695d0d9dfc3b2ef35b0bfef7fe7babdaa Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 23 Jul 2024 10:04:53 -0400 Subject: [PATCH 1/3] vectorize route references together --- redisvl/extensions/router/semantic.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 93621036..15d9b6b5 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -1,4 +1,3 @@ -import hashlib from typing import Any, Dict, List, Optional, Type import redis.commands.search.reducers as reducers @@ -175,13 +174,17 @@ def _add_routes(self, routes: List[Route]): keys: List[str] = [] for route in routes: + # embed route references as a single batch + reference_vectors = self.vectorizer.embed_many( + [reference for reference in route.references], as_buffer=True + ) # set route references - for reference in route.references: + for i, reference in enumerate(route.references): route_references.append( { "route_name": route.name, "reference": reference, - "vector": self.vectorizer.embed(reference, as_buffer=True), + "vector": reference_vectors[i] } ) keys.append(self._route_ref_key(route.name, reference)) @@ -462,3 +465,17 @@ def clear(self) -> None: """Flush all routes from the semantic router index.""" self._index.clear() self.routes = [] + + @classmethod + def from_dict(cls): + pass + + def to_dict(self): + pass + + @classmethod + def from_yaml(cls): + pass + + def to_yaml(self): + pass From bbfd33c93e5b87a1e9f80e469e95d1f21ee14b02 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 24 Jul 2024 15:25:14 -0400 Subject: [PATCH 2/3] add vectorizer from dict factory method --- redisvl/utils/utils.py | 26 +++++++++++++++++++++ redisvl/utils/vectorize/__init__.py | 19 ++++++++++++++- redisvl/utils/vectorize/base.py | 14 +++++++++++ redisvl/utils/vectorize/text/azureopenai.py | 4 ++++ redisvl/utils/vectorize/text/cohere.py | 4 ++++ redisvl/utils/vectorize/text/custom.py | 4 ++++ redisvl/utils/vectorize/text/huggingface.py | 4 ++++ redisvl/utils/vectorize/text/mistral.py | 4 ++++ redisvl/utils/vectorize/text/openai.py | 4 ++++ redisvl/utils/vectorize/text/vertexai.py | 4 ++++ 10 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 redisvl/utils/utils.py diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py new file mode 100644 index 00000000..96c5250e --- /dev/null +++ b/redisvl/utils/utils.py @@ -0,0 +1,26 @@ +from enum import Enum +from typing import Any, Dict + +from pydantic.v1 import BaseModel + + +def model_to_dict(model: BaseModel) -> Dict[str, Any]: + """ + Custom serialization function that converts a Pydantic model to a dict, + serializing Enum fields to their values, and handling nested models and lists. + """ + + def serialize_item(item): + if isinstance(item, Enum): + return item.value.lower() + elif isinstance(item, dict): + return {key: serialize_item(value) for key, value in item.items()} + elif isinstance(item, list): + return [serialize_item(element) for element in item] + else: + return item + + serialized_data = model.dict(exclude_none=True) + for key, value in serialized_data.items(): + serialized_data[key] = serialize_item(value) + return serialized_data diff --git a/redisvl/utils/vectorize/__init__.py b/redisvl/utils/vectorize/__init__.py index ecac2768..52c8363e 100644 --- a/redisvl/utils/vectorize/__init__.py +++ b/redisvl/utils/vectorize/__init__.py @@ -1,4 +1,4 @@ -from redisvl.utils.vectorize.base import BaseVectorizer +from redisvl.utils.vectorize.base import BaseVectorizer, Vectorizers from redisvl.utils.vectorize.text.azureopenai import AzureOpenAITextVectorizer from redisvl.utils.vectorize.text.cohere import CohereTextVectorizer from redisvl.utils.vectorize.text.custom import CustomTextVectorizer @@ -17,3 +17,20 @@ "MistralAITextVectorizer", "CustomTextVectorizer", ] + + +def vectorizer_from_dict(vectorizer: dict) -> BaseVectorizer: + vectorizer_type = Vectorizers(vectorizer["type"]) + model = vectorizer["model"] + if vectorizer_type == Vectorizers.cohere: + return CohereTextVectorizer(model) + elif vectorizer_type == Vectorizers.openai: + return OpenAITextVectorizer(model) + elif vectorizer_type == Vectorizers.azure_openai: + return AzureOpenAITextVectorizer(model) + elif vectorizer_type == Vectorizers.hf: + return HFTextVectorizer(model) + elif vectorizer_type == Vectorizers.mistral: + return MistralAITextVectorizer(model) + elif vectorizer_type == Vectorizers.vertexai: + return VertexAITextVectorizer(model) diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index f5ef8198..3ea2dccd 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from enum import Enum from typing import Callable, List, Optional from pydantic.v1 import BaseModel, validator @@ -6,10 +7,23 @@ from redisvl.redis.utils import array_to_buffer +class Vectorizers(Enum): + azure_openai = "azure_openai" + openai = "openai" + cohere = "cohere" + mistral = "mistral" + vertexai = "vertexai" + hf = "hf" + + class BaseVectorizer(BaseModel, ABC): model: str dims: int + @property + def type(self) -> str: + return "base" + @validator("dims") @classmethod def check_dims(cls, value): diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index 1a0ab15a..734fef5b 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -312,3 +312,7 @@ async def aembed( text = preprocess(text) result = await self._aclient.embeddings.create(input=[text], model=self.model) return self._process_embedding(result.data[0].embedding, as_buffer) + + @property + def type(self) -> str: + return "azure_openai" diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py index cec856dd..47275d40 100644 --- a/redisvl/utils/vectorize/text/cohere.py +++ b/redisvl/utils/vectorize/text/cohere.py @@ -253,3 +253,7 @@ async def aembed( **kwargs, ) -> List[float]: raise NotImplementedError + + @property + def type(self) -> str: + return "cohere" diff --git a/redisvl/utils/vectorize/text/custom.py b/redisvl/utils/vectorize/text/custom.py index 7ccca839..56155f5b 100644 --- a/redisvl/utils/vectorize/text/custom.py +++ b/redisvl/utils/vectorize/text/custom.py @@ -291,3 +291,7 @@ async def aembed_many( results = await self._aembed_many_func(batch, **kwargs) embeddings += [self._process_embedding(r, as_buffer) for r in results] return embeddings + + @property + def type(self) -> str: + return "custom" diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index cb72652e..d5e255c9 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -162,3 +162,7 @@ async def aembed( **kwargs, ) -> List[float]: raise NotImplementedError + + @property + def type(self) -> str: + return "hf" diff --git a/redisvl/utils/vectorize/text/mistral.py b/redisvl/utils/vectorize/text/mistral.py index 4bb9c4fd..8776ef3d 100644 --- a/redisvl/utils/vectorize/text/mistral.py +++ b/redisvl/utils/vectorize/text/mistral.py @@ -260,3 +260,7 @@ async def aembed( text = preprocess(text) result = await self._aclient.embeddings(model=self.model, input=[text]) return self._process_embedding(result.data[0].embedding, as_buffer) + + @property + def type(self) -> str: + return "mistral" diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index 9d6904eb..5921bda8 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -266,3 +266,7 @@ async def aembed( text = preprocess(text) result = await self._aclient.embeddings.create(input=[text], model=self.model) return self._process_embedding(result.data[0].embedding, as_buffer) + + @property + def type(self) -> str: + return "openai" diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index 1d67c672..b7248003 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -212,3 +212,7 @@ async def aembed( **kwargs, ) -> List[float]: raise NotImplementedError + + @property + def type(self) -> str: + return "vertexai" From 41e3b3b305c80e7d98b22d40ad5a3305d6cbdb9f Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 24 Jul 2024 15:25:26 -0400 Subject: [PATCH 3/3] finish up semantic router loading --- docs/user_guide/router.yaml | 35 +++++ docs/user_guide/semantic_router_08.ipynb | 120 +++++++++++++-- redisvl/extensions/router/semantic.py | 179 ++++++++++++++++++++-- redisvl/schema/schema.py | 25 +-- schemas/semantic_router.yaml | 23 +++ tests/integration/test_semantic_router.py | 69 +++++++++ tests/unit/test_schema.py | 2 +- 7 files changed, 406 insertions(+), 47 deletions(-) create mode 100644 docs/user_guide/router.yaml create mode 100644 schemas/semantic_router.yaml diff --git a/docs/user_guide/router.yaml b/docs/user_guide/router.yaml new file mode 100644 index 00000000..ec0453c4 --- /dev/null +++ b/docs/user_guide/router.yaml @@ -0,0 +1,35 @@ +name: topic-router +routes: +- name: technology + references: + - what are the latest advancements in AI? + - tell me about the newest gadgets + - what's trending in tech? + metadata: + category: tech + priority: '1' +- name: sports + references: + - who won the game last night? + - tell me about the upcoming sports events + - what's the latest in the world of sports? + - sports + - basketball and football + metadata: + category: sports + priority: '2' +- name: entertainment + references: + - what are the top movies right now? + - who won the best actor award? + - what's new in the entertainment industry? + metadata: + category: entertainment + priority: '3' +vectorizer: + type: hf + model: sentence-transformers/all-mpnet-base-v2 +routing_config: + distance_threshold: 1.0 + max_k: 3 + aggregation_method: min diff --git a/docs/user_guide/semantic_router_08.ipynb b/docs/user_guide/semantic_router_08.ipynb index 0fc01beb..bfe8b193 100644 --- a/docs/user_guide/semantic_router_08.ipynb +++ b/docs/user_guide/semantic_router_08.ipynb @@ -171,7 +171,7 @@ { "data": { "text/plain": [ - "RouteMatch(name='technology', distance=0.119614183903)" + "RouteMatch(name='technology', distance=0.119614243507)" ] }, "execution_count": 5, @@ -244,9 +244,9 @@ { "data": { "text/plain": [ - "[RouteMatch(name='sports', distance=0.758580672741),\n", - " RouteMatch(name='entertainment', distance=0.812423805396),\n", - " RouteMatch(name='technology', distance=0.884235262871)]" + "[RouteMatch(name='sports', distance=0.758580708504),\n", + " RouteMatch(name='entertainment', distance=0.812423825264),\n", + " RouteMatch(name='technology', distance=0.88423516353)]" ] }, "execution_count": 8, @@ -268,9 +268,9 @@ { "data": { "text/plain": [ - "[RouteMatch(name='sports', distance=0.663254022598),\n", - " RouteMatch(name='entertainment', distance=0.712985336781),\n", - " RouteMatch(name='technology', distance=0.832674443722)]" + "[RouteMatch(name='sports', distance=0.663253903389),\n", + " RouteMatch(name='entertainment', distance=0.712985396385),\n", + " RouteMatch(name='technology', distance=0.832674384117)]" ] }, "execution_count": 9, @@ -321,9 +321,9 @@ { "data": { "text/plain": [ - "[RouteMatch(name='sports', distance=0.663254022598),\n", - " RouteMatch(name='entertainment', distance=0.712985336781),\n", - " RouteMatch(name='technology', distance=0.832674443722)]" + "[RouteMatch(name='sports', distance=0.663253903389),\n", + " RouteMatch(name='entertainment', distance=0.712985396385),\n", + " RouteMatch(name='technology', distance=0.832674384117)]" ] }, "execution_count": 11, @@ -340,13 +340,109 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Clean up the router" + "## Router serialization" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'name': 'topic-router',\n", + " 'routes': [{'name': 'technology',\n", + " 'references': ['what are the latest advancements in AI?',\n", + " 'tell me about the newest gadgets',\n", + " \"what's trending in tech?\"],\n", + " 'metadata': {'category': 'tech', 'priority': '1'}},\n", + " {'name': 'sports',\n", + " 'references': ['who won the game last night?',\n", + " 'tell me about the upcoming sports events',\n", + " \"what's the latest in the world of sports?\",\n", + " 'sports',\n", + " 'basketball and football'],\n", + " 'metadata': {'category': 'sports', 'priority': '2'}},\n", + " {'name': 'entertainment',\n", + " 'references': ['what are the top movies right now?',\n", + " 'who won the best actor award?',\n", + " \"what's new in the entertainment industry?\"],\n", + " 'metadata': {'category': 'entertainment', 'priority': '3'}}],\n", + " 'vectorizer': {'type': 'hf',\n", + " 'model': 'sentence-transformers/all-mpnet-base-v2'},\n", + " 'routing_config': {'distance_threshold': 1.0,\n", + " 'max_k': 3,\n", + " 'aggregation_method': 'min'}}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "router.to_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "15:16:28 redisvl.index.index INFO Index already exists, not overwriting.\n" + ] + } + ], + "source": [ + "router2 = SemanticRouter.from_dict(router.to_dict(), redis_url=\"redis://localhost:6379\")\n", + "\n", + "assert router2 == router" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "router.to_yaml(\"router.yaml\", overwrite=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "15:17:42 redisvl.index.index INFO Index already exists, not overwriting.\n" + ] + } + ], + "source": [ + "router3 = SemanticRouter.from_yaml(\"router.yaml\", redis_url=\"redis://localhost:6379\")\n", + "\n", + "assert router3 == router2 == router" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Clean up the router" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, "outputs": [], "source": [ "# Use clear to flush all routes from the index\n", @@ -355,7 +451,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 15d9b6b5..ee78b9d3 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -1,6 +1,8 @@ +from pathlib import Path from typing import Any, Dict, List, Optional, Type import redis.commands.search.reducers as reducers +import yaml from pydantic.v1 import BaseModel, Field, PrivateAttr from redis import Redis from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer @@ -17,7 +19,12 @@ from redisvl.redis.utils import convert_bytes, hashify, make_dict from redisvl.schema import IndexInfo, IndexSchema from redisvl.utils.log import get_logger -from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer +from redisvl.utils.utils import model_to_dict +from redisvl.utils.vectorize import ( + BaseVectorizer, + HFTextVectorizer, + vectorizer_from_dict, +) logger = get_logger(__name__) @@ -123,8 +130,10 @@ def _initialize_index( if redis_client: self._index.set_client(redis_client) - else: + elif redis_url: self._index.connect(redis_url=redis_url, **connection_kwargs) + else: + raise ValueError("Must provide either a redis client or redis url string.") existed = self._index.exists() self._index.create(overwrite=overwrite) @@ -184,7 +193,7 @@ def _add_routes(self, routes: List[Route]): { "route_name": route.name, "reference": reference, - "vector": reference_vectors[i] + "vector": reference_vectors[i], } ) keys.append(self._route_ref_key(route.name, reference)) @@ -467,15 +476,163 @@ def clear(self) -> None: self.routes = [] @classmethod - def from_dict(cls): - pass + def from_dict( + cls, + data: Dict[str, Any], + redis_client: Optional[Redis] = None, + redis_url: Optional[str] = None, + overwrite: bool = False, + **kwargs, + ) -> "SemanticRouter": + """Create a SemanticRouter from a dictionary. + + Args: + data (Dict[str, Any]): The dictionary containing the semantic router data. + redis_client (Optional[Redis]): Redis client for connection. + redis_url (Optional[str]): Redis URL for connection. + overwrite (bool): Whether to overwrite existing index. + **kwargs: Additional arguments. + + Returns: + SemanticRouter: The semantic router instance. + + Raises: + ValueError: If required data is missing or invalid. + + .. code-block:: python + + from redisvl.extensions.router import SemanticRouter + router_data = { + "name": "example_router", + "routes": [{"name": "route1", "references": ["ref1"], "distance_threshold": 0.5}], + "vectorizer": {"type": "openai", "model": "text-embedding-ada-002"}, + } + router = SemanticRouter.from_dict(router_data) + """ + try: + name = data["name"] + routes_data = data["routes"] + vectorizer_data = data["vectorizer"] + routing_config_data = data["routing_config"] + except KeyError as e: + raise ValueError(f"Unable to load semantic router from dict: {str(e)}") + + try: + vectorizer = vectorizer_from_dict(vectorizer_data) + except Exception as e: + raise ValueError(f"Unable to load vectorizer: {str(e)}") + + if not vectorizer: + raise ValueError(f"Unable to load vectorizer: {vectorizer_data}") + + routes = [Route(**route) for route in routes_data] + routing_config = RoutingConfig(**routing_config_data) + + return cls( + name=name, + routes=routes, + vectorizer=vectorizer, + routing_config=routing_config, + redis_client=redis_client, + redis_url=redis_url, + overwrite=overwrite, + **kwargs, + ) - def to_dict(self): - pass + def to_dict(self) -> Dict[str, Any]: + """Convert the SemanticRouter instance to a dictionary. + + Returns: + Dict[str, Any]: The dictionary representation of the SemanticRouter. + + .. code-block:: python + + from redisvl.extensions.router import SemanticRouter + router = SemanticRouter(name="example_router", routes=[], redis_url="redis://localhost:6379") + router_dict = router.to_dict() + """ + return { + "name": self.name, + "routes": [model_to_dict(route) for route in self.routes], + "vectorizer": { + "type": self.vectorizer.type, + "model": self.vectorizer.model, + }, + "routing_config": model_to_dict(self.routing_config), + } @classmethod - def from_yaml(cls): - pass + def from_yaml( + cls, + file_path: str, + redis_client: Optional[Redis] = None, + redis_url: Optional[str] = None, + overwrite: bool = False, + **kwargs, + ) -> "SemanticRouter": + """Create a SemanticRouter from a YAML file. + + Args: + file_path (str): The path to the YAML file. + redis_client (Optional[Redis]): Redis client for connection. + redis_url (Optional[str]): Redis URL for connection. + overwrite (bool): Whether to overwrite existing index. + **kwargs: Additional arguments. + + Returns: + SemanticRouter: The semantic router instance. + + Raises: + ValueError: If the file path is invalid. + FileNotFoundError: If the file does not exist. + + .. code-block:: python + + from redisvl.extensions.router import SemanticRouter + router = SemanticRouter.from_yaml("router.yaml", redis_url="redis://localhost:6379") + """ + try: + fp = Path(file_path).resolve() + except OSError as e: + raise ValueError(f"Invalid file path: {file_path}") from e + + if not fp.exists(): + raise FileNotFoundError(f"File {file_path} does not exist") + + with open(fp, "r") as f: + yaml_data = yaml.safe_load(f) + return cls.from_dict( + yaml_data, + redis_client=redis_client, + redis_url=redis_url, + overwrite=overwrite, + **kwargs, + ) + + def to_yaml(self, file_path: str, overwrite: bool = True) -> None: + """Write the semantic router to a YAML file. + + Args: + file_path (str): The path to the YAML file. + overwrite (bool): Whether to overwrite the file if it already exists. + + Raises: + FileExistsError: If the file already exists and overwrite is False. + + .. code-block:: python + + from redisvl.extensions.router import SemanticRouter + router = SemanticRouter( + name="example_router", + routes=[], + redis_url="redis://localhost:6379" + ) + router.to_yaml("router.yaml") + """ + fp = Path(file_path).resolve() + if fp.exists() and not overwrite: + raise FileExistsError(f"Schema file {file_path} already exists.") - def to_yaml(self): - pass + with open(fp, "w") as f: + yaml_data = self.to_dict() + yaml.dump(yaml_data, f, sort_keys=False) diff --git a/redisvl/schema/schema.py b/redisvl/schema/schema.py index ed9cffd4..7f3db845 100644 --- a/redisvl/schema/schema.py +++ b/redisvl/schema/schema.py @@ -9,33 +9,12 @@ from redisvl.schema.fields import BaseField, FieldFactory from redisvl.utils.log import get_logger +from redisvl.utils.utils import model_to_dict logger = get_logger(__name__) SCHEMA_VERSION = "0.1.0" -def custom_dict(model: BaseModel) -> Dict[str, Any]: - """ - Custom serialization function that converts a Pydantic model to a dict, - serializing Enum fields to their values, and handling nested models and lists. - """ - - def serialize_item(item): - if isinstance(item, Enum): - return item.value.lower() - elif isinstance(item, dict): - return {key: serialize_item(value) for key, value in item.items()} - elif isinstance(item, list): - return [serialize_item(element) for element in item] - else: - return item - - serialized_data = model.dict(exclude_none=True) - for key, value in serialized_data.items(): - serialized_data[key] = serialize_item(value) - return serialized_data - - class StorageType(Enum): """ Enumeration for the storage types supported in Redis. @@ -452,7 +431,7 @@ def to_dict(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The index schema as a dictionary. """ - dict_schema = custom_dict(self) + dict_schema = model_to_dict(self) # cast fields back to a pure list dict_schema["fields"] = [ field for field_name, field in dict_schema["fields"].items() diff --git a/schemas/semantic_router.yaml b/schemas/semantic_router.yaml new file mode 100644 index 00000000..7b504154 --- /dev/null +++ b/schemas/semantic_router.yaml @@ -0,0 +1,23 @@ +name: test-router +routes: +- name: greeting + references: + - hello + - hi + metadata: + type: greeting + distance_threshold: 0.3 +- name: farewell + references: + - bye + - goodbye + metadata: + type: farewell + distance_threshold: 0.3 +vectorizer: + type: hf + model: sentence-transformers/all-mpnet-base-v2 +routing_config: + distance_threshold: 0.3 + max_k: 2 + aggregation_method: avg diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index e8bb2b38..6c54a2b9 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -1,3 +1,5 @@ +import pathlib + import pytest from redisvl.extensions.router import SemanticRouter @@ -5,6 +7,10 @@ from redisvl.redis.connection import compare_versions +def get_base_path(): + return pathlib.Path(__file__).parent.resolve() + + @pytest.fixture def routes(): return [ @@ -156,3 +162,66 @@ def test_remove_routes(semantic_router): semantic_router.remove_route("unknown_route") assert semantic_router.get("unknown_route") is None + + +def test_to_dict(semantic_router): + router_dict = semantic_router.to_dict() + assert router_dict["name"] == semantic_router.name + assert len(router_dict["routes"]) == len(semantic_router.routes) + assert router_dict["vectorizer"]["type"] == semantic_router.vectorizer.type + + +def test_from_dict(semantic_router): + router_dict = semantic_router.to_dict() + new_router = SemanticRouter.from_dict( + router_dict, redis_client=semantic_router._index.client + ) + assert new_router == semantic_router + + +def test_to_yaml(semantic_router): + yaml_file = str(get_base_path().joinpath("../../schemas/semantic_router.yaml")) + semantic_router.to_yaml(yaml_file, overwrite=True) + assert pathlib.Path(yaml_file).exists() + + +def test_from_yaml(semantic_router): + yaml_file = str(get_base_path().joinpath("../../schemas/semantic_router.yaml")) + new_router = SemanticRouter.from_yaml( + yaml_file, redis_client=semantic_router._index.client, overwrite=True + ) + assert new_router == semantic_router + + +def test_to_dict_missing_fields(): + data = { + "name": "incomplete-router", + "routes": [], + "vectorizer": {"type": "HFTextVectorizer", "model": "bert-base-uncased"}, + } + with pytest.raises(ValueError): + SemanticRouter.from_dict(data) + + +def test_invalid_vectorizer(): + data = { + "name": "invalid-router", + "routes": [], + "vectorizer": {"type": "InvalidVectorizer", "model": "invalid-model"}, + "routing_config": {}, + } + with pytest.raises(ValueError): + SemanticRouter.from_dict(data) + + +def test_yaml_invalid_file_path(): + with pytest.raises(FileNotFoundError): + SemanticRouter.from_yaml("invalid_path.yaml", redis_client=None) + + +def test_idempotent_to_dict(semantic_router): + router_dict = semantic_router.to_dict() + new_router = SemanticRouter.from_dict( + router_dict, redis_client=semantic_router._index.client + ) + assert new_router.to_dict() == router_dict diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index f6e84702..26878cb5 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -4,7 +4,7 @@ import pytest from redisvl.schema.fields import TagField, TextField -from redisvl.schema.schema import IndexSchema, StorageType, custom_dict +from redisvl.schema.schema import IndexSchema, StorageType def get_base_path():