From af1e141b7001fc0f3f65f226f8ffc4e39a129cd2 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 8 Jul 2024 18:28:19 -0400 Subject: [PATCH 01/16] way too early WIP --- redisvl/extensions/router/__init__.py | 0 redisvl/extensions/router/routes.py | 48 +++++++++ redisvl/extensions/router/semantic.py | 135 ++++++++++++++++++++++++++ 3 files changed, 183 insertions(+) create mode 100644 redisvl/extensions/router/__init__.py create mode 100644 redisvl/extensions/router/routes.py create mode 100644 redisvl/extensions/router/semantic.py diff --git a/redisvl/extensions/router/__init__.py b/redisvl/extensions/router/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/redisvl/extensions/router/routes.py b/redisvl/extensions/router/routes.py new file mode 100644 index 00000000..0d2ec0c1 --- /dev/null +++ b/redisvl/extensions/router/routes.py @@ -0,0 +1,48 @@ + + +from pydantic.v1 import BaseModel, Field, validator +from typing import List, Dict, Optional + + +class Route(BaseModel): + name: str + """The name of the route""" + references: List[str] + """List of reference phrases for the route""" + metadata: Dict[str, str] = Field(default={}) + """Metadata associated with the route""" + + @validator('name') + def name_must_not_be_empty(cls, v): + if not v or not v.strip(): + raise ValueError('Route name must not be empty') + return v + + @validator('references') + def references_must_not_be_empty(cls, v): + if not v: + raise ValueError('References must not be empty') + if any(not ref.strip() for ref in v): + raise ValueError('All references must be non-empty strings') + return v + + +class RoutingConfig(BaseModel): + top_k: int = Field(default=1) + """The maximum number of top matches to return""" + distance_threshold: Optional[float] = None + """The threshold for semantic distance""" + # TODO: need more here + + @validator('top_k') + def top_k_must_be_positive(cls, v): + if v <= 0: + raise ValueError('top_k must be a positive integer') + return v + + @validator('distance_threshold') + def distance_threshold_must_be_valid(cls, v): + if v is not None and (v <= 0 or v > 1): + raise ValueError('distance_threshold must be between 0 and 1') + return v + diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py new file mode 100644 index 00000000..c219bea5 --- /dev/null +++ b/redisvl/extensions/router/semantic.py @@ -0,0 +1,135 @@ +from pydantic.v1 import BaseModel, root_validator, Field +from typing import Any, List, Dict, Optional, Union +from redis import Redis +from redisvl.index import SearchIndex +from redisvl.schema import IndexSchema, IndexInfo +from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer +from redisvl.extensions.router.routes import Route, RoutingConfig + +import hashlib + + +class SemanticRouterIndexSchema(IndexSchema): + + @classmethod + def from_params(cls, name: str, vector_dims: int): + return cls( + index=IndexInfo(name=name, prefix=name), + fields={ + "route_name": {"name": "route_name", "type": "tag"}, + "reference": {"name": "reference", "type": "text"}, + "vector": { + "name": "vector", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": vector_dims, + "distance_metric": "cosine", + "datatype": "float32" + } + } + } + ) + + +class SemanticRouter(BaseModel): + name: str + """The name of the semantic router""" + vectorizer: BaseVectorizer = Field(default_factory=HFTextVectorizer) + """The vectorizer used to embed route references""" + routes: List[Route] + """List of Route objects""" + routing_config: RoutingConfig = Field(default_factory=RoutingConfig) + """Configuration for routing behavior""" + + class Config: + arbitrary_types_allowed = True + + def __init__(self, **data): + super().__init__(**data) + self._initialize_index(**data) + + def _initialize_index(self, **data): + """Initialize the search index and handle Redis connection. + + Args: + data (dict): Initialization data containing Redis connection details. + """ + # Extract connection parameters + redis_url = data.pop("redis_url", "redis://localhost:6379") + redis_client = data.pop("redis_client", None) + connection_args = data.pop("connection_args", {}) + + # Create search index schema + schema = SemanticRouterIndexSchema.from_params(self.name, self.vectorizer.dims) + + # Build search index + self._index = SearchIndex(schema=schema) + + # Handle Redis connection + if redis_client: + self._index.set_client(redis_client) + else: + self._index.connect(redis_url=redis_url, **connection_args) + + if not self._index.exists(): + self._add_routes(self.routes) + + self._index.create(overwrite=False) + + def update_routing_config(self, routing_config: RoutingConfig): + """Update the routing configuration. + + Args: + routing_config (RoutingConfig): The new routing configuration. + """ + self.routing_config = routing_config + # TODO: Ensure Pydantic handles the validation here + # TODO: Determine if we need to persist this to Redis + + def _add_routes(self, routes: List[Route]): + """Add routes to the index. + + Args: + routes (List[Route]): List of routes to be added. + """ + route_references: List[Dict[str, Any]] = [] + keys: List[str] = [] + + for route in routes: + for reference in route.references: + route_references.append({ + "route_name": route.name, + "reference": reference, + "vector": self.vectorizer.embed(reference) + }) + reference_hash = hashlib.sha256(reference.encode("utf-8")).hexdigest() + keys.append(f"{self._index.schema.index.prefix}:{route.name}:{reference_hash}") + + self._index.load(route_references, keys=keys) + + + def __call__( + self, + statement: str, + top_k: Optional[int] = None, + distance_threshold: Optional[float] = None, + ) -> List[Dict[str, Any]]: + """Query the semantic router with a given statement. + + Args: + statement (str): The input statement to be queried. + top_k (Optional[int]): The maximum number of top matches to return. + distance_threshold (Optional[float]): The threshold for semantic distance. + + Returns: + List[Dict[str, Any]]: The matching routes and their details. + """ + vector = self.vectorizer.embed(statement) + top_k = top_k if top_k is not None else self.routing_config.top_k + distance_threshold = distance_threshold if distance_threshold is not None else self.routing_config.distance_threshold + + # TODO: Implement the query logic based on top_k and distance_threshold + results = [] + + return results From 73332cc8266cf7f7349d5ccdfad4a33e8a47a12c Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 8 Jul 2024 22:41:13 -0400 Subject: [PATCH 02/16] cleanup and also add temp test notebook for experimentation --- redisvl/extensions/router/semantic.py | 96 ++++++++++---- redisvl/extensions/router/test.ipynb | 172 ++++++++++++++++++++++++++ redisvl/schema/schema.py | 6 +- 3 files changed, 248 insertions(+), 26 deletions(-) create mode 100644 redisvl/extensions/router/test.ipynb diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index c219bea5..12d415d9 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -1,7 +1,8 @@ -from pydantic.v1 import BaseModel, root_validator, Field +from pydantic.v1 import BaseModel, root_validator, Field, PrivateAttr from typing import Any, List, Dict, Optional, Union from redis import Redis from redisvl.index import SearchIndex +from redisvl.query import VectorQuery, RangeQuery from redisvl.schema import IndexSchema, IndexInfo from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer from redisvl.extensions.router.routes import Route, RoutingConfig @@ -15,10 +16,10 @@ class SemanticRouterIndexSchema(IndexSchema): def from_params(cls, name: str, vector_dims: int): return cls( index=IndexInfo(name=name, prefix=name), - fields={ - "route_name": {"name": "route_name", "type": "tag"}, - "reference": {"name": "reference", "type": "text"}, - "vector": { + fields=[ + {"name": "route_name", "type": "tag"}, + {"name": "reference", "type": "text"}, + { "name": "vector", "type": "vector", "attrs": { @@ -28,38 +29,56 @@ def from_params(cls, name: str, vector_dims: int): "datatype": "float32" } } - } + ] ) class SemanticRouter(BaseModel): name: str """The name of the semantic router""" - vectorizer: BaseVectorizer = Field(default_factory=HFTextVectorizer) - """The vectorizer used to embed route references""" routes: List[Route] """List of Route objects""" + vectorizer: BaseVectorizer = Field(default_factory=HFTextVectorizer) + """The vectorizer used to embed route references""" routing_config: RoutingConfig = Field(default_factory=RoutingConfig) """Configuration for routing behavior""" + _index: SearchIndex = PrivateAttr() + class Config: arbitrary_types_allowed = True - def __init__(self, **data): - super().__init__(**data) - self._initialize_index(**data) + def __init__( + self, + name: str, + routes: List[Route], + vectorizer: BaseVectorizer = HFTextVectorizer(), + routing_config: RoutingConfig = RoutingConfig(), + redis_client: Optional[Redis] = None, + redis_url: str = "redis://localhost:6379", + overwrite: bool = False, + **kwargs + ): + super().__init__( + name=name, + routes=routes, + vectorizer=vectorizer, + routing_config=routing_config + ) + self._initialize_index(redis_client, redis_url, overwrite) - def _initialize_index(self, **data): + def _initialize_index( + self, + redis_client: Optional[Redis] = None, + redis_url: str = "redis://localhost:6379", + overwrite: bool = False, + **connection_kwargs + ): """Initialize the search index and handle Redis connection. Args: data (dict): Initialization data containing Redis connection details. """ - # Extract connection parameters - redis_url = data.pop("redis_url", "redis://localhost:6379") - redis_client = data.pop("redis_client", None) - connection_args = data.pop("connection_args", {}) - # Create search index schema schema = SemanticRouterIndexSchema.from_params(self.name, self.vectorizer.dims) @@ -70,12 +89,18 @@ def _initialize_index(self, **data): if redis_client: self._index.set_client(redis_client) else: - self._index.connect(redis_url=redis_url, **connection_args) + self._index.connect(redis_url=redis_url, **connection_kwargs) - if not self._index.exists(): + existed = self._index.exists() + self._index.create(overwrite=overwrite) + + # If the index did not yet exist OR we overwrote it + if not existed or overwrite: self._add_routes(self.routes) - self._index.create(overwrite=False) + # TODO : double check this kind of logic + + def update_routing_config(self, routing_config: RoutingConfig): """Update the routing configuration. @@ -95,13 +120,13 @@ def _add_routes(self, routes: List[Route]): """ route_references: List[Dict[str, Any]] = [] keys: List[str] = [] - + # Iteratively load route references for route in routes: for reference in route.references: route_references.append({ "route_name": route.name, "reference": reference, - "vector": self.vectorizer.embed(reference) + "vector": self.vectorizer.embed(reference, as_buffer=True) }) reference_hash = hashlib.sha256(reference.encode("utf-8")).hexdigest() keys.append(f"{self._index.schema.index.prefix}:{route.name}:{reference_hash}") @@ -129,7 +154,28 @@ def __call__( top_k = top_k if top_k is not None else self.routing_config.top_k distance_threshold = distance_threshold if distance_threshold is not None else self.routing_config.distance_threshold - # TODO: Implement the query logic based on top_k and distance_threshold - results = [] + if distance_threshold: + query = RangeQuery( + vector=vector, + vector_field_name="vector", + distance_threshold=distance_threshold, + return_fields=["route_name", "reference"], + num_results=top_k # need to fetch more to be able to do aggregation + ) + else: + query = VectorQuery( + vector=vector, + vector_field_name="vector", + return_fields=["route_name", "reference"], + num_results=top_k # need to fetch more to be able to do aggregation + ) + + route_references = self._index.query(query) + + # TODO use accumulation strategy to aggregation (sum or avg) the scores by the associated route + #top_routes_and_scores = ... + + # TODO fetch the route objects and metadata directly from this class based on top matches + #results = ... - return results + return route_references diff --git a/redisvl/extensions/router/test.ipynb b/redisvl/extensions/router/test.ipynb new file mode 100644 index 00000000..bc151c3a --- /dev/null +++ b/redisvl/extensions/router/test.ipynb @@ -0,0 +1,172 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from redisvl.extensions.router.routes import Route, RoutingConfig" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Define individual routes manually with metadata\n", + "politics = Route(\n", + " name=\"politics\",\n", + " references=[\n", + " \"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\"\n", + " ],\n", + " metadata={\"priority\": 1}\n", + ")\n", + "\n", + "chitchat = Route(\n", + " name=\"chitchat\",\n", + " references=[\n", + " \"hello\",\n", + " \"how's the weather today?\",\n", + " \"how are things going?\"\n", + " ],\n", + " metadata={\"priority\": 2}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "routes = [politics, chitchat]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "config = RoutingConfig(top_k=5, distance_threshold=0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "22:39:55 redisvl.index.index INFO Index already exists, not overwriting.\n" + ] + } + ], + "source": [ + "from redisvl.extensions.router.semantic import SemanticRouter\n", + "import redis\n", + "\n", + "# Create SemanticRouter named \"topic-router\"\n", + "redis_client = redis.Redis()\n", + "routes = [politics, chitchat]\n", + "topic_router = SemanticRouter(\n", + " name=\"topic-router\",\n", + " routes=routes,\n", + " routing_config=config,\n", + " redis_client=redis_client\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Query topic-router with default behavior based on the config\n", + "result = topic_router(\"don't you love politics?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "assert topic_router.routes == routes\n", + "assert topic_router.name == \"topic-router\"\n", + "assert topic_router.name == topic_router._index.name == topic_router._index.prefix\n", + "assert topic_router.routing_config == config" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config.top_k" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'id': 'topic-router:chitchat:186465374219e387c415753e03e45813d5d6d1291d7d4cd78107b2d8528817b5',\n", + " 'vector_distance': '0.45913541317',\n", + " 'route_name': 'chitchat',\n", + " 'reference': 'how are things going?'}]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "topic_router(\"how are you\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "redisvl--BNYQ9Uk-py3.10", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/redisvl/schema/schema.py b/redisvl/schema/schema.py index d165cff4..9659c79a 100644 --- a/redisvl/schema/schema.py +++ b/redisvl/schema/schema.py @@ -195,7 +195,11 @@ def validate_and_create_fields(cls, values): """ Validate uniqueness of field names and create valid field instances. """ - index = IndexInfo(**values.get("index")) + # Ensure index is a dictionary for validation + index = values.get("index") + if not isinstance(index, IndexInfo): + index = IndexInfo(**index) + input_fields = values.get("fields", []) prepared_fields: Dict[str, BaseField] = {} # Handle old fields format temporarily From 207fc2fdb2c7872f9322f11fa20884f5b5adb280 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 9 Jul 2024 11:16:07 -0400 Subject: [PATCH 03/16] change param name --- redisvl/extensions/router/routes.py | 23 +++++++++++++++++------ redisvl/extensions/router/semantic.py | 27 +++++++++------------------ 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/redisvl/extensions/router/routes.py b/redisvl/extensions/router/routes.py index 0d2ec0c1..f3eae1ca 100644 --- a/redisvl/extensions/router/routes.py +++ b/redisvl/extensions/router/routes.py @@ -3,6 +3,8 @@ from pydantic.v1 import BaseModel, Field, validator from typing import List, Dict, Optional +from enum import Enum + class Route(BaseModel): name: str @@ -27,17 +29,26 @@ def references_must_not_be_empty(cls, v): return v +class AccumulationMethod(Enum): + # TODO: tidy up the enum usage + simple = "simple" # Take the winner at face value + avg = "avg" # Consider the avg score of all matches + sum = "sum" # Consider the cumulative score of all matches + auto = "auto" # Pick on the user's behalf? + + class RoutingConfig(BaseModel): - top_k: int = Field(default=1) + max_k: int = Field(default=1) """The maximum number of top matches to return""" - distance_threshold: Optional[float] = None + distance_threshold: float = Field(default=0.5) """The threshold for semantic distance""" - # TODO: need more here + accumulation_method: AccumulationMethod = Field(default=AccumulationMethod.auto) + """The accumulation method used to determine the matching route""" - @validator('top_k') - def top_k_must_be_positive(cls, v): + @validator('max_k') + def max_k_must_be_positive(cls, v): if v <= 0: - raise ValueError('top_k must be a positive integer') + raise ValueError('max_k must be a positive integer') return v @validator('distance_threshold') diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 12d415d9..2f659e14 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -137,38 +137,29 @@ def _add_routes(self, routes: List[Route]): def __call__( self, statement: str, - top_k: Optional[int] = None, + max_k: Optional[int] = None, distance_threshold: Optional[float] = None, ) -> List[Dict[str, Any]]: """Query the semantic router with a given statement. Args: statement (str): The input statement to be queried. - top_k (Optional[int]): The maximum number of top matches to return. + max_k (Optional[int]): The maximum number of top matches to return. distance_threshold (Optional[float]): The threshold for semantic distance. Returns: List[Dict[str, Any]]: The matching routes and their details. """ vector = self.vectorizer.embed(statement) - top_k = top_k if top_k is not None else self.routing_config.top_k + max_k = max_k if max_k is not None else self.routing_config.max_k distance_threshold = distance_threshold if distance_threshold is not None else self.routing_config.distance_threshold - if distance_threshold: - query = RangeQuery( - vector=vector, - vector_field_name="vector", - distance_threshold=distance_threshold, - return_fields=["route_name", "reference"], - num_results=top_k # need to fetch more to be able to do aggregation - ) - else: - query = VectorQuery( - vector=vector, - vector_field_name="vector", - return_fields=["route_name", "reference"], - num_results=top_k # need to fetch more to be able to do aggregation - ) + query = RangeQuery( + vector=vector, + vector_field_name="vector", + distance_threshold=distance_threshold, + return_fields=["route_name", "reference"], + ) route_references = self._index.query(query) From 3a5fed9c00aad63fc68ab7742920e1d2b0e46be7 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 9 Jul 2024 23:42:53 -0400 Subject: [PATCH 04/16] very early baseline approach --- redisvl/extensions/router/routes.py | 2 - redisvl/extensions/router/semantic.py | 119 ++++++++++++++++++++++---- redisvl/extensions/router/test.ipynb | 118 ++++++++++++++++++++----- 3 files changed, 198 insertions(+), 41 deletions(-) diff --git a/redisvl/extensions/router/routes.py b/redisvl/extensions/router/routes.py index f3eae1ca..d831e992 100644 --- a/redisvl/extensions/router/routes.py +++ b/redisvl/extensions/router/routes.py @@ -30,7 +30,6 @@ def references_must_not_be_empty(cls, v): class AccumulationMethod(Enum): - # TODO: tidy up the enum usage simple = "simple" # Take the winner at face value avg = "avg" # Consider the avg score of all matches sum = "sum" # Consider the cumulative score of all matches @@ -56,4 +55,3 @@ def distance_threshold_must_be_valid(cls, v): if v is not None and (v <= 0 or v > 1): raise ValueError('distance_threshold must be between 0 and 1') return v - diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 2f659e14..02972475 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -5,7 +5,7 @@ from redisvl.query import VectorQuery, RangeQuery from redisvl.schema import IndexSchema, IndexInfo from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer -from redisvl.extensions.router.routes import Route, RoutingConfig +from redisvl.extensions.router.routes import Route, RoutingConfig, AccumulationMethod import hashlib @@ -44,6 +44,7 @@ class SemanticRouter(BaseModel): """Configuration for routing behavior""" _index: SearchIndex = PrivateAttr() + _accumulation_method: AccumulationMethod = PrivateAttr() class Config: arbitrary_types_allowed = True @@ -59,6 +60,18 @@ def __init__( overwrite: bool = False, **kwargs ): + """Initialize the SemanticRouter. + + Args: + name (str): The name of the semantic router. + routes (List[Route]): List of Route objects. + vectorizer (BaseVectorizer, optional): The vectorizer used to embed route references. Defaults to HFTextVectorizer(). + routing_config (RoutingConfig, optional): Configuration for routing behavior. Defaults to RoutingConfig(). + redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None. + redis_url (str, optional): Redis URL for connection. Defaults to "redis://localhost:6379". + overwrite (bool, optional): Whether to overwrite existing index. Defaults to False. + **kwargs: Additional arguments. + """ super().__init__( name=name, routes=routes, @@ -66,6 +79,7 @@ def __init__( routing_config=routing_config ) self._initialize_index(redis_client, redis_url, overwrite) + self._accumulation_method = self._pick_accumulation_method() def _initialize_index( self, @@ -77,15 +91,14 @@ def _initialize_index( """Initialize the search index and handle Redis connection. Args: - data (dict): Initialization data containing Redis connection details. + redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None. + redis_url (str, optional): Redis URL for connection. Defaults to "redis://localhost:6379". + overwrite (bool, optional): Whether to overwrite existing index. Defaults to False. + **connection_kwargs: Additional connection arguments. """ - # Create search index schema schema = SemanticRouterIndexSchema.from_params(self.name, self.vectorizer.dims) - - # Build search index self._index = SearchIndex(schema=schema) - # Handle Redis connection if redis_client: self._index.set_client(redis_client) else: @@ -94,13 +107,22 @@ def _initialize_index( existed = self._index.exists() self._index.create(overwrite=overwrite) - # If the index did not yet exist OR we overwrote it if not existed or overwrite: self._add_routes(self.routes) - # TODO : double check this kind of logic + def _pick_accumulation_method(self) -> AccumulationMethod: + """Pick the accumulation method based on the routing configuration.""" + if self.routing_config.accumulation_method != AccumulationMethod.auto: + return self.routing_config.accumulation_method + num_route_references = [len(route.references) for route in self.routes] + avg_num_references = sum(num_route_references) / len(num_route_references) + variance = sum((x - avg_num_references) ** 2 for x in num_route_references) / len(num_route_references) + if variance < 1: # TODO: Arbitrary threshold for low variance + return AccumulationMethod.sum + else: + return AccumulationMethod.avg def update_routing_config(self, routing_config: RoutingConfig): """Update the routing configuration. @@ -109,8 +131,7 @@ def update_routing_config(self, routing_config: RoutingConfig): routing_config (RoutingConfig): The new routing configuration. """ self.routing_config = routing_config - # TODO: Ensure Pydantic handles the validation here - # TODO: Determine if we need to persist this to Redis + self._accumulation_method = self._pick_accumulation_method() def _add_routes(self, routes: List[Route]): """Add routes to the index. @@ -120,7 +141,7 @@ def _add_routes(self, routes: List[Route]): """ route_references: List[Dict[str, Any]] = [] keys: List[str] = [] - # Iteratively load route references + for route in routes: for reference in route.references: route_references.append({ @@ -133,7 +154,6 @@ def _add_routes(self, routes: List[Route]): self._index.load(route_references, keys=keys) - def __call__( self, statement: str, @@ -154,19 +174,82 @@ def __call__( max_k = max_k if max_k is not None else self.routing_config.max_k distance_threshold = distance_threshold if distance_threshold is not None else self.routing_config.distance_threshold + # get the total number of route references in the index + num_route_references = sum( + [len(route.references) for route in self.routes] + ) + # define the baseline range query to fetch relevant route references query = RangeQuery( vector=vector, vector_field_name="vector", distance_threshold=distance_threshold, return_fields=["route_name", "reference"], + # max number of results from range query + num_results=num_route_references ) - + # execute query and accumulate results route_references = self._index.query(query) + top_routes_and_scores = self._reduce_scores(route_references, max_k) + top_routes = self._fetch_routes(top_routes_and_scores) + + return top_routes + + def _reduce_scores( + self, + route_references: List[Dict[str, Any]], + max_k: int + ) -> List[Dict[str, Any]]: + """Group by route name and reduce scores to return max_k routes overall. + + Args: + route_references: List of route references with scores. + max_k: The number of top results to return. - # TODO use accumulation strategy to aggregation (sum or avg) the scores by the associated route - #top_routes_and_scores = ... + Returns: + List[Dict[str, Any]]: Accumulated scores for the top routes. + """ + # TODO: eventually this should be replaced by an AggregationQuery class + scores_by_route = {} + for ref in route_references: + route_name = ref['route_name'] + score = ref['vector_distance'] + if route_name not in scores_by_route: + scores_by_route[route_name] = [] + scores_by_route[route_name].append(float(score)) + + accumulated_scores = [] + for route_name, scores in scores_by_route.items(): + if self._accumulation_method == AccumulationMethod.sum: + accumulated_score = sum(scores) + elif self._accumulation_method == AccumulationMethod.avg: + accumulated_score = sum(scores) / len(scores) + else: + # simple strategy + accumulated_score = scores[0] # take the first score + + accumulated_scores.append({"route_name": route_name, "score": accumulated_score}) + + # Sort by score in descending order and return the max_k results + accumulated_scores.sort(key=lambda x: x["score"], reverse=False) + return accumulated_scores[:max_k] + + def _fetch_routes(self, top_routes_and_scores: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Fetch route objects and metadata based on top matches. - # TODO fetch the route objects and metadata directly from this class based on top matches - #results = ... + Args: + top_routes_and_scores: List of top routes and their scores. + + Returns: + List[Dict[str, Any]]: Routes with their metadata. + """ + results = [] + for route_info in top_routes_and_scores: + route_name = route_info["route_name"] + route = next((r for r in self.routes if r.name == route_name), None) + if route: + results.append({ + **route.dict(), + "score": route_info["score"], + }) - return route_references + return results diff --git a/redisvl/extensions/router/test.ipynb b/redisvl/extensions/router/test.ipynb index bc151c3a..00749943 100644 --- a/redisvl/extensions/router/test.ipynb +++ b/redisvl/extensions/router/test.ipynb @@ -51,7 +51,7 @@ "metadata": {}, "outputs": [], "source": [ - "config = RoutingConfig(top_k=5, distance_threshold=0.5)" + "config = RoutingConfig(max_k=2, distance_threshold=1.0, accumulation_method=\"avg\")" ] }, { @@ -63,7 +63,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "22:39:55 redisvl.index.index INFO Index already exists, not overwriting.\n" + "23:37:26 redisvl.index.index INFO Index already exists, overwriting.\n" ] } ], @@ -78,20 +78,11 @@ " name=\"topic-router\",\n", " routes=routes,\n", " routing_config=config,\n", - " redis_client=redis_client\n", + " redis_client=redis_client,\n", + " overwrite=True\n", ")" ] }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "# Query topic-router with default behavior based on the config\n", - "result = topic_router(\"don't you love politics?\")" - ] - }, { "cell_type": "code", "execution_count": 7, @@ -112,7 +103,15 @@ { "data": { "text/plain": [ - "5" + "[{'name': 'politics',\n", + " 'references': [\"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\"],\n", + " 'metadata': {'priority': '1'},\n", + " 'score': 0.3825837373735},\n", + " {'name': 'chitchat',\n", + " 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n", + " 'metadata': {'priority': '2'},\n", + " 'score': 0.8872345884643332}]" ] }, "execution_count": 8, @@ -121,24 +120,30 @@ } ], "source": [ - "config.top_k" + "# Query topic-router with behavior based on the config\n", + "topic_router(\"don't you love politics?\")" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[{'id': 'topic-router:chitchat:186465374219e387c415753e03e45813d5d6d1291d7d4cd78107b2d8528817b5',\n", - " 'vector_distance': '0.45913541317',\n", - " 'route_name': 'chitchat',\n", - " 'reference': 'how are things going?'}]" + "[{'name': 'chitchat',\n", + " 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n", + " 'metadata': {'priority': '2'},\n", + " 'score': 0.5357088247936667},\n", + " {'name': 'politics',\n", + " 'references': [\"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\"],\n", + " 'metadata': {'priority': '1'},\n", + " 'score': 0.8782881200315}]" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -146,6 +151,77 @@ "source": [ "topic_router(\"how are you\")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Need to work on a way to properly handle the \"sum\" strategy... Adding route references by vector distance when 0 is \"more similar\" means we would have to flip polarity in order to capture the right signal.\n", + "\n", + "Avg is a safer strategy in general.. but assumes that the variance in the amount of references per route is not too high." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "config = RoutingConfig(distance_threshold=0.1, accumulation_method=\"avg\")\n", + "\n", + "topic_router.update_routing_config(config)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# distance threshold is too low\n", + "topic_router(\"hello world\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'name': 'chitchat',\n", + " 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n", + " 'metadata': {'priority': '2'},\n", + " 'score': 0.243986725807}]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "topic_router(\"hello world\", distance_threshold=0.3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From 7b08ad35f01e2d28e7e250a88336f889db94bf5d Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 15 Jul 2024 11:08:17 -0400 Subject: [PATCH 05/16] wip updates using aggregations --- redisvl/extensions/router/semantic.py | 112 +++++++++++--------------- redisvl/extensions/router/test.ipynb | 96 +++++++++++++++------- redisvl/index/index.py | 9 ++- 3 files changed, 124 insertions(+), 93 deletions(-) diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 02972475..caf1a123 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -1,12 +1,18 @@ from pydantic.v1 import BaseModel, root_validator, Field, PrivateAttr from typing import Any, List, Dict, Optional, Union + from redis import Redis +from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer +import redis.commands.search.reducers as reducers + from redisvl.index import SearchIndex from redisvl.query import VectorQuery, RangeQuery from redisvl.schema import IndexSchema, IndexInfo from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer from redisvl.extensions.router.routes import Route, RoutingConfig, AccumulationMethod +from redisvl.redis.utils import make_dict, convert_bytes + import hashlib @@ -44,7 +50,7 @@ class SemanticRouter(BaseModel): """Configuration for routing behavior""" _index: SearchIndex = PrivateAttr() - _accumulation_method: AccumulationMethod = PrivateAttr() + # _accumulation_method: AccumulationMethod = PrivateAttr() class Config: arbitrary_types_allowed = True @@ -79,7 +85,7 @@ def __init__( routing_config=routing_config ) self._initialize_index(redis_client, redis_url, overwrite) - self._accumulation_method = self._pick_accumulation_method() + # self._accumulation_method = self._pick_accumulation_method() def _initialize_index( self, @@ -110,19 +116,19 @@ def _initialize_index( if not existed or overwrite: self._add_routes(self.routes) - def _pick_accumulation_method(self) -> AccumulationMethod: - """Pick the accumulation method based on the routing configuration.""" - if self.routing_config.accumulation_method != AccumulationMethod.auto: - return self.routing_config.accumulation_method + # def _pick_accumulation_method(self) -> AccumulationMethod: + # """Pick the accumulation method based on the routing configuration.""" + # if self.routing_config.accumulation_method != AccumulationMethod.auto: + # return self.routing_config.accumulation_method - num_route_references = [len(route.references) for route in self.routes] - avg_num_references = sum(num_route_references) / len(num_route_references) - variance = sum((x - avg_num_references) ** 2 for x in num_route_references) / len(num_route_references) + # num_route_references = [len(route.references) for route in self.routes] + # avg_num_references = sum(num_route_references) / len(num_route_references) + # variance = sum((x - avg_num_references) ** 2 for x in num_route_references) / len(num_route_references) - if variance < 1: # TODO: Arbitrary threshold for low variance - return AccumulationMethod.sum - else: - return AccumulationMethod.avg + # if variance < 1: # TODO: Arbitrary threshold for low variance + # return AccumulationMethod.sum + # else: + # return AccumulationMethod.avg def update_routing_config(self, routing_config: RoutingConfig): """Update the routing configuration. @@ -131,7 +137,7 @@ def update_routing_config(self, routing_config: RoutingConfig): routing_config (RoutingConfig): The new routing configuration. """ self.routing_config = routing_config - self._accumulation_method = self._pick_accumulation_method() + # self._accumulation_method = self._pick_accumulation_method() def _add_routes(self, routes: List[Route]): """Add routes to the index. @@ -174,64 +180,41 @@ def __call__( max_k = max_k if max_k is not None else self.routing_config.max_k distance_threshold = distance_threshold if distance_threshold is not None else self.routing_config.distance_threshold - # get the total number of route references in the index - num_route_references = sum( - [len(route.references) for route in self.routes] - ) + # # get the total number of route references in the index + # num_route_references = sum( + # [len(route.references) for route in self.routes] + # ) # define the baseline range query to fetch relevant route references - query = RangeQuery( + vector_range_query = RangeQuery( vector=vector, vector_field_name="vector", - distance_threshold=distance_threshold, - return_fields=["route_name", "reference"], - # max number of results from range query - num_results=num_route_references + distance_threshold=2, + return_fields=["route_name"] ) - # execute query and accumulate results - route_references = self._index.query(query) - top_routes_and_scores = self._reduce_scores(route_references, max_k) - top_routes = self._fetch_routes(top_routes_and_scores) - return top_routes + # build redis aggregation query + aggregate_query = str(vector_range_query).split(" RETURN")[0] + aggregate_request = ( + AggregateRequest(aggregate_query) + .group_by( + "@route_name", + reducers.avg("vector_distance").alias("avg"), + reducers.min("vector_distance").alias("score") + ) + .apply(avg_score="1 - @avg", score="1 - @score") + .dialect(2) + ) - def _reduce_scores( - self, - route_references: List[Dict[str, Any]], - max_k: int - ) -> List[Dict[str, Any]]: - """Group by route name and reduce scores to return max_k routes overall. + top_routes_and_scores = [] + aggregate_results = self._index.client.ft(self._index.name).aggregate(aggregate_request, vector_range_query.params) - Args: - route_references: List of route references with scores. - max_k: The number of top results to return. + for result in aggregate_results.rows: + top_routes_and_scores.append(make_dict(convert_bytes(result))) + + top_routes = self._fetch_routes(top_routes_and_scores) + + return top_routes - Returns: - List[Dict[str, Any]]: Accumulated scores for the top routes. - """ - # TODO: eventually this should be replaced by an AggregationQuery class - scores_by_route = {} - for ref in route_references: - route_name = ref['route_name'] - score = ref['vector_distance'] - if route_name not in scores_by_route: - scores_by_route[route_name] = [] - scores_by_route[route_name].append(float(score)) - - accumulated_scores = [] - for route_name, scores in scores_by_route.items(): - if self._accumulation_method == AccumulationMethod.sum: - accumulated_score = sum(scores) - elif self._accumulation_method == AccumulationMethod.avg: - accumulated_score = sum(scores) / len(scores) - else: - # simple strategy - accumulated_score = scores[0] # take the first score - - accumulated_scores.append({"route_name": route_name, "score": accumulated_score}) - - # Sort by score in descending order and return the max_k results - accumulated_scores.sort(key=lambda x: x["score"], reverse=False) - return accumulated_scores[:max_k] def _fetch_routes(self, top_routes_and_scores: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Fetch route objects and metadata based on top matches. @@ -250,6 +233,7 @@ def _fetch_routes(self, top_routes_and_scores: List[Dict[str, Any]]) -> List[Dic results.append({ **route.dict(), "score": route_info["score"], + "avg_score": route_info["avg_score"] }) return results diff --git a/redisvl/extensions/router/test.ipynb b/redisvl/extensions/router/test.ipynb index 00749943..8a60b625 100644 --- a/redisvl/extensions/router/test.ipynb +++ b/redisvl/extensions/router/test.ipynb @@ -15,6 +15,8 @@ "metadata": {}, "outputs": [], "source": [ + "from redisvl.extensions.router.routes import Route\n", + "\n", "# Define individual routes manually with metadata\n", "politics = Route(\n", " name=\"politics\",\n", @@ -63,13 +65,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "23:37:26 redisvl.index.index INFO Index already exists, overwriting.\n" + "13:38:58 redisvl.index.index INFO Index already exists, overwriting.\n" ] } ], "source": [ - "from redisvl.extensions.router.semantic import SemanticRouter\n", "import redis\n", + "from redisvl.extensions.router.semantic import SemanticRouter\n", "\n", "# Create SemanticRouter named \"topic-router\"\n", "redis_client = redis.Redis()\n", @@ -85,48 +87,67 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[{'name': 'chitchat',\n", + " 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n", + " 'metadata': {'priority': '2'},\n", + " 'score': '0.0641258955002',\n", + " 'avg_score': '0.0481971502304'},\n", + " {'name': 'politics',\n", + " 'references': [\"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\"],\n", + " 'metadata': {'priority': '1'},\n", + " 'score': '0.298070549965',\n", + " 'avg_score': '0.207850039005'}]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "assert topic_router.routes == routes\n", - "assert topic_router.name == \"topic-router\"\n", - "assert topic_router.name == topic_router._index.name == topic_router._index.prefix\n", - "assert topic_router.routing_config == config" + "topic_router(\"I am thinking about running for Governor in the state of VA. What do I need to consider?\")" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[{'name': 'politics',\n", + "[{'name': 'chitchat',\n", + " 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n", + " 'metadata': {'priority': '2'},\n", + " 'score': '0.12840616703',\n", + " 'avg_score': '0.112765411536'},\n", + " {'name': 'politics',\n", " 'references': [\"isn't politics the best thing ever\",\n", " \"why don't you tell me about your political opinions\"],\n", " 'metadata': {'priority': '1'},\n", - " 'score': 0.3825837373735},\n", - " {'name': 'chitchat',\n", - " 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n", - " 'metadata': {'priority': '2'},\n", - " 'score': 0.8872345884643332}]" + " 'score': '0.764727830887',\n", + " 'avg_score': '0.617416262627'}]" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Query topic-router with behavior based on the config\n", "topic_router(\"don't you love politics?\")" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -135,15 +156,17 @@ "[{'name': 'chitchat',\n", " 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n", " 'metadata': {'priority': '2'},\n", - " 'score': 0.5357088247936667},\n", + " 'score': '0.54086458683',\n", + " 'avg_score': '0.464291175207'},\n", " {'name': 'politics',\n", " 'references': [\"isn't politics the best thing ever\",\n", " \"why don't you tell me about your political opinions\"],\n", " 'metadata': {'priority': '1'},\n", - " 'score': 0.8782881200315}]" + " 'score': '0.156601548195',\n", + " 'avg_score': '0.121711879969'}]" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -163,7 +186,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -174,16 +197,26 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[]" + "[{'name': 'chitchat',\n", + " 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n", + " 'metadata': {'priority': '2'},\n", + " 'score': '0.756013274193',\n", + " 'avg_score': '0.423087894917'},\n", + " {'name': 'politics',\n", + " 'references': [\"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\"],\n", + " 'metadata': {'priority': '1'},\n", + " 'score': '0.175542235374',\n", + " 'avg_score': '0.138914197683'}]" ] }, - "execution_count": 14, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -195,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -204,10 +237,17 @@ "[{'name': 'chitchat',\n", " 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n", " 'metadata': {'priority': '2'},\n", - " 'score': 0.243986725807}]" + " 'score': '0.756013274193',\n", + " 'avg_score': '0.423087894917'},\n", + " {'name': 'politics',\n", + " 'references': [\"isn't politics the best thing ever\",\n", + " \"why don't you tell me about your political opinions\"],\n", + " 'metadata': {'priority': '1'},\n", + " 'score': '0.175542235374',\n", + " 'avg_score': '0.138914197683'}]" ] }, - "execution_count": 15, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 2187c759..a479f8ed 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -177,7 +177,14 @@ def __init__( self.connect(redis_url, **connection_args) # set up index storage layer - self._storage = self._STORAGE_MAP[self.schema.index.storage_type]( + # self._storage = self._STORAGE_MAP[self.schema.index.storage_type]( + # prefix=self.schema.index.prefix, + # key_separator=self.schema.index.key_separator, + # ) + + @property + def _storage(self): + return self._STORAGE_MAP[self.schema.index.storage_type]( prefix=self.schema.index.prefix, key_separator=self.schema.index.key_separator, ) From 187bfaa2b941bfc29869dcc71ecb6d9dc02aa6a1 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Mon, 15 Jul 2024 15:58:34 -0400 Subject: [PATCH 06/16] Add intial round of tests --- redisvl/extensions/router/__init__.py | 5 + redisvl/extensions/router/routes.py | 25 +- redisvl/extensions/router/semantic.py | 78 +++--- redisvl/extensions/router/test.ipynb | 288 ---------------------- tests/integration/test_semantic_router.py | 63 +++++ tests/unit/test_routes.py | 52 ++++ 6 files changed, 161 insertions(+), 350 deletions(-) delete mode 100644 redisvl/extensions/router/test.ipynb create mode 100644 tests/integration/test_semantic_router.py create mode 100644 tests/unit/test_routes.py diff --git a/redisvl/extensions/router/__init__.py b/redisvl/extensions/router/__init__.py index e69de29b..d1d3a0ab 100644 --- a/redisvl/extensions/router/__init__.py +++ b/redisvl/extensions/router/__init__.py @@ -0,0 +1,5 @@ +from redisvl.extensions.router.semantic import SemanticRouter +from redisvl.extensions.router.routes import Route, RoutingConfig + + +__all__ = ["SemanticRouter", "Route", "RoutingConfig"] diff --git a/redisvl/extensions/router/routes.py b/redisvl/extensions/router/routes.py index d831e992..f1f55340 100644 --- a/redisvl/extensions/router/routes.py +++ b/redisvl/extensions/router/routes.py @@ -1,9 +1,6 @@ - - -from pydantic.v1 import BaseModel, Field, validator -from typing import List, Dict, Optional - from enum import Enum +from pydantic.v1 import BaseModel, Field, validator +from typing import List, Dict class Route(BaseModel): @@ -29,20 +26,18 @@ def references_must_not_be_empty(cls, v): return v -class AccumulationMethod(Enum): - simple = "simple" # Take the winner at face value - avg = "avg" # Consider the avg score of all matches - sum = "sum" # Consider the cumulative score of all matches - auto = "auto" # Pick on the user's behalf? +class RouteSortingMethod(Enum): + avg_distance = "avg_distance" + min_distance = "min_distance" class RoutingConfig(BaseModel): - max_k: int = Field(default=1) - """The maximum number of top matches to return""" distance_threshold: float = Field(default=0.5) - """The threshold for semantic distance""" - accumulation_method: AccumulationMethod = Field(default=AccumulationMethod.auto) - """The accumulation method used to determine the matching route""" + """The threshold for semantic distance.""" + max_k: int = Field(default=1) + """The maximum number of top matches to return.""" + sort_by: RouteSortingMethod = Field(default=RouteSortingMethod.avg_distance) + """The technique used to sort the final route matches before truncating.""" @validator('max_k') def max_k_must_be_positive(cls, v): diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index caf1a123..40cff192 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -2,14 +2,14 @@ from typing import Any, List, Dict, Optional, Union from redis import Redis -from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer +from redis.commands.search.aggregation import AggregateRequest, AggregateResult import redis.commands.search.reducers as reducers from redisvl.index import SearchIndex from redisvl.query import VectorQuery, RangeQuery from redisvl.schema import IndexSchema, IndexInfo from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer -from redisvl.extensions.router.routes import Route, RoutingConfig, AccumulationMethod +from redisvl.extensions.router.routes import Route, RoutingConfig, RouteSortingMethod from redisvl.redis.utils import make_dict, convert_bytes @@ -20,6 +20,9 @@ class SemanticRouterIndexSchema(IndexSchema): @classmethod def from_params(cls, name: str, vector_dims: int): + """Load the semantic router index schema from the router name and + vector dimensionality. + """ return cls( index=IndexInfo(name=name, prefix=name), fields=[ @@ -50,7 +53,6 @@ class SemanticRouter(BaseModel): """Configuration for routing behavior""" _index: SearchIndex = PrivateAttr() - # _accumulation_method: AccumulationMethod = PrivateAttr() class Config: arbitrary_types_allowed = True @@ -85,7 +87,6 @@ def __init__( routing_config=routing_config ) self._initialize_index(redis_client, redis_url, overwrite) - # self._accumulation_method = self._pick_accumulation_method() def _initialize_index( self, @@ -116,20 +117,6 @@ def _initialize_index( if not existed or overwrite: self._add_routes(self.routes) - # def _pick_accumulation_method(self) -> AccumulationMethod: - # """Pick the accumulation method based on the routing configuration.""" - # if self.routing_config.accumulation_method != AccumulationMethod.auto: - # return self.routing_config.accumulation_method - - # num_route_references = [len(route.references) for route in self.routes] - # avg_num_references = sum(num_route_references) / len(num_route_references) - # variance = sum((x - avg_num_references) ** 2 for x in num_route_references) / len(num_route_references) - - # if variance < 1: # TODO: Arbitrary threshold for low variance - # return AccumulationMethod.sum - # else: - # return AccumulationMethod.avg - def update_routing_config(self, routing_config: RoutingConfig): """Update the routing configuration. @@ -165,6 +152,7 @@ def __call__( statement: str, max_k: Optional[int] = None, distance_threshold: Optional[float] = None, + sort_by: Optional[str] = None ) -> List[Dict[str, Any]]: """Query the semantic router with a given statement. @@ -172,6 +160,7 @@ def __call__( statement (str): The input statement to be queried. max_k (Optional[int]): The maximum number of top matches to return. distance_threshold (Optional[float]): The threshold for semantic distance. + sort_by (Optional[str]): The technique used to sort the final route matches before truncating. Returns: List[Dict[str, Any]]: The matching routes and their details. @@ -179,11 +168,8 @@ def __call__( vector = self.vectorizer.embed(statement) max_k = max_k if max_k is not None else self.routing_config.max_k distance_threshold = distance_threshold if distance_threshold is not None else self.routing_config.distance_threshold + sort_by = RouteSortingMethod(sort_by) if sort_by is not None else self.routing_config.sort_by - # # get the total number of route references in the index - # num_route_references = sum( - # [len(route.references) for route in self.routes] - # ) # define the baseline range query to fetch relevant route references vector_range_query = RangeQuery( vector=vector, @@ -198,42 +184,40 @@ def __call__( AggregateRequest(aggregate_query) .group_by( "@route_name", - reducers.avg("vector_distance").alias("avg"), - reducers.min("vector_distance").alias("score") + reducers.avg("vector_distance").alias("avg_distance"), + reducers.min("vector_distance").alias("min_distance") ) - .apply(avg_score="1 - @avg", score="1 - @score") .dialect(2) ) - top_routes_and_scores = [] - aggregate_results = self._index.client.ft(self._index.name).aggregate(aggregate_request, vector_range_query.params) - - for result in aggregate_results.rows: - top_routes_and_scores.append(make_dict(convert_bytes(result))) + # run the aggregation query in Redis + aggregate_result: AggregateResult = ( + self._index.client + .ft(self._index.name) + .aggregate(aggregate_request, vector_range_query.params) + ) - top_routes = self._fetch_routes(top_routes_and_scores) + top_routes_and_scores = sorted([ + self._process_result(result) for result in aggregate_result.rows + ], key=lambda r: r[sort_by.value]) - return top_routes + return top_routes_and_scores[:max_k] - def _fetch_routes(self, top_routes_and_scores: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Fetch route objects and metadata based on top matches. + def _process_result(self, result: Dict[str, Any]) -> Dict[str, Any]: + """Process resulting route objects and metadata. Args: - top_routes_and_scores: List of top routes and their scores. + result: Aggregation query result object Returns: List[Dict[str, Any]]: Routes with their metadata. """ - results = [] - for route_info in top_routes_and_scores: - route_name = route_info["route_name"] - route = next((r for r in self.routes if r.name == route_name), None) - if route: - results.append({ - **route.dict(), - "score": route_info["score"], - "avg_score": route_info["avg_score"] - }) - - return results + result_dict = make_dict(convert_bytes(result)) + route_name = result_dict["route_name"] + route = next((r for r in self.routes if r.name == route_name), None) + return { + **route.dict(), + "avg_distance": float(result_dict["avg_distance"]), + "min_distance": float(result_dict["min_distance"]) + } diff --git a/redisvl/extensions/router/test.ipynb b/redisvl/extensions/router/test.ipynb deleted file mode 100644 index 8a60b625..00000000 --- a/redisvl/extensions/router/test.ipynb +++ /dev/null @@ -1,288 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from redisvl.extensions.router.routes import Route, RoutingConfig" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from redisvl.extensions.router.routes import Route\n", - "\n", - "# Define individual routes manually with metadata\n", - "politics = Route(\n", - " name=\"politics\",\n", - " references=[\n", - " \"isn't politics the best thing ever\",\n", - " \"why don't you tell me about your political opinions\"\n", - " ],\n", - " metadata={\"priority\": 1}\n", - ")\n", - "\n", - "chitchat = Route(\n", - " name=\"chitchat\",\n", - " references=[\n", - " \"hello\",\n", - " \"how's the weather today?\",\n", - " \"how are things going?\"\n", - " ],\n", - " metadata={\"priority\": 2}\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "routes = [politics, chitchat]" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "config = RoutingConfig(max_k=2, distance_threshold=1.0, accumulation_method=\"avg\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "13:38:58 redisvl.index.index INFO Index already exists, overwriting.\n" - ] - } - ], - "source": [ - "import redis\n", - "from redisvl.extensions.router.semantic import SemanticRouter\n", - "\n", - "# Create SemanticRouter named \"topic-router\"\n", - "redis_client = redis.Redis()\n", - "routes = [politics, chitchat]\n", - "topic_router = SemanticRouter(\n", - " name=\"topic-router\",\n", - " routes=routes,\n", - " routing_config=config,\n", - " redis_client=redis_client,\n", - " overwrite=True\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'name': 'chitchat',\n", - " 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n", - " 'metadata': {'priority': '2'},\n", - " 'score': '0.0641258955002',\n", - " 'avg_score': '0.0481971502304'},\n", - " {'name': 'politics',\n", - " 'references': [\"isn't politics the best thing ever\",\n", - " \"why don't you tell me about your political opinions\"],\n", - " 'metadata': {'priority': '1'},\n", - " 'score': '0.298070549965',\n", - " 'avg_score': '0.207850039005'}]" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "topic_router(\"I am thinking about running for Governor in the state of VA. What do I need to consider?\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'name': 'chitchat',\n", - " 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n", - " 'metadata': {'priority': '2'},\n", - " 'score': '0.12840616703',\n", - " 'avg_score': '0.112765411536'},\n", - " {'name': 'politics',\n", - " 'references': [\"isn't politics the best thing ever\",\n", - " \"why don't you tell me about your political opinions\"],\n", - " 'metadata': {'priority': '1'},\n", - " 'score': '0.764727830887',\n", - " 'avg_score': '0.617416262627'}]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "topic_router(\"don't you love politics?\")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'name': 'chitchat',\n", - " 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n", - " 'metadata': {'priority': '2'},\n", - " 'score': '0.54086458683',\n", - " 'avg_score': '0.464291175207'},\n", - " {'name': 'politics',\n", - " 'references': [\"isn't politics the best thing ever\",\n", - " \"why don't you tell me about your political opinions\"],\n", - " 'metadata': {'priority': '1'},\n", - " 'score': '0.156601548195',\n", - " 'avg_score': '0.121711879969'}]" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "topic_router(\"how are you\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Need to work on a way to properly handle the \"sum\" strategy... Adding route references by vector distance when 0 is \"more similar\" means we would have to flip polarity in order to capture the right signal.\n", - "\n", - "Avg is a safer strategy in general.. but assumes that the variance in the amount of references per route is not too high." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "config = RoutingConfig(distance_threshold=0.1, accumulation_method=\"avg\")\n", - "\n", - "topic_router.update_routing_config(config)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'name': 'chitchat',\n", - " 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n", - " 'metadata': {'priority': '2'},\n", - " 'score': '0.756013274193',\n", - " 'avg_score': '0.423087894917'},\n", - " {'name': 'politics',\n", - " 'references': [\"isn't politics the best thing ever\",\n", - " \"why don't you tell me about your political opinions\"],\n", - " 'metadata': {'priority': '1'},\n", - " 'score': '0.175542235374',\n", - " 'avg_score': '0.138914197683'}]" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# distance threshold is too low\n", - "topic_router(\"hello world\")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'name': 'chitchat',\n", - " 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n", - " 'metadata': {'priority': '2'},\n", - " 'score': '0.756013274193',\n", - " 'avg_score': '0.423087894917'},\n", - " {'name': 'politics',\n", - " 'references': [\"isn't politics the best thing ever\",\n", - " \"why don't you tell me about your political opinions\"],\n", - " 'metadata': {'priority': '1'},\n", - " 'score': '0.175542235374',\n", - " 'avg_score': '0.138914197683'}]" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "topic_router(\"hello world\", distance_threshold=0.3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "redisvl--BNYQ9Uk-py3.10", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py new file mode 100644 index 00000000..dd39e5a7 --- /dev/null +++ b/tests/integration/test_semantic_router.py @@ -0,0 +1,63 @@ +import pytest +from redisvl.extensions.router.routes import Route, RoutingConfig +from redisvl.extensions.router.semantic import SemanticRouter + + +@pytest.fixture +def routes(): + politics = Route( + name="politics", + references=[ + "isn't politics the best thing ever", + "why don't you tell me about your political opinions" + ], + metadata={"priority": "1"} + ) + chitchat = Route( + name="chitchat", + references=[ + "hello", + "how's the weather today?", + "how are things going?" + ], + metadata={"priority": "2"} + ) + return [politics, chitchat] + +@pytest.fixture +def semantic_router(redis_client, routes): + config = RoutingConfig(distance_threshold=1.0) + router = SemanticRouter( + name="topic-router", + routes=routes, + routing_config=config, + redis_client=redis_client, + overwrite=True + ) + return router + + +def test_semantic_router_match_politics(semantic_router): + result = semantic_router("I am thinking about running for Governor in the state of VA. What do I need to consider?") + assert result[0]['route'].name == "politics" + + +def test_semantic_router_match_chitchat(semantic_router): + result = semantic_router("hello") + assert result[0]['route'].name == "chitchat" + + +def test_semantic_router_no_match(semantic_router): + result = semantic_router("unrelated topic") + assert result == [] + + +def test_update_routing_config(semantic_router): + new_config = RoutingConfig(distance_threshold=0.1, sort_by='avg_distance') + + semantic_router.update_routing_config(new_config) + result = semantic_router("hello world") + assert result == [] + + result = semantic_router("hello world", distance_threshold=0.3) + assert len(result) > 0 diff --git a/tests/unit/test_routes.py b/tests/unit/test_routes.py new file mode 100644 index 00000000..10e45e48 --- /dev/null +++ b/tests/unit/test_routes.py @@ -0,0 +1,52 @@ +import pytest +from pydantic.v1 import ValidationError +from redisvl.extensions.router.routes import Route, RoutingConfig, RouteSortingMethod + + +def test_route_creation(): + route = Route( + name="test_route", + references=["test reference 1", "test reference 2"], + metadata={"priority": "1"} + ) + assert route.name == "test_route" + assert route.references == ["test reference 1", "test reference 2"] + assert route.metadata == {"priority": "1"} + + +def test_route_name_empty(): + with pytest.raises(ValidationError): + Route(name="", references=["test reference"]) + + +def test_route_references_empty(): + with pytest.raises(ValidationError): + Route(name="test_route", references=[]) + + +def test_route_references_non_empty_strings(): + with pytest.raises(ValidationError): + Route(name="test_route", references=["", "test reference"]) + + +def test_routing_config_creation(): + config = RoutingConfig( + distance_threshold=0.5, + max_k=1, + sort_by=RouteSortingMethod.avg_distance + ) + assert config.distance_threshold == 0.5 + assert config.max_k == 1 + assert config.sort_by == RouteSortingMethod.avg_distance + + +def test_routing_config_invalid_max_k(): + with pytest.raises(ValidationError): + RoutingConfig(distance_threshold=0.5, max_k=0) + + +def test_routing_config_invalid_distance_threshold(): + with pytest.raises(ValidationError): + RoutingConfig(distance_threshold=-0.1, max_k=1) + with pytest.raises(ValidationError): + RoutingConfig(distance_threshold=1.1, max_k=1) From ffcd8688b9a655f9e83b00bc8563cac934524757 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 17 Jul 2024 12:23:20 -0400 Subject: [PATCH 07/16] update router and improve tests --- redisvl/extensions/router/__init__.py | 3 +- redisvl/extensions/router/routes.py | 52 ---- redisvl/extensions/router/schema.py | 80 ++++++ redisvl/extensions/router/semantic.py | 318 +++++++++++++++------- tests/integration/test_semantic_router.py | 132 ++++++--- tests/unit/test_route_schema.py | 124 +++++++++ tests/unit/test_routes.py | 52 ---- 7 files changed, 522 insertions(+), 239 deletions(-) delete mode 100644 redisvl/extensions/router/routes.py create mode 100644 redisvl/extensions/router/schema.py create mode 100644 tests/unit/test_route_schema.py delete mode 100644 tests/unit/test_routes.py diff --git a/redisvl/extensions/router/__init__.py b/redisvl/extensions/router/__init__.py index d1d3a0ab..25249f3b 100644 --- a/redisvl/extensions/router/__init__.py +++ b/redisvl/extensions/router/__init__.py @@ -1,5 +1,4 @@ +from redisvl.extensions.router.schema import Route, RoutingConfig from redisvl.extensions.router.semantic import SemanticRouter -from redisvl.extensions.router.routes import Route, RoutingConfig - __all__ = ["SemanticRouter", "Route", "RoutingConfig"] diff --git a/redisvl/extensions/router/routes.py b/redisvl/extensions/router/routes.py deleted file mode 100644 index f1f55340..00000000 --- a/redisvl/extensions/router/routes.py +++ /dev/null @@ -1,52 +0,0 @@ -from enum import Enum -from pydantic.v1 import BaseModel, Field, validator -from typing import List, Dict - - -class Route(BaseModel): - name: str - """The name of the route""" - references: List[str] - """List of reference phrases for the route""" - metadata: Dict[str, str] = Field(default={}) - """Metadata associated with the route""" - - @validator('name') - def name_must_not_be_empty(cls, v): - if not v or not v.strip(): - raise ValueError('Route name must not be empty') - return v - - @validator('references') - def references_must_not_be_empty(cls, v): - if not v: - raise ValueError('References must not be empty') - if any(not ref.strip() for ref in v): - raise ValueError('All references must be non-empty strings') - return v - - -class RouteSortingMethod(Enum): - avg_distance = "avg_distance" - min_distance = "min_distance" - - -class RoutingConfig(BaseModel): - distance_threshold: float = Field(default=0.5) - """The threshold for semantic distance.""" - max_k: int = Field(default=1) - """The maximum number of top matches to return.""" - sort_by: RouteSortingMethod = Field(default=RouteSortingMethod.avg_distance) - """The technique used to sort the final route matches before truncating.""" - - @validator('max_k') - def max_k_must_be_positive(cls, v): - if v <= 0: - raise ValueError('max_k must be a positive integer') - return v - - @validator('distance_threshold') - def distance_threshold_must_be_valid(cls, v): - if v is not None and (v <= 0 or v > 1): - raise ValueError('distance_threshold must be between 0 and 1') - return v diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py new file mode 100644 index 00000000..58ac2d85 --- /dev/null +++ b/redisvl/extensions/router/schema.py @@ -0,0 +1,80 @@ +from enum import Enum +from typing import Dict, List, Optional + +from pydantic.v1 import BaseModel, Field, validator + + +class Route(BaseModel): + """Model representing a routing path with associated metadata and thresholds.""" + + name: str + """The name of the route.""" + references: List[str] + """List of reference phrases for the route.""" + metadata: Dict[str, str] = Field(default={}) + """Metadata associated with the route.""" + distance_threshold: Optional[float] = Field(default=None) + """Distance threshold for matching the route.""" + + @validator("name") + def name_must_not_be_empty(cls, v): + if not v or not v.strip(): + raise ValueError("Route name must not be empty") + return v + + @validator("references") + def references_must_not_be_empty(cls, v): + if not v: + raise ValueError("References must not be empty") + if any(not ref.strip() for ref in v): + raise ValueError("All references must be non-empty strings") + return v + + @validator("distance_threshold") + def distance_threshold_must_be_positive(cls, v): + if v is not None and v <= 0: + raise ValueError("Route distance threshold must be greater than zero") + return v + + +class RouteMatch(BaseModel): + """Model representing a matched route with distance information.""" + + route: Optional[Route] = None + """The matched route.""" + distance: Optional[float] = Field(default=None) + """The distance of the match.""" + + +class DistanceAggregationMethod(Enum): + """Enumeration for distance aggregation methods.""" + + avg = "avg" + """Compute the average of the vector distances.""" + min = "min" + """Compute the minimum of the vector distances.""" + sum = "sum" + """Compute the sum of the vector distances.""" + + +class RoutingConfig(BaseModel): + """Configuration for routing behavior.""" + + distance_threshold: float = Field(default=0.5) + """The threshold for semantic distance.""" + max_k: int = Field(default=1) + """The maximum number of top matches to return.""" + aggregation_method: DistanceAggregationMethod = Field(default=DistanceAggregationMethod.avg) + """Aggregation method to use to classify queries.""" + + @validator("max_k") + def max_k_must_be_positive(cls, v): + if v <= 0: + raise ValueError("max_k must be a positive integer") + return v + + @validator("distance_threshold") + def distance_threshold_must_be_valid(cls, v): + if v <= 0 or v > 1: + raise ValueError("distance_threshold must be between 0 and 1") + return v diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 40cff192..80835d5b 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -1,27 +1,32 @@ -from pydantic.v1 import BaseModel, root_validator, Field, PrivateAttr -from typing import Any, List, Dict, Optional, Union +import hashlib +from typing import Any, Dict, List, Optional +import redis.commands.search.reducers as reducers +from pydantic.v1 import BaseModel, Field, PrivateAttr from redis import Redis from redis.commands.search.aggregation import AggregateRequest, AggregateResult -import redis.commands.search.reducers as reducers +from redisvl.extensions.router.schema import Route, RoutingConfig, RouteMatch, DistanceAggregationMethod from redisvl.index import SearchIndex -from redisvl.query import VectorQuery, RangeQuery -from redisvl.schema import IndexSchema, IndexInfo +from redisvl.query import RangeQuery +from redisvl.redis.utils import convert_bytes, make_dict +from redisvl.schema import IndexInfo, IndexSchema from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer -from redisvl.extensions.router.routes import Route, RoutingConfig, RouteSortingMethod - -from redisvl.redis.utils import make_dict, convert_bytes - -import hashlib class SemanticRouterIndexSchema(IndexSchema): + """Customized index schema for SemanticRouter.""" @classmethod - def from_params(cls, name: str, vector_dims: int): - """Load the semantic router index schema from the router name and - vector dimensionality. + def from_params(cls, name: str, vector_dims: int) -> 'SemanticRouterIndexSchema': + """Create an index schema based on router name and vector dimensions. + + Args: + name (str): The name of the index. + vector_dims (int): The dimensions of the vectors. + + Returns: + SemanticRouterIndexSchema: The constructed index schema. """ return cls( index=IndexInfo(name=name, prefix=name), @@ -35,22 +40,24 @@ def from_params(cls, name: str, vector_dims: int): "algorithm": "flat", "dims": vector_dims, "distance_metric": "cosine", - "datatype": "float32" - } - } - ] + "datatype": "float32", + }, + }, + ], ) class SemanticRouter(BaseModel): + """Semantic Router for managing and querying route vectors.""" + name: str - """The name of the semantic router""" + """The name of the semantic router.""" routes: List[Route] - """List of Route objects""" + """List of Route objects.""" vectorizer: BaseVectorizer = Field(default_factory=HFTextVectorizer) - """The vectorizer used to embed route references""" + """The vectorizer used to embed route references.""" routing_config: RoutingConfig = Field(default_factory=RoutingConfig) - """Configuration for routing behavior""" + """Configuration for routing behavior.""" _index: SearchIndex = PrivateAttr() @@ -64,9 +71,9 @@ def __init__( vectorizer: BaseVectorizer = HFTextVectorizer(), routing_config: RoutingConfig = RoutingConfig(), redis_client: Optional[Redis] = None, - redis_url: str = "redis://localhost:6379", + redis_url: Optional[str] = None, overwrite: bool = False, - **kwargs + **kwargs, ): """Initialize the SemanticRouter. @@ -76,30 +83,25 @@ def __init__( vectorizer (BaseVectorizer, optional): The vectorizer used to embed route references. Defaults to HFTextVectorizer(). routing_config (RoutingConfig, optional): Configuration for routing behavior. Defaults to RoutingConfig(). redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None. - redis_url (str, optional): Redis URL for connection. Defaults to "redis://localhost:6379". + redis_url (Optional[str], optional): Redis URL for connection. Defaults to None. overwrite (bool, optional): Whether to overwrite existing index. Defaults to False. **kwargs: Additional arguments. """ - super().__init__( - name=name, - routes=routes, - vectorizer=vectorizer, - routing_config=routing_config - ) + super().__init__(name=name, routes=routes, vectorizer=vectorizer, routing_config=routing_config) self._initialize_index(redis_client, redis_url, overwrite) def _initialize_index( self, redis_client: Optional[Redis] = None, - redis_url: str = "redis://localhost:6379", + redis_url: Optional[str] = None, overwrite: bool = False, - **connection_kwargs + **connection_kwargs, ): """Initialize the search index and handle Redis connection. Args: redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None. - redis_url (str, optional): Redis URL for connection. Defaults to "redis://localhost:6379". + redis_url (Optional[str], optional): Redis URL for connection. Defaults to None. overwrite (bool, optional): Whether to overwrite existing index. Defaults to False. **connection_kwargs: Additional connection arguments. """ @@ -115,8 +117,27 @@ def _initialize_index( self._index.create(overwrite=overwrite) if not existed or overwrite: + # write the routes to Redis self._add_routes(self.routes) + @property + def route_names(self) -> List[str]: + """Get the list of route names. + + Returns: + List[str]: List of route names. + """ + return [route.name for route in self.routes] + + @property + def route_thresholds(self) -> Dict[str, float]: + """Get the distance thresholds for each route. + + Returns: + Dict[str, float]: Dictionary of route names and their distance thresholds. + """ + return {route.name: route.distance_threshold for route in self.routes} + def update_routing_config(self, routing_config: RoutingConfig): """Update the routing configuration. @@ -124,10 +145,9 @@ def update_routing_config(self, routing_config: RoutingConfig): routing_config (RoutingConfig): The new routing configuration. """ self.routing_config = routing_config - # self._accumulation_method = self._pick_accumulation_method() def _add_routes(self, routes: List[Route]): - """Add routes to the index. + """Add routes to the router and index. Args: routes (List[Route]): List of routes to be added. @@ -136,88 +156,204 @@ def _add_routes(self, routes: List[Route]): keys: List[str] = [] for route in routes: + if route.distance_threshold is None: + route.distance_threshold = self.routing_config.distance_threshold + # set route reference for reference in route.references: - route_references.append({ - "route_name": route.name, - "reference": reference, - "vector": self.vectorizer.embed(reference, as_buffer=True) - }) + route_references.append( + { + "route_name": route.name, + "reference": reference, + "vector": self.vectorizer.embed(reference, as_buffer=True), + } + ) reference_hash = hashlib.sha256(reference.encode("utf-8")).hexdigest() keys.append(f"{self._index.schema.index.prefix}:{route.name}:{reference_hash}") + # set route if does not yet exist client side + if not self.get(route.name): + self.routes.append(route) + self._index.load(route_references, keys=keys) - def __call__( + def get(self, route_name: str) -> Optional[Route]: + """Get a route by its name. + + Args: + route_name (str): Name of the route. + + Returns: + Optional[Route]: The selected Route object or None if not found. + """ + return next((route for route in self.routes if route.name == route_name), None) + + def _process_route(self, result: Dict[str, Any]) -> RouteMatch: + """Process resulting route objects and metadata. + + Args: + result (Dict[str, Any]): Aggregation query result object. + + Returns: + RouteMatch: Processed route match with route object and distance. + """ + route_dict = make_dict(convert_bytes(result)) + route = self.get(route_dict["route_name"]) + return RouteMatch(route=route, distance=float(route_dict["distance"])) + + def _build_aggregate_request( self, - statement: str, - max_k: Optional[int] = None, - distance_threshold: Optional[float] = None, - sort_by: Optional[str] = None - ) -> List[Dict[str, Any]]: - """Query the semantic router with a given statement. + vector_range_query: RangeQuery, + aggregation_method: DistanceAggregationMethod, + max_k: int + ) -> AggregateRequest: + """Build the Redis aggregation request. Args: - statement (str): The input statement to be queried. - max_k (Optional[int]): The maximum number of top matches to return. - distance_threshold (Optional[float]): The threshold for semantic distance. - sort_by (Optional[str]): The technique used to sort the final route matches before truncating. + vector_range_query (RangeQuery): The query vector. + aggregation_method (DistanceAggregationMethod): The aggregation method. + max_k (int): The maximum number of top matches to return. Returns: - List[Dict[str, Any]]: The matching routes and their details. + AggregateRequest: The constructed aggregation request. """ - vector = self.vectorizer.embed(statement) - max_k = max_k if max_k is not None else self.routing_config.max_k - distance_threshold = distance_threshold if distance_threshold is not None else self.routing_config.distance_threshold - sort_by = RouteSortingMethod(sort_by) if sort_by is not None else self.routing_config.sort_by + if aggregation_method == DistanceAggregationMethod.min: + aggregation_func = reducers.min + elif aggregation_method == DistanceAggregationMethod.sum: + aggregation_func = reducers.sum + else: + aggregation_func = reducers.avg - # define the baseline range query to fetch relevant route references + aggregate_query = str(vector_range_query).split(" RETURN")[0] + aggregate_request = ( + AggregateRequest(aggregate_query) + .group_by("@route_name", aggregation_func("vector_distance").alias("distance")) + .sort_by("@distance", max=max_k) + .dialect(2) + ) + + return aggregate_request + + def _classify( + self, + vector: List[float], + distance_threshold: float, + aggregation_method: DistanceAggregationMethod + ) -> List[RouteMatch]: + """Classify a single query vector. + + Args: + vector (List[float]): The query vector. + distance_threshold (float): The distance threshold. + aggregation_method (DistanceAggregationMethod): The aggregation method. + + Returns: + List[RouteMatch]: List of route matches. + """ vector_range_query = RangeQuery( vector=vector, vector_field_name="vector", - distance_threshold=2, - return_fields=["route_name"] + distance_threshold=distance_threshold, + return_fields=["route_name"], ) - # build redis aggregation query - aggregate_query = str(vector_range_query).split(" RETURN")[0] - aggregate_request = ( - AggregateRequest(aggregate_query) - .group_by( - "@route_name", - reducers.avg("vector_distance").alias("avg_distance"), - reducers.min("vector_distance").alias("min_distance") - ) - .dialect(2) - ) + aggregate_request = self._build_aggregate_request(vector_range_query, aggregation_method, max_k=1) + route_matches: AggregateResult = self._index.client.ft(self._index.name).aggregate(aggregate_request, vector_range_query.params) + return [self._process_route(route_match) for route_match in route_matches.rows] + + def _classify_many( + self, + vector: List[float], + max_k: int, + distance_threshold: float, + aggregation_method: DistanceAggregationMethod + ) -> List[RouteMatch]: + """Classify multiple query vectors. - # run the aggregation query in Redis - aggregate_result: AggregateResult = ( - self._index.client - .ft(self._index.name) - .aggregate(aggregate_request, vector_range_query.params) + Args: + vector (List[float]): The query vector. + max_k (int): The maximum number of top matches to return. + distance_threshold (float): The distance threshold. + aggregation_method (DistanceAggregationMethod): The aggregation method. + + Returns: + List[RouteMatch]: List of route matches. + """ + vector_range_query = RangeQuery( + vector=vector, + vector_field_name="vector", + distance_threshold=distance_threshold, + return_fields=["route_name"], ) + aggregate_request = self._build_aggregate_request(vector_range_query, aggregation_method, max_k) + route_matches: AggregateResult = self._index.client.ft(self._index.name).aggregate(aggregate_request, vector_range_query.params) + return [self._process_route(route_match) for route_match in route_matches.rows] + + def _pass_threshold(self, route_match: Optional[RouteMatch]) -> bool: + """Check if a route match passes the distance threshold. - top_routes_and_scores = sorted([ - self._process_result(result) for result in aggregate_result.rows - ], key=lambda r: r[sort_by.value]) + Args: + route_match (Optional[RouteMatch]): The route match to check. + + Returns: + bool: True if the route match passes the threshold, False otherwise. + """ + return route_match is not None and route_match.distance <= route_match.route.distance_threshold - return top_routes_and_scores[:max_k] + def __call__( + self, + statement: Optional[str] = None, + vector: Optional[List[float]] = None, + distance_threshold: Optional[float] = None, + ) -> RouteMatch: + """Query the semantic router with a given statement or vector. + Args: + statement (Optional[str]): The input statement to be queried. + vector (Optional[List[float]]): The input vector to be queried. + distance_threshold (Optional[float]): The threshold for semantic distance. - def _process_result(self, result: Dict[str, Any]) -> Dict[str, Any]: - """Process resulting route objects and metadata. + Returns: + RouteMatch: The matching route. + """ + if not vector: + if not statement: + raise ValueError("Must provide a vector or statement to the router") + vector = self.vectorizer.embed(statement) + + distance_threshold = distance_threshold or self.routing_config.distance_threshold + route_matches = self._classify(vector, distance_threshold, self.routing_config.aggregation_method) + route_match = route_matches[0] if route_matches else None + + if route_match and self._pass_threshold(route_match): + return route_match + + return RouteMatch() + + def route_many( + self, + statement: Optional[str] = None, + vector: Optional[List[float]] = None, + max_k: Optional[int] = None, + distance_threshold: Optional[float] = None, + ) -> List[RouteMatch]: + """Query the semantic router with a given statement or vector for multiple matches. Args: - result: Aggregation query result object + statement (Optional[str]): The input statement to be queried. + vector (Optional[List[float]]): The input vector to be queried. + max_k (Optional[int]): The maximum number of top matches to return. + distance_threshold (Optional[float]): The threshold for semantic distance. Returns: - List[Dict[str, Any]]: Routes with their metadata. + List[RouteMatch]: The matching routes and their details. """ - result_dict = make_dict(convert_bytes(result)) - route_name = result_dict["route_name"] - route = next((r for r in self.routes if r.name == route_name), None) - return { - **route.dict(), - "avg_distance": float(result_dict["avg_distance"]), - "min_distance": float(result_dict["min_distance"]) - } + if not vector: + if not statement: + raise ValueError("Must provide a vector or statement to the router") + vector = self.vectorizer.embed(statement) + + distance_threshold = distance_threshold or self.routing_config.distance_threshold + max_k = max_k or self.routing_config.max_k + route_matches = self._classify_many(vector, max_k, distance_threshold, self.routing_config.aggregation_method) + + return [route_match for route_match in route_matches if self._pass_threshold(route_match)] diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index dd39e5a7..48d2833a 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -1,63 +1,111 @@ import pytest -from redisvl.extensions.router.routes import Route, RoutingConfig -from redisvl.extensions.router.semantic import SemanticRouter + +from redisvl.extensions.router.schema import Route, RoutingConfig +from redisvl.extensions.router import SemanticRouter @pytest.fixture def routes(): - politics = Route( - name="politics", - references=[ - "isn't politics the best thing ever", - "why don't you tell me about your political opinions" - ], - metadata={"priority": "1"} - ) - chitchat = Route( - name="chitchat", - references=[ - "hello", - "how's the weather today?", - "how are things going?" - ], - metadata={"priority": "2"} - ) - return [politics, chitchat] + return [ + Route(name="greeting", references=["hello", "hi"], metadata={"type": "greeting"}, distance_threshold=0.3), + Route(name="farewell", references=["bye", "goodbye"], metadata={"type": "farewell"}, distance_threshold=0.3) + ] @pytest.fixture -def semantic_router(redis_client, routes): - config = RoutingConfig(distance_threshold=1.0) +def semantic_router(client, routes): router = SemanticRouter( - name="topic-router", + name="test-router", routes=routes, - routing_config=config, - redis_client=redis_client, - overwrite=True + routing_config=RoutingConfig(distance_threshold=0.3, max_k=2), + redis_client=client, + overwrite=False ) - return router + yield router + router._index.delete(drop=True) -def test_semantic_router_match_politics(semantic_router): - result = semantic_router("I am thinking about running for Governor in the state of VA. What do I need to consider?") - assert result[0]['route'].name == "politics" +def test_initialize_router(semantic_router): + assert semantic_router.name == "test-router" + assert len(semantic_router.routes) == 2 + assert semantic_router.routing_config.distance_threshold == 0.3 + assert semantic_router.routing_config.max_k == 2 -def test_semantic_router_match_chitchat(semantic_router): - result = semantic_router("hello") - assert result[0]['route'].name == "chitchat" +def test_router_properties(semantic_router): + route_names = semantic_router.route_names + assert "greeting" in route_names + assert "farewell" in route_names + thresholds = semantic_router.route_thresholds + assert thresholds["greeting"] == 0.3 + assert thresholds["farewell"] == 0.3 -def test_semantic_router_no_match(semantic_router): - result = semantic_router("unrelated topic") - assert result == [] +def test_get_route(semantic_router): + route = semantic_router.get("greeting") + assert route is not None + assert route.name == "greeting" + assert "hello" in route.references + + +def test_get_non_existing_route(semantic_router): + route = semantic_router.get("non_existent_route") + assert route is None + + +def test_single_query(semantic_router): + match = semantic_router("hello") + assert match.route is not None + assert match.route.name == "greeting" + assert match.distance <= semantic_router.route_thresholds["greeting"] + + +def test_single_query_no_match(semantic_router): + match = semantic_router("unknown_phrase") + assert match.route is None -def test_update_routing_config(semantic_router): - new_config = RoutingConfig(distance_threshold=0.1, sort_by='avg_distance') +def test_multiple_query(semantic_router): + matches = semantic_router.route_many("hello", max_k=2) + assert len(matches) > 0 + assert matches[0].route.name == "greeting" + +def test_update_routing_config(semantic_router): + new_config = RoutingConfig(distance_threshold=0.5, max_k=1) semantic_router.update_routing_config(new_config) - result = semantic_router("hello world") - assert result == [] + assert semantic_router.routing_config.distance_threshold == 0.5 + assert semantic_router.routing_config.max_k == 1 + + +def test_vector_query(semantic_router): + vector = semantic_router.vectorizer.embed("goodbye") + match = semantic_router(vector=vector) + assert match.route is not None + assert match.route.name == "farewell" + + +def test_vector_query_no_match(semantic_router): + vector = [0.0] * semantic_router.vectorizer.dims # Random vector unlikely to match any route + match = semantic_router(vector=vector) + assert match.route is None + + +def test_additional_route(semantic_router): + new_routes = [ + Route( + name="politics", + references=["are you liberal or conservative?", "who will you vote for?", "political speech"], + metadata={"type": "greeting"}, + ) + ] + semantic_router._add_routes(new_routes) + + route = semantic_router.get("politics") + assert route is not None + assert route.name == "politics" + assert "political speech" in route.references - result = semantic_router("hello world", distance_threshold=0.3) - assert len(result) > 0 + match = semantic_router("political speech") + print(match, flush=True) + assert match is not None + assert match.route.name == "politics" diff --git a/tests/unit/test_route_schema.py b/tests/unit/test_route_schema.py new file mode 100644 index 00000000..2a8eeef3 --- /dev/null +++ b/tests/unit/test_route_schema.py @@ -0,0 +1,124 @@ +import pytest +from pydantic.v1 import ValidationError +from redisvl.extensions.router.schema import Route, RouteMatch, DistanceAggregationMethod, RoutingConfig + +def test_route_valid(): + route = Route( + name="Test Route", + references=["reference1", "reference2"], + metadata={"key": "value"}, + distance_threshold=0.3 + ) + assert route.name == "Test Route" + assert route.references == ["reference1", "reference2"] + assert route.metadata == {"key": "value"} + assert route.distance_threshold == 0.3 + +def test_route_empty_name(): + with pytest.raises(ValidationError) as excinfo: + Route( + name="", + references=["reference1", "reference2"], + metadata={"key": "value"}, + distance_threshold=0.3 + ) + assert "Route name must not be empty" in str(excinfo.value) + +def test_route_empty_references(): + with pytest.raises(ValidationError) as excinfo: + Route( + name="Test Route", + references=[], + metadata={"key": "value"}, + distance_threshold=0.3 + ) + assert "References must not be empty" in str(excinfo.value) + +def test_route_non_empty_references(): + with pytest.raises(ValidationError) as excinfo: + Route( + name="Test Route", + references=["reference1", ""], + metadata={"key": "value"}, + distance_threshold=0.3 + ) + assert "All references must be non-empty strings" in str(excinfo.value) + +def test_route_valid_no_threshold(): + route = Route( + name="Test Route", + references=["reference1", "reference2"], + metadata={"key": "value"} + ) + assert route.name == "Test Route" + assert route.references == ["reference1", "reference2"] + assert route.metadata == {"key": "value"} + assert route.distance_threshold is None + +def test_route_invalid_threshold_zero(): + with pytest.raises(ValidationError) as excinfo: + Route( + name="Test Route", + references=["reference1", "reference2"], + metadata={"key": "value"}, + distance_threshold=0 + ) + assert "Route distance threshold must be greater than zero" in str(excinfo.value) + +def test_route_invalid_threshold_negative(): + with pytest.raises(ValidationError) as excinfo: + Route( + name="Test Route", + references=["reference1", "reference2"], + metadata={"key": "value"}, + distance_threshold=-0.1 + ) + assert "Route distance threshold must be greater than zero" in str(excinfo.value) + +def test_route_match(): + route = Route( + name="Test Route", + references=["reference1", "reference2"], + metadata={"key": "value"}, + distance_threshold=0.3 + ) + route_match = RouteMatch( + route=route, + distance=0.25 + ) + assert route_match.route == route + assert route_match.distance == 0.25 + +def test_route_match_no_route(): + route_match = RouteMatch() + assert route_match.route is None + assert route_match.distance is None + +def test_distance_aggregation_method(): + assert DistanceAggregationMethod.avg == DistanceAggregationMethod("avg") + assert DistanceAggregationMethod.min == DistanceAggregationMethod("min") + assert DistanceAggregationMethod.sum == DistanceAggregationMethod("sum") + +def test_routing_config_valid(): + config = RoutingConfig( + distance_threshold=0.6, + max_k=5 + ) + assert config.distance_threshold == 0.6 + assert config.max_k == 5 + +def test_routing_config_invalid_max_k(): + with pytest.raises(ValidationError) as excinfo: + RoutingConfig( + distance_threshold=0.6, + max_k=0 + ) + assert "max_k must be a positive integer" in str(excinfo.value) + +def test_routing_config_invalid_distance_threshold(): + with pytest.raises(ValidationError) as excinfo: + RoutingConfig( + distance_threshold=1.5, + max_k=5 + ) + assert "distance_threshold must be between 0 and 1" in str(excinfo.value) diff --git a/tests/unit/test_routes.py b/tests/unit/test_routes.py deleted file mode 100644 index 10e45e48..00000000 --- a/tests/unit/test_routes.py +++ /dev/null @@ -1,52 +0,0 @@ -import pytest -from pydantic.v1 import ValidationError -from redisvl.extensions.router.routes import Route, RoutingConfig, RouteSortingMethod - - -def test_route_creation(): - route = Route( - name="test_route", - references=["test reference 1", "test reference 2"], - metadata={"priority": "1"} - ) - assert route.name == "test_route" - assert route.references == ["test reference 1", "test reference 2"] - assert route.metadata == {"priority": "1"} - - -def test_route_name_empty(): - with pytest.raises(ValidationError): - Route(name="", references=["test reference"]) - - -def test_route_references_empty(): - with pytest.raises(ValidationError): - Route(name="test_route", references=[]) - - -def test_route_references_non_empty_strings(): - with pytest.raises(ValidationError): - Route(name="test_route", references=["", "test reference"]) - - -def test_routing_config_creation(): - config = RoutingConfig( - distance_threshold=0.5, - max_k=1, - sort_by=RouteSortingMethod.avg_distance - ) - assert config.distance_threshold == 0.5 - assert config.max_k == 1 - assert config.sort_by == RouteSortingMethod.avg_distance - - -def test_routing_config_invalid_max_k(): - with pytest.raises(ValidationError): - RoutingConfig(distance_threshold=0.5, max_k=0) - - -def test_routing_config_invalid_distance_threshold(): - with pytest.raises(ValidationError): - RoutingConfig(distance_threshold=-0.1, max_k=1) - with pytest.raises(ValidationError): - RoutingConfig(distance_threshold=1.1, max_k=1) From 45b41bebcf967b65ad48a2ceff3c928742aeaa76 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 17 Jul 2024 12:35:34 -0400 Subject: [PATCH 08/16] formatting and linting updates --- redisvl/extensions/router/schema.py | 4 +- redisvl/extensions/router/semantic.py | 86 +++++++++++++++++------ redisvl/schema/schema.py | 2 +- tests/integration/test_semantic_router.py | 30 ++++++-- tests/unit/test_route_schema.py | 57 ++++++++------- 5 files changed, 123 insertions(+), 56 deletions(-) diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py index 58ac2d85..62d86b20 100644 --- a/redisvl/extensions/router/schema.py +++ b/redisvl/extensions/router/schema.py @@ -64,7 +64,9 @@ class RoutingConfig(BaseModel): """The threshold for semantic distance.""" max_k: int = Field(default=1) """The maximum number of top matches to return.""" - aggregation_method: DistanceAggregationMethod = Field(default=DistanceAggregationMethod.avg) + aggregation_method: DistanceAggregationMethod = Field( + default=DistanceAggregationMethod.avg + ) """Aggregation method to use to classify queries.""" @validator("max_k") diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 80835d5b..59e1eacf 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -1,12 +1,17 @@ import hashlib -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type import redis.commands.search.reducers as reducers from pydantic.v1 import BaseModel, Field, PrivateAttr from redis import Redis -from redis.commands.search.aggregation import AggregateRequest, AggregateResult - -from redisvl.extensions.router.schema import Route, RoutingConfig, RouteMatch, DistanceAggregationMethod +from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer + +from redisvl.extensions.router.schema import ( + DistanceAggregationMethod, + Route, + RouteMatch, + RoutingConfig, +) from redisvl.index import SearchIndex from redisvl.query import RangeQuery from redisvl.redis.utils import convert_bytes, make_dict @@ -18,7 +23,7 @@ class SemanticRouterIndexSchema(IndexSchema): """Customized index schema for SemanticRouter.""" @classmethod - def from_params(cls, name: str, vector_dims: int) -> 'SemanticRouterIndexSchema': + def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema": """Create an index schema based on router name and vector dimensions. Args: @@ -30,7 +35,7 @@ def from_params(cls, name: str, vector_dims: int) -> 'SemanticRouterIndexSchema' """ return cls( index=IndexInfo(name=name, prefix=name), - fields=[ + fields=[ # type: ignore {"name": "route_name", "type": "tag"}, {"name": "reference", "type": "text"}, { @@ -87,7 +92,12 @@ def __init__( overwrite (bool, optional): Whether to overwrite existing index. Defaults to False. **kwargs: Additional arguments. """ - super().__init__(name=name, routes=routes, vectorizer=vectorizer, routing_config=routing_config) + super().__init__( + name=name, + routes=routes, + vectorizer=vectorizer, + routing_config=routing_config, + ) self._initialize_index(redis_client, redis_url, overwrite) def _initialize_index( @@ -130,7 +140,7 @@ def route_names(self) -> List[str]: return [route.name for route in self.routes] @property - def route_thresholds(self) -> Dict[str, float]: + def route_thresholds(self) -> Dict[str, Optional[float]]: """Get the distance thresholds for each route. Returns: @@ -168,7 +178,9 @@ def _add_routes(self, routes: List[Route]): } ) reference_hash = hashlib.sha256(reference.encode("utf-8")).hexdigest() - keys.append(f"{self._index.schema.index.prefix}:{route.name}:{reference_hash}") + keys.append( + f"{self._index.schema.index.prefix}:{route.name}:{reference_hash}" + ) # set route if does not yet exist client side if not self.get(route.name): @@ -204,7 +216,7 @@ def _build_aggregate_request( self, vector_range_query: RangeQuery, aggregation_method: DistanceAggregationMethod, - max_k: int + max_k: int, ) -> AggregateRequest: """Build the Redis aggregation request. @@ -216,6 +228,8 @@ def _build_aggregate_request( Returns: AggregateRequest: The constructed aggregation request. """ + aggregation_func: Type[Reducer] + if aggregation_method == DistanceAggregationMethod.min: aggregation_func = reducers.min elif aggregation_method == DistanceAggregationMethod.sum: @@ -226,7 +240,9 @@ def _build_aggregate_request( aggregate_query = str(vector_range_query).split(" RETURN")[0] aggregate_request = ( AggregateRequest(aggregate_query) - .group_by("@route_name", aggregation_func("vector_distance").alias("distance")) + .group_by( + "@route_name", aggregation_func("vector_distance").alias("distance") + ) .sort_by("@distance", max=max_k) .dialect(2) ) @@ -237,7 +253,7 @@ def _classify( self, vector: List[float], distance_threshold: float, - aggregation_method: DistanceAggregationMethod + aggregation_method: DistanceAggregationMethod, ) -> List[RouteMatch]: """Classify a single query vector. @@ -256,8 +272,12 @@ def _classify( return_fields=["route_name"], ) - aggregate_request = self._build_aggregate_request(vector_range_query, aggregation_method, max_k=1) - route_matches: AggregateResult = self._index.client.ft(self._index.name).aggregate(aggregate_request, vector_range_query.params) + aggregate_request = self._build_aggregate_request( + vector_range_query, aggregation_method, max_k=1 + ) + route_matches: AggregateResult = self._index.client.ft( # type: ignore + self._index.name + ).aggregate(aggregate_request, vector_range_query.params) return [self._process_route(route_match) for route_match in route_matches.rows] def _classify_many( @@ -265,7 +285,7 @@ def _classify_many( vector: List[float], max_k: int, distance_threshold: float, - aggregation_method: DistanceAggregationMethod + aggregation_method: DistanceAggregationMethod, ) -> List[RouteMatch]: """Classify multiple query vectors. @@ -284,8 +304,12 @@ def _classify_many( distance_threshold=distance_threshold, return_fields=["route_name"], ) - aggregate_request = self._build_aggregate_request(vector_range_query, aggregation_method, max_k) - route_matches: AggregateResult = self._index.client.ft(self._index.name).aggregate(aggregate_request, vector_range_query.params) + aggregate_request = self._build_aggregate_request( + vector_range_query, aggregation_method, max_k + ) + route_matches: AggregateResult = self._index.client.ft( # type: ignore + self._index.name + ).aggregate(aggregate_request, vector_range_query.params) return [self._process_route(route_match) for route_match in route_matches.rows] def _pass_threshold(self, route_match: Optional[RouteMatch]) -> bool: @@ -297,7 +321,11 @@ def _pass_threshold(self, route_match: Optional[RouteMatch]) -> bool: Returns: bool: True if the route match passes the threshold, False otherwise. """ - return route_match is not None and route_match.distance <= route_match.route.distance_threshold + if route_match: + if route_match.distance is not None and route_match.route is not None: + if route_match.route.distance_threshold: + return route_match.distance <= route_match.route.distance_threshold + return False def __call__( self, @@ -320,8 +348,12 @@ def __call__( raise ValueError("Must provide a vector or statement to the router") vector = self.vectorizer.embed(statement) - distance_threshold = distance_threshold or self.routing_config.distance_threshold - route_matches = self._classify(vector, distance_threshold, self.routing_config.aggregation_method) + distance_threshold = ( + distance_threshold or self.routing_config.distance_threshold + ) + route_matches = self._classify( + vector, distance_threshold, self.routing_config.aggregation_method + ) route_match = route_matches[0] if route_matches else None if route_match and self._pass_threshold(route_match): @@ -352,8 +384,16 @@ def route_many( raise ValueError("Must provide a vector or statement to the router") vector = self.vectorizer.embed(statement) - distance_threshold = distance_threshold or self.routing_config.distance_threshold + distance_threshold = ( + distance_threshold or self.routing_config.distance_threshold + ) max_k = max_k or self.routing_config.max_k - route_matches = self._classify_many(vector, max_k, distance_threshold, self.routing_config.aggregation_method) + route_matches = self._classify_many( + vector, max_k, distance_threshold, self.routing_config.aggregation_method + ) - return [route_match for route_match in route_matches if self._pass_threshold(route_match)] + return [ + route_match + for route_match in route_matches + if self._pass_threshold(route_match) + ] diff --git a/redisvl/schema/schema.py b/redisvl/schema/schema.py index 9659c79a..ed9cffd4 100644 --- a/redisvl/schema/schema.py +++ b/redisvl/schema/schema.py @@ -195,7 +195,7 @@ def validate_and_create_fields(cls, values): """ Validate uniqueness of field names and create valid field instances. """ - # Ensure index is a dictionary for validation + # Ensure index is a dictionary for validation index = values.get("index") if not isinstance(index, IndexInfo): index = IndexInfo(**index) diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index 48d2833a..5cf46e4d 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -1,16 +1,27 @@ import pytest -from redisvl.extensions.router.schema import Route, RoutingConfig from redisvl.extensions.router import SemanticRouter +from redisvl.extensions.router.schema import Route, RoutingConfig @pytest.fixture def routes(): return [ - Route(name="greeting", references=["hello", "hi"], metadata={"type": "greeting"}, distance_threshold=0.3), - Route(name="farewell", references=["bye", "goodbye"], metadata={"type": "farewell"}, distance_threshold=0.3) + Route( + name="greeting", + references=["hello", "hi"], + metadata={"type": "greeting"}, + distance_threshold=0.3, + ), + Route( + name="farewell", + references=["bye", "goodbye"], + metadata={"type": "farewell"}, + distance_threshold=0.3, + ), ] + @pytest.fixture def semantic_router(client, routes): router = SemanticRouter( @@ -18,7 +29,7 @@ def semantic_router(client, routes): routes=routes, routing_config=RoutingConfig(distance_threshold=0.3, max_k=2), redis_client=client, - overwrite=False + overwrite=False, ) yield router router._index.delete(drop=True) @@ -70,6 +81,7 @@ def test_multiple_query(semantic_router): assert len(matches) > 0 assert matches[0].route.name == "greeting" + def test_update_routing_config(semantic_router): new_config = RoutingConfig(distance_threshold=0.5, max_k=1) semantic_router.update_routing_config(new_config) @@ -85,7 +97,9 @@ def test_vector_query(semantic_router): def test_vector_query_no_match(semantic_router): - vector = [0.0] * semantic_router.vectorizer.dims # Random vector unlikely to match any route + vector = [ + 0.0 + ] * semantic_router.vectorizer.dims # Random vector unlikely to match any route match = semantic_router(vector=vector) assert match.route is None @@ -94,7 +108,11 @@ def test_additional_route(semantic_router): new_routes = [ Route( name="politics", - references=["are you liberal or conservative?", "who will you vote for?", "political speech"], + references=[ + "are you liberal or conservative?", + "who will you vote for?", + "political speech", + ], metadata={"type": "greeting"}, ) ] diff --git a/tests/unit/test_route_schema.py b/tests/unit/test_route_schema.py index 2a8eeef3..0d42451e 100644 --- a/tests/unit/test_route_schema.py +++ b/tests/unit/test_route_schema.py @@ -1,124 +1,131 @@ import pytest from pydantic.v1 import ValidationError -from redisvl.extensions.router.schema import Route, RouteMatch, DistanceAggregationMethod, RoutingConfig + +from redisvl.extensions.router.schema import ( + DistanceAggregationMethod, + Route, + RouteMatch, + RoutingConfig, +) + def test_route_valid(): route = Route( name="Test Route", references=["reference1", "reference2"], metadata={"key": "value"}, - distance_threshold=0.3 + distance_threshold=0.3, ) assert route.name == "Test Route" assert route.references == ["reference1", "reference2"] assert route.metadata == {"key": "value"} assert route.distance_threshold == 0.3 + def test_route_empty_name(): with pytest.raises(ValidationError) as excinfo: Route( name="", references=["reference1", "reference2"], metadata={"key": "value"}, - distance_threshold=0.3 + distance_threshold=0.3, ) assert "Route name must not be empty" in str(excinfo.value) + def test_route_empty_references(): with pytest.raises(ValidationError) as excinfo: Route( name="Test Route", references=[], metadata={"key": "value"}, - distance_threshold=0.3 + distance_threshold=0.3, ) assert "References must not be empty" in str(excinfo.value) + def test_route_non_empty_references(): with pytest.raises(ValidationError) as excinfo: Route( name="Test Route", references=["reference1", ""], metadata={"key": "value"}, - distance_threshold=0.3 + distance_threshold=0.3, ) assert "All references must be non-empty strings" in str(excinfo.value) + def test_route_valid_no_threshold(): route = Route( name="Test Route", references=["reference1", "reference2"], - metadata={"key": "value"} + metadata={"key": "value"}, ) assert route.name == "Test Route" assert route.references == ["reference1", "reference2"] assert route.metadata == {"key": "value"} assert route.distance_threshold is None + def test_route_invalid_threshold_zero(): with pytest.raises(ValidationError) as excinfo: Route( name="Test Route", references=["reference1", "reference2"], metadata={"key": "value"}, - distance_threshold=0 + distance_threshold=0, ) assert "Route distance threshold must be greater than zero" in str(excinfo.value) + def test_route_invalid_threshold_negative(): with pytest.raises(ValidationError) as excinfo: Route( name="Test Route", references=["reference1", "reference2"], metadata={"key": "value"}, - distance_threshold=-0.1 + distance_threshold=-0.1, ) assert "Route distance threshold must be greater than zero" in str(excinfo.value) + def test_route_match(): route = Route( name="Test Route", references=["reference1", "reference2"], metadata={"key": "value"}, - distance_threshold=0.3 - ) - route_match = RouteMatch( - route=route, - distance=0.25 + distance_threshold=0.3, ) + route_match = RouteMatch(route=route, distance=0.25) assert route_match.route == route assert route_match.distance == 0.25 + def test_route_match_no_route(): route_match = RouteMatch() assert route_match.route is None assert route_match.distance is None + def test_distance_aggregation_method(): assert DistanceAggregationMethod.avg == DistanceAggregationMethod("avg") assert DistanceAggregationMethod.min == DistanceAggregationMethod("min") assert DistanceAggregationMethod.sum == DistanceAggregationMethod("sum") + def test_routing_config_valid(): - config = RoutingConfig( - distance_threshold=0.6, - max_k=5 - ) + config = RoutingConfig(distance_threshold=0.6, max_k=5) assert config.distance_threshold == 0.6 assert config.max_k == 5 + def test_routing_config_invalid_max_k(): with pytest.raises(ValidationError) as excinfo: - RoutingConfig( - distance_threshold=0.6, - max_k=0 - ) + RoutingConfig(distance_threshold=0.6, max_k=0) assert "max_k must be a positive integer" in str(excinfo.value) + def test_routing_config_invalid_distance_threshold(): with pytest.raises(ValidationError) as excinfo: - RoutingConfig( - distance_threshold=1.5, - max_k=5 - ) + RoutingConfig(distance_threshold=1.5, max_k=5) assert "distance_threshold must be between 0 and 1" in str(excinfo.value) From f0360bc2f165c4127bf20060179ba8d47078b414 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 17 Jul 2024 12:56:26 -0400 Subject: [PATCH 09/16] handle old versions of redis --- redisvl/extensions/router/semantic.py | 35 +++++++++++++++++------ redisvl/redis/connection.py | 29 +++++++++++++++++++ tests/integration/test_connection.py | 30 +------------------ tests/integration/test_semantic_router.py | 31 +++++++++++++++++--- 4 files changed, 84 insertions(+), 41 deletions(-) diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 59e1eacf..62e4ec29 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -5,6 +5,7 @@ from pydantic.v1 import BaseModel, Field, PrivateAttr from redis import Redis from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer +from redis.exceptions import ResponseError from redisvl.extensions.router.schema import ( DistanceAggregationMethod, @@ -275,10 +276,19 @@ def _classify( aggregate_request = self._build_aggregate_request( vector_range_query, aggregation_method, max_k=1 ) - route_matches: AggregateResult = self._index.client.ft( # type: ignore - self._index.name - ).aggregate(aggregate_request, vector_range_query.params) - return [self._process_route(route_match) for route_match in route_matches.rows] + try: + route_matches: AggregateResult = self._index.client.ft( # type: ignore + self._index.name + ).aggregate(aggregate_request, vector_range_query.params) + return [ + self._process_route(route_match) for route_match in route_matches.rows + ] + except ResponseError as e: + if "VSS is not yet supported on FT.AGGREGATE" in str(e): + raise RuntimeError( + "Semantic routing is only available on Redis version 7.x.x or greater" + ) + raise e def _classify_many( self, @@ -307,10 +317,19 @@ def _classify_many( aggregate_request = self._build_aggregate_request( vector_range_query, aggregation_method, max_k ) - route_matches: AggregateResult = self._index.client.ft( # type: ignore - self._index.name - ).aggregate(aggregate_request, vector_range_query.params) - return [self._process_route(route_match) for route_match in route_matches.rows] + try: + route_matches: AggregateResult = self._index.client.ft( # type: ignore + self._index.name + ).aggregate(aggregate_request, vector_range_query.params) + return [ + self._process_route(route_match) for route_match in route_matches.rows + ] + except ResponseError as e: + if "VSS is not yet supported on FT.AGGREGATE" in str(e): + raise RuntimeError( + "Semantic routing is only available on Redis version 7.x.x or greater" + ) + raise e def _pass_threshold(self, route_match: Optional[RouteMatch]) -> bool: """Check if a route match passes the distance threshold. diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 3e949dd3..888cb127 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -18,6 +18,35 @@ from redisvl.version import __version__ +def compare_versions(version1, version2): + """ + Compare two Redis version strings numerically. + + Parameters: + version1 (str): The first version string (e.g., "7.2.4"). + version2 (str): The second version string (e.g., "6.2.1"). + + Returns: + int: -1 if version1 < version2, 0 if version1 == version2, 1 if version1 > version2. + """ + v1_parts = list(map(int, version1.split("."))) + v2_parts = list(map(int, version2.split("."))) + + for v1, v2 in zip(v1_parts, v2_parts): + if v1 < v2: + return False + elif v1 > v2: + return True + + # If the versions are equal so far, compare the lengths of the version parts + if len(v1_parts) < len(v2_parts): + return False + elif len(v1_parts) > len(v2_parts): + return True + + return True + + def unpack_redis_modules(module_list: List[Dict[str, Any]]) -> Dict[str, Any]: """Unpack a list of Redis modules pulled from the MODULES LIST command.""" return {module["name"]: module["ver"] for module in module_list} diff --git a/tests/integration/test_connection.py b/tests/integration/test_connection.py index 30f1c3f3..99608ddc 100644 --- a/tests/integration/test_connection.py +++ b/tests/integration/test_connection.py @@ -7,6 +7,7 @@ from redisvl.redis.connection import ( RedisConnectionFactory, + compare_versions, convert_index_info_to_schema, get_address_from_env, unpack_redis_modules, @@ -18,35 +19,6 @@ EXPECTED_LIB_NAME = f"redis-py(redisvl_v{__version__})" -def compare_versions(version1, version2): - """ - Compare two Redis version strings numerically. - - Parameters: - version1 (str): The first version string (e.g., "7.2.4"). - version2 (str): The second version string (e.g., "6.2.1"). - - Returns: - int: -1 if version1 < version2, 0 if version1 == version2, 1 if version1 > version2. - """ - v1_parts = list(map(int, version1.split("."))) - v2_parts = list(map(int, version2.split("."))) - - for v1, v2 in zip(v1_parts, v2_parts): - if v1 < v2: - return False - elif v1 > v2: - return True - - # If the versions are equal so far, compare the lengths of the version parts - if len(v1_parts) < len(v2_parts): - return False - elif len(v1_parts) > len(v2_parts): - return True - - return True - - def test_get_address_from_env(redis_url): assert get_address_from_env() == redis_url diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index 5cf46e4d..3ade7e1c 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -2,6 +2,7 @@ from redisvl.extensions.router import SemanticRouter from redisvl.extensions.router.schema import Route, RoutingConfig +from redisvl.redis.connection import compare_versions @pytest.fixture @@ -65,6 +66,10 @@ def test_get_non_existing_route(semantic_router): def test_single_query(semantic_router): + redis_version = semantic_router._index.client.info()["redis_version"] + if not compare_versions(redis_version, "7.0.0"): + pytest.skip("Not using a late enough version of Redis") + match = semantic_router("hello") assert match.route is not None assert match.route.name == "greeting" @@ -72,11 +77,19 @@ def test_single_query(semantic_router): def test_single_query_no_match(semantic_router): + redis_version = semantic_router._index.client.info()["redis_version"] + if not compare_versions(redis_version, "7.0.0"): + pytest.skip("Not using a late enough version of Redis") + match = semantic_router("unknown_phrase") assert match.route is None def test_multiple_query(semantic_router): + redis_version = semantic_router._index.client.info()["redis_version"] + if not compare_versions(redis_version, "7.0.0"): + pytest.skip("Not using a late enough version of Redis") + matches = semantic_router.route_many("hello", max_k=2) assert len(matches) > 0 assert matches[0].route.name == "greeting" @@ -90,6 +103,10 @@ def test_update_routing_config(semantic_router): def test_vector_query(semantic_router): + redis_version = semantic_router._index.client.info()["redis_version"] + if not compare_versions(redis_version, "7.0.0"): + pytest.skip("Not using a late enough version of Redis") + vector = semantic_router.vectorizer.embed("goodbye") match = semantic_router(vector=vector) assert match.route is not None @@ -97,6 +114,10 @@ def test_vector_query(semantic_router): def test_vector_query_no_match(semantic_router): + redis_version = semantic_router._index.client.info()["redis_version"] + if not compare_versions(redis_version, "7.0.0"): + pytest.skip("Not using a late enough version of Redis") + vector = [ 0.0 ] * semantic_router.vectorizer.dims # Random vector unlikely to match any route @@ -123,7 +144,9 @@ def test_additional_route(semantic_router): assert route.name == "politics" assert "political speech" in route.references - match = semantic_router("political speech") - print(match, flush=True) - assert match is not None - assert match.route.name == "politics" + redis_version = semantic_router._index.client.info()["redis_version"] + if compare_versions(redis_version, "7.0.0"): + match = semantic_router("political speech") + print(match, flush=True) + assert match is not None + assert match.route.name == "politics" From 9b0153d4d0ec8901231144cd711edcbf4abdb506 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 17 Jul 2024 12:56:34 -0400 Subject: [PATCH 10/16] docs --- docs/api/cache.rst | 2 +- docs/api/index.md | 1 + docs/api/query.rst | 5 +++-- docs/api/router.rst | 39 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 44 insertions(+), 3 deletions(-) create mode 100644 docs/api/router.rst diff --git a/docs/api/cache.rst b/docs/api/cache.rst index 7e34ee1f..6d5b72f7 100644 --- a/docs/api/cache.rst +++ b/docs/api/cache.rst @@ -1,6 +1,6 @@ ******** -LLMCache +LLM Cache ******** SemanticCache diff --git a/docs/api/index.md b/docs/api/index.md index b2c45ab6..14371397 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -18,5 +18,6 @@ filter vectorizer reranker cache +router ``` diff --git a/docs/api/query.rst b/docs/api/query.rst index 9146160a..a06d8bad 100644 --- a/docs/api/query.rst +++ b/docs/api/query.rst @@ -3,11 +3,12 @@ Query ***** +.. _query_api: + + VectorQuery =========== -.. _query_api: - .. currentmodule:: redisvl.query diff --git a/docs/api/router.rst b/docs/api/router.rst new file mode 100644 index 00000000..e13e8ab9 --- /dev/null +++ b/docs/api/router.rst @@ -0,0 +1,39 @@ + +******** +Semantic Router +******** + +.. _semantic_router_api: + + +Semantic Router +============= + +.. currentmodule:: redisvl.extensions.router + +.. autoclass:: SemanticRouter + :show-inheritance: + :members: + :inherited-members: + + +Routing Config +=============== + +.. currentmodule:: redisvl.extensions.router + +.. autoclass:: RoutingConfig + :show-inheritance: + :members: + :inherited-members: + + +Route +===== + +.. currentmodule:: redisvl.extensions.router + +.. autoclass:: Route + :show-inheritance: + :members: + :inherited-members: From e5d6b043e109affc57d3268bab6b58d5cc1c306b Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 17 Jul 2024 13:39:04 -0400 Subject: [PATCH 11/16] update docs --- docs/_static/js/sidebar.js | 9 ++++++-- docs/api/cache.rst | 4 ++-- docs/api/router.rst | 30 ++++++++++++++++++--------- redisvl/extensions/router/semantic.py | 15 ++++++++++---- 4 files changed, 40 insertions(+), 18 deletions(-) diff --git a/docs/_static/js/sidebar.js b/docs/_static/js/sidebar.js index 7a52eb54..3d78af59 100644 --- a/docs/_static/js/sidebar.js +++ b/docs/_static/js/sidebar.js @@ -9,7 +9,7 @@ const toc = [ { title: "Query and Filter", path: "/user_guide/hybrid_queries_02.html" }, { title: "JSON vs Hash Storage", path: "/user_guide/hash_vs_json_05.html" }, { title: "Vectorizers", path: "/user_guide/vectorizers_04.html" }, - { title: "Rerankers", path: "/user_guide/rerankers_06.html"}, + { title: "Rerankers", path: "/user_guide/rerankers_06.html" }, { title: "Semantic Caching", path: "/user_guide/llmcache_03.html" }, ]}, { header: "API", toc: [ @@ -17,9 +17,14 @@ const toc = [ { title: "Search Index", path: "/api/searchindex.html" }, { title: "Query", path: "/api/query.html" }, { title: "Filter", path: "/api/filter.html" }, + ]}, + { header: "Utils", toc: [ { title: "Vectorizers", path: "/api/vectorizer.html" }, { title: "Rerankers", path: "/api/reranker.html" }, - { title: "LLMCache", path: "/api/cache.html" } + ]}, + { header: "Extensions", toc: [ + { title: "LLM Cache", path: "/api/cache.html" }, + { title: "Semantic Router", path: "/api/router.html" }, ]} ]; diff --git a/docs/api/cache.rst b/docs/api/cache.rst index 6d5b72f7..9c921c60 100644 --- a/docs/api/cache.rst +++ b/docs/api/cache.rst @@ -1,7 +1,7 @@ -******** +********* LLM Cache -******** +********* SemanticCache ============= diff --git a/docs/api/router.rst b/docs/api/router.rst index e13e8ab9..fbaab435 100644 --- a/docs/api/router.rst +++ b/docs/api/router.rst @@ -1,31 +1,27 @@ -******** +*************** Semantic Router -******** +*************** .. _semantic_router_api: Semantic Router -============= +=============== .. currentmodule:: redisvl.extensions.router .. autoclass:: SemanticRouter - :show-inheritance: :members: - :inherited-members: Routing Config -=============== +============== .. currentmodule:: redisvl.extensions.router .. autoclass:: RoutingConfig - :show-inheritance: :members: - :inherited-members: Route @@ -34,6 +30,20 @@ Route .. currentmodule:: redisvl.extensions.router .. autoclass:: Route - :show-inheritance: :members: - :inherited-members: + +Route Match +=========== + +.. currentmodule:: redisvl.extensions.router.schema + +.. autoclass:: RouteMatch + :members: + +Distance Aggregation Method +=========================== + +.. currentmodule:: redisvl.extensions.router.schema + +.. autoclass:: DistanceAggregationMethod + :members: diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 62e4ec29..66de42d4 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -74,8 +74,8 @@ def __init__( self, name: str, routes: List[Route], - vectorizer: BaseVectorizer = HFTextVectorizer(), - routing_config: RoutingConfig = RoutingConfig(), + vectorizer: Optional[BaseVectorizer] = None, + routing_config: Optional[RoutingConfig] = None, redis_client: Optional[Redis] = None, redis_url: Optional[str] = None, overwrite: bool = False, @@ -86,13 +86,20 @@ def __init__( Args: name (str): The name of the semantic router. routes (List[Route]): List of Route objects. - vectorizer (BaseVectorizer, optional): The vectorizer used to embed route references. Defaults to HFTextVectorizer(). - routing_config (RoutingConfig, optional): Configuration for routing behavior. Defaults to RoutingConfig(). + vectorizer (BaseVectorizer, optional): The vectorizer used to embed route references. Defaults to default HFTextVectorizer. + routing_config (RoutingConfig, optional): Configuration for routing behavior. Defaults to the default RoutingConfig. redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None. redis_url (Optional[str], optional): Redis URL for connection. Defaults to None. overwrite (bool, optional): Whether to overwrite existing index. Defaults to False. **kwargs: Additional arguments. """ + # Set vectorizer default + if vectorizer is None: + vectorizer = HFTextVectorizer() + + if routing_config is None: + routing_config = RoutingConfig() + super().__init__( name=name, routes=routes, From b6f6d699b4dc2a935e789e563bc0f5a870ab9f3d Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 17 Jul 2024 14:12:08 -0400 Subject: [PATCH 12/16] wip on docs --- redisvl/extensions/router/semantic.py | 32 ++++++++++++++++++--------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 66de42d4..6edbde54 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -174,9 +174,7 @@ def _add_routes(self, routes: List[Route]): keys: List[str] = [] for route in routes: - if route.distance_threshold is None: - route.distance_threshold = self.routing_config.distance_threshold - # set route reference + # set route references for reference in route.references: route_references.append( { @@ -338,19 +336,21 @@ def _classify_many( ) raise e - def _pass_threshold(self, route_match: Optional[RouteMatch]) -> bool: + def _pass_threshold(self, route_match: Optional[RouteMatch], distance_threshold: float) -> bool: """Check if a route match passes the distance threshold. Args: route_match (Optional[RouteMatch]): The route match to check. + distance_threshold (float): The fallback distance threshold to use if not assigned to a route. Returns: bool: True if the route match passes the threshold, False otherwise. """ if route_match: if route_match.distance is not None and route_match.route is not None: - if route_match.route.distance_threshold: - return route_match.distance <= route_match.route.distance_threshold + _distance_threshold = route_match.route.distance_threshold or distance_threshold + if _distance_threshold: + return route_match.distance <= _distance_threshold return False def __call__( @@ -358,6 +358,7 @@ def __call__( statement: Optional[str] = None, vector: Optional[List[float]] = None, distance_threshold: Optional[float] = None, + aggregation_method: Optional[DistanceAggregationMethod] = None ) -> RouteMatch: """Query the semantic router with a given statement or vector. @@ -365,6 +366,7 @@ def __call__( statement (Optional[str]): The input statement to be queried. vector (Optional[List[float]]): The input vector to be queried. distance_threshold (Optional[float]): The threshold for semantic distance. + aggregation_method (Optional[DistanceAggregationMethod]): The aggregation method used for vector distances. Returns: RouteMatch: The matching route. @@ -377,12 +379,13 @@ def __call__( distance_threshold = ( distance_threshold or self.routing_config.distance_threshold ) + aggregation_method = aggregation_method or self.routing_config.aggregation_method route_matches = self._classify( - vector, distance_threshold, self.routing_config.aggregation_method + vector, distance_threshold, aggregation_method ) route_match = route_matches[0] if route_matches else None - if route_match and self._pass_threshold(route_match): + if route_match and self._pass_threshold(route_match, distance_threshold): return route_match return RouteMatch() @@ -393,6 +396,7 @@ def route_many( vector: Optional[List[float]] = None, max_k: Optional[int] = None, distance_threshold: Optional[float] = None, + aggregation_method: Optional[DistanceAggregationMethod] = None ) -> List[RouteMatch]: """Query the semantic router with a given statement or vector for multiple matches. @@ -401,6 +405,7 @@ def route_many( vector (Optional[List[float]]): The input vector to be queried. max_k (Optional[int]): The maximum number of top matches to return. distance_threshold (Optional[float]): The threshold for semantic distance. + aggregation_method (Optional[DistanceAggregationMethod]): The aggregation method used for vector distances. Returns: List[RouteMatch]: The matching routes and their details. @@ -414,12 +419,19 @@ def route_many( distance_threshold or self.routing_config.distance_threshold ) max_k = max_k or self.routing_config.max_k + aggregation_method = aggregation_method or self.routing_config.aggregation_method route_matches = self._classify_many( - vector, max_k, distance_threshold, self.routing_config.aggregation_method + vector, max_k, distance_threshold, aggregation_method ) return [ route_match for route_match in route_matches - if self._pass_threshold(route_match) + if self._pass_threshold(route_match, distance_threshold) ] + + def delete(self): + self._index.delete(drop=True) + + def clear(self): + self._index.clear() From 6ddd05bc4444ba58a5196042aa525d39ac96be22 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 17 Jul 2024 14:12:18 -0400 Subject: [PATCH 13/16] wip on user guide --- docs/user_guide/semantic_router_08.ipynb | 398 +++++++++++++++++++++++ 1 file changed, 398 insertions(+) create mode 100644 docs/user_guide/semantic_router_08.ipynb diff --git a/docs/user_guide/semantic_router_08.ipynb b/docs/user_guide/semantic_router_08.ipynb new file mode 100644 index 00000000..f00c5428 --- /dev/null +++ b/docs/user_guide/semantic_router_08.ipynb @@ -0,0 +1,398 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Semantic Routing\n", + "\n", + "RedisVL provides a `SemanticRouter` interface to utilize Redis' built-in search & aggregation in order to perform\n", + "KNN-style classification over a set of `Route` references to determine the best match.\n", + "\n", + "This notebook will go over how to use Redis as a Semantic Router for your applications" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the Routes\n", + "\n", + "Below we define 3 different routes. One for `technology`, one for `sports`, and\n", + "another for `entertainment`. Now for this example, the goal here is\n", + "surely topic \"classification\". But you can create routes and references for\n", + "almost anything.\n", + "\n", + "Each route has a set of references that cover the \"semantic surface area\" of the\n", + "route. The incoming query from a user needs to be semantically similar to one or\n", + "more of the references in order to \"match\" on the route." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from redisvl.extensions.router import Route\n", + "\n", + "\n", + "# Define routes for the semantic router\n", + "technology = Route(\n", + " name=\"technology\",\n", + " references=[\n", + " \"what are the latest advancements in AI?\",\n", + " \"tell me about the newest gadgets\",\n", + " \"what's trending in tech?\"\n", + " ],\n", + " metadata={\"category\": \"tech\", \"priority\": 1}\n", + ")\n", + "\n", + "sports = Route(\n", + " name=\"sports\",\n", + " references=[\n", + " \"who won the game last night?\",\n", + " \"tell me about the upcoming sports events\",\n", + " \"what's the latest in the world of sports?\",\n", + " \"sports\",\n", + " \"basketball and football\"\n", + " ],\n", + " metadata={\"category\": \"sports\", \"priority\": 2}\n", + ")\n", + "\n", + "entertainment = Route(\n", + " name=\"entertainment\",\n", + " references=[\n", + " \"what are the top movies right now?\",\n", + " \"who won the best actor award?\",\n", + " \"what's new in the entertainment industry?\"\n", + " ],\n", + " metadata={\"category\": \"entertainment\", \"priority\": 3}\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize the SemanticRouter\n", + "\n", + "``SemanticRouter`` will automatically create an index within Redis upon initialization for the route references. By default, it uses the `HFTextVectorizer` to \n", + "generate embeddings for each route reference." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "14:09:10 redisvl.index.index INFO Index already exists, overwriting.\n" + ] + } + ], + "source": [ + "import os\n", + "from redisvl.extensions.router import SemanticRouter\n", + "from redisvl.utils.vectorize import HFTextVectorizer\n", + "\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "\n", + "# Initialize the SemanticRouter\n", + "router = SemanticRouter(\n", + " name=\"topic-router\",\n", + " vectorizer=HFTextVectorizer(),\n", + " routes=[technology, sports, entertainment],\n", + " redis_url=\"redis://localhost:6379\",\n", + " overwrite=True # Blow away any other routing index with this name\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "HFTextVectorizer(model='sentence-transformers/all-mpnet-base-v2', dims=768)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "router.vectorizer" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Index Information:\n", + "╭──────────────┬────────────────┬──────────────────┬─────────────────┬────────────╮\n", + "│ Index Name │ Storage Type │ Prefixes │ Index Options │ Indexing │\n", + "├──────────────┼────────────────┼──────────────────┼─────────────────┼────────────┤\n", + "│ topic-router │ HASH │ ['topic-router'] │ [] │ 0 │\n", + "╰──────────────┴────────────────┴──────────────────┴─────────────────┴────────────╯\n", + "Index Fields:\n", + "╭────────────┬─────────────┬────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬─────────────────┬────────────────╮\n", + "│ Name │ Attribute │ Type │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │\n", + "├────────────┼─────────────┼────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────────────────┼────────────────┤\n", + "│ route_name │ route_name │ TAG │ SEPARATOR │ , │ │ │ │ │ │ │\n", + "│ reference │ reference │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n", + "│ vector │ vector │ VECTOR │ algorithm │ FLAT │ data_type │ FLOAT32 │ dim │ 768 │ distance_metric │ COSINE │\n", + "╰────────────┴─────────────┴────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴─────────────────┴────────────────╯\n" + ] + } + ], + "source": [ + "# look at the index specification created for the semantic router\n", + "!rvl index info -i topic-router" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simple routing" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteMatch(route=Route(name='technology', references=['what are the latest advancements in AI?', 'tell me about the newest gadgets', \"what's trending in tech?\"], metadata={'category': 'tech', 'priority': '1'}, distance_threshold=None), distance=0.119614183903)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Query the router with a statement\n", + "route_match = router(\"Can you tell me about the latest in artificial intelligence?\")\n", + "route_match" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteMatch(route=Route(name='sports', references=['who won the game last night?', 'tell me about the upcoming sports events', \"what's the latest in the world of sports?\", 'sports', 'basketball and football'], metadata={'category': 'sports', 'priority': '2'}, distance_threshold=None), distance=0.554210186005)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Toggle the runtime distance threshold\n", + "route_match = router(\"Which basketball team will win the NBA finals?\", distance_threshold=0.7)\n", + "route_match" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also route a statement to many routes and order them by distance:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[RouteMatch(route=Route(name='sports', references=['who won the game last night?', 'tell me about the upcoming sports events', \"what's the latest in the world of sports?\", 'sports', 'basketball and football'], metadata={'category': 'sports', 'priority': '2'}, distance_threshold=None), distance=0.758580672741),\n", + " RouteMatch(route=Route(name='entertainment', references=['what are the top movies right now?', 'who won the best actor award?', \"what's new in the entertainment industry?\"], metadata={'category': 'entertainment', 'priority': '3'}, distance_threshold=None), distance=0.812423805396),\n", + " RouteMatch(route=Route(name='technology', references=['what are the latest advancements in AI?', 'tell me about the newest gadgets', \"what's trending in tech?\"], metadata={'category': 'tech', 'priority': '1'}, distance_threshold=None), distance=0.884235262871)]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Perform multi-class classification with route_many() -- toggle the max_k and the distance_threshold\n", + "route_matches = router.route_many(\"Lebron James\", distance_threshold=1.0, max_k=3)\n", + "route_matches" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[RouteMatch(route=Route(name='sports', references=['who won the game last night?', 'tell me about the upcoming sports events', \"what's the latest in the world of sports?\", 'sports', 'basketball and football'], metadata={'category': 'sports', 'priority': '2'}, distance_threshold=None), distance=0.663254022598),\n", + " RouteMatch(route=Route(name='entertainment', references=['what are the top movies right now?', 'who won the best actor award?', \"what's new in the entertainment industry?\"], metadata={'category': 'entertainment', 'priority': '3'}, distance_threshold=None), distance=0.712985336781),\n", + " RouteMatch(route=Route(name='technology', references=['what are the latest advancements in AI?', 'tell me about the newest gadgets', \"what's trending in tech?\"], metadata={'category': 'tech', 'priority': '1'}, distance_threshold=None), distance=0.832674443722)]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Toggle the aggregation method -- note the different distances in the result\n", + "from redisvl.extensions.router.schema import DistanceAggregationMethod\n", + "\n", + "route_matches = router.route_many(\"Lebron James\", aggregation_method=DistanceAggregationMethod.min, distance_threshold=1.0, max_k=3)\n", + "route_matches" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note the different route match distances. This is because we used the `min` aggregation method instead of the default `avg` approach." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Update the routing config" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from redisvl.extensions.router import RoutingConfig\n", + "\n", + "router.update_routing_config(\n", + " RoutingConfig(distance_threshold=1.0, aggregation_method=DistanceAggregationMethod.min, max_k=3)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[RouteMatch(route=Route(name='sports', references=['who won the game last night?', 'tell me about the upcoming sports events', \"what's the latest in the world of sports?\", 'sports', 'basketball and football'], metadata={'category': 'sports', 'priority': '2'}, distance_threshold=None), distance=0.663254022598),\n", + " RouteMatch(route=Route(name='entertainment', references=['what are the top movies right now?', 'who won the best actor award?', \"what's new in the entertainment industry?\"], metadata={'category': 'entertainment', 'priority': '3'}, distance_threshold=None), distance=0.712985336781),\n", + " RouteMatch(route=Route(name='technology', references=['what are the latest advancements in AI?', 'tell me about the newest gadgets', \"what's trending in tech?\"], metadata={'category': 'tech', 'priority': '1'}, distance_threshold=None), distance=0.832674443722)]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "route_matches = router.route_many(\"Lebron James\")\n", + "route_matches" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Clean up the router" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'SearchIndex' object has no attribute 'clear'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[11], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Use clear to flush all routes from the index\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mrouter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclear\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/AppliedAI/redis-vl-python/redisvl/extensions/router/semantic.py:437\u001b[0m, in \u001b[0;36mSemanticRouter.clear\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 436\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclear\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 437\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_index\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclear\u001b[49m()\n", + "\u001b[0;31mAttributeError\u001b[0m: 'SearchIndex' object has no attribute 'clear'" + ] + } + ], + "source": [ + "# Use clear to flush all routes from the index\n", + "router.clear()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use delete to clear the index and remove it completely\n", + "router.delete()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "rvl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 336b9d5860d9a2eae2e1c7fd86167f5b7242abfb Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 17 Jul 2024 14:17:42 -0400 Subject: [PATCH 14/16] tweaks --- docs/_static/js/sidebar.js | 1 + docs/user_guide/index.md | 1 + docs/user_guide/semantic_router_08.ipynb | 29 +++-------------------- redisvl/extensions/router/semantic.py | 22 ++++++++++------- tests/integration/test_semantic_router.py | 2 +- 5 files changed, 20 insertions(+), 35 deletions(-) diff --git a/docs/_static/js/sidebar.js b/docs/_static/js/sidebar.js index 3d78af59..00b49dd5 100644 --- a/docs/_static/js/sidebar.js +++ b/docs/_static/js/sidebar.js @@ -11,6 +11,7 @@ const toc = [ { title: "Vectorizers", path: "/user_guide/vectorizers_04.html" }, { title: "Rerankers", path: "/user_guide/rerankers_06.html" }, { title: "Semantic Caching", path: "/user_guide/llmcache_03.html" }, + { title: "Semantic Routing", path: "/user_guide/semantic_router_08.html" }, ]}, { header: "API", toc: [ { title: "Schema", path: "/api/schema.html"}, diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md index 6f0d0b5f..9664a9a3 100644 --- a/docs/user_guide/index.md +++ b/docs/user_guide/index.md @@ -18,5 +18,6 @@ vectorizers_04 hash_vs_json_05 rerankers_06 session_manager_07 +semantic_router_08 ``` diff --git a/docs/user_guide/semantic_router_08.ipynb b/docs/user_guide/semantic_router_08.ipynb index f00c5428..5716ea63 100644 --- a/docs/user_guide/semantic_router_08.ipynb +++ b/docs/user_guide/semantic_router_08.ipynb @@ -90,7 +90,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "14:09:10 redisvl.index.index INFO Index already exists, overwriting.\n" + "14:13:26 redisvl.index.index INFO Index already exists, overwriting.\n" ] } ], @@ -136,16 +136,6 @@ "execution_count": 4, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", - "To disable this warning, you can either:\n", - "\t- Avoid using `tokenizers` before the fork if possible\n", - "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -343,20 +333,7 @@ "cell_type": "code", "execution_count": 11, "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'SearchIndex' object has no attribute 'clear'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[11], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Use clear to flush all routes from the index\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mrouter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclear\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/AppliedAI/redis-vl-python/redisvl/extensions/router/semantic.py:437\u001b[0m, in \u001b[0;36mSemanticRouter.clear\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 436\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclear\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 437\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_index\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclear\u001b[49m()\n", - "\u001b[0;31mAttributeError\u001b[0m: 'SearchIndex' object has no attribute 'clear'" - ] - } - ], + "outputs": [], "source": [ "# Use clear to flush all routes from the index\n", "router.clear()" @@ -364,7 +341,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 6edbde54..f0b89848 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -336,7 +336,9 @@ def _classify_many( ) raise e - def _pass_threshold(self, route_match: Optional[RouteMatch], distance_threshold: float) -> bool: + def _pass_threshold( + self, route_match: Optional[RouteMatch], distance_threshold: float + ) -> bool: """Check if a route match passes the distance threshold. Args: @@ -348,7 +350,9 @@ def _pass_threshold(self, route_match: Optional[RouteMatch], distance_threshold: """ if route_match: if route_match.distance is not None and route_match.route is not None: - _distance_threshold = route_match.route.distance_threshold or distance_threshold + _distance_threshold = ( + route_match.route.distance_threshold or distance_threshold + ) if _distance_threshold: return route_match.distance <= _distance_threshold return False @@ -358,7 +362,7 @@ def __call__( statement: Optional[str] = None, vector: Optional[List[float]] = None, distance_threshold: Optional[float] = None, - aggregation_method: Optional[DistanceAggregationMethod] = None + aggregation_method: Optional[DistanceAggregationMethod] = None, ) -> RouteMatch: """Query the semantic router with a given statement or vector. @@ -379,10 +383,10 @@ def __call__( distance_threshold = ( distance_threshold or self.routing_config.distance_threshold ) - aggregation_method = aggregation_method or self.routing_config.aggregation_method - route_matches = self._classify( - vector, distance_threshold, aggregation_method + aggregation_method = ( + aggregation_method or self.routing_config.aggregation_method ) + route_matches = self._classify(vector, distance_threshold, aggregation_method) route_match = route_matches[0] if route_matches else None if route_match and self._pass_threshold(route_match, distance_threshold): @@ -396,7 +400,7 @@ def route_many( vector: Optional[List[float]] = None, max_k: Optional[int] = None, distance_threshold: Optional[float] = None, - aggregation_method: Optional[DistanceAggregationMethod] = None + aggregation_method: Optional[DistanceAggregationMethod] = None, ) -> List[RouteMatch]: """Query the semantic router with a given statement or vector for multiple matches. @@ -419,7 +423,9 @@ def route_many( distance_threshold or self.routing_config.distance_threshold ) max_k = max_k or self.routing_config.max_k - aggregation_method = aggregation_method or self.routing_config.aggregation_method + aggregation_method = ( + aggregation_method or self.routing_config.aggregation_method + ) route_matches = self._classify_many( vector, max_k, distance_threshold, aggregation_method ) diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index 3ade7e1c..75947698 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -33,7 +33,7 @@ def semantic_router(client, routes): overwrite=False, ) yield router - router._index.delete(drop=True) + router.delete() def test_initialize_router(semantic_router): From 4692aeecf75630e1be92f54c091af48107207a88 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 17 Jul 2024 17:00:36 -0400 Subject: [PATCH 15/16] updates to router after some feedback --- docs/api/router.rst | 2 + docs/user_guide/semantic_router_08.ipynb | 72 ++++++----- redisvl/extensions/llmcache/base.py | 5 +- redisvl/extensions/router/schema.py | 6 +- redisvl/extensions/router/semantic.py | 150 +++++++++++++++------- redisvl/redis/utils.py | 6 + tests/integration/test_semantic_router.py | 24 ++-- tests/unit/test_route_schema.py | 12 +- 8 files changed, 181 insertions(+), 96 deletions(-) diff --git a/docs/api/router.rst b/docs/api/router.rst index fbaab435..191d3e0f 100644 --- a/docs/api/router.rst +++ b/docs/api/router.rst @@ -32,6 +32,7 @@ Route .. autoclass:: Route :members: + Route Match =========== @@ -40,6 +41,7 @@ Route Match .. autoclass:: RouteMatch :members: + Distance Aggregation Method =========================== diff --git a/docs/user_guide/semantic_router_08.ipynb b/docs/user_guide/semantic_router_08.ipynb index 5716ea63..0fc01beb 100644 --- a/docs/user_guide/semantic_router_08.ipynb +++ b/docs/user_guide/semantic_router_08.ipynb @@ -85,15 +85,7 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "14:13:26 redisvl.index.index INFO Index already exists, overwriting.\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "from redisvl.extensions.router import SemanticRouter\n", @@ -179,7 +171,7 @@ { "data": { "text/plain": [ - "RouteMatch(route=Route(name='technology', references=['what are the latest advancements in AI?', 'tell me about the newest gadgets', \"what's trending in tech?\"], metadata={'category': 'tech', 'priority': '1'}, distance_threshold=None), distance=0.119614183903)" + "RouteMatch(name='technology', distance=0.119614183903)" ] }, "execution_count": 5, @@ -201,7 +193,7 @@ { "data": { "text/plain": [ - "RouteMatch(route=Route(name='sports', references=['who won the game last night?', 'tell me about the upcoming sports events', \"what's the latest in the world of sports?\", 'sports', 'basketball and football'], metadata={'category': 'sports', 'priority': '2'}, distance_threshold=None), distance=0.554210186005)" + "RouteMatch(name=None, distance=None)" ] }, "execution_count": 6, @@ -209,6 +201,28 @@ "output_type": "execute_result" } ], + "source": [ + "# Query the router with a statement and return a miss\n", + "route_match = router(\"are aliens real?\")\n", + "route_match" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteMatch(name='sports', distance=0.554210186005)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Toggle the runtime distance threshold\n", "route_match = router(\"Which basketball team will win the NBA finals?\", distance_threshold=0.7)\n", @@ -224,18 +238,18 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[RouteMatch(route=Route(name='sports', references=['who won the game last night?', 'tell me about the upcoming sports events', \"what's the latest in the world of sports?\", 'sports', 'basketball and football'], metadata={'category': 'sports', 'priority': '2'}, distance_threshold=None), distance=0.758580672741),\n", - " RouteMatch(route=Route(name='entertainment', references=['what are the top movies right now?', 'who won the best actor award?', \"what's new in the entertainment industry?\"], metadata={'category': 'entertainment', 'priority': '3'}, distance_threshold=None), distance=0.812423805396),\n", - " RouteMatch(route=Route(name='technology', references=['what are the latest advancements in AI?', 'tell me about the newest gadgets', \"what's trending in tech?\"], metadata={'category': 'tech', 'priority': '1'}, distance_threshold=None), distance=0.884235262871)]" + "[RouteMatch(name='sports', distance=0.758580672741),\n", + " RouteMatch(name='entertainment', distance=0.812423805396),\n", + " RouteMatch(name='technology', distance=0.884235262871)]" ] }, - "execution_count": 7, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -248,18 +262,18 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[RouteMatch(route=Route(name='sports', references=['who won the game last night?', 'tell me about the upcoming sports events', \"what's the latest in the world of sports?\", 'sports', 'basketball and football'], metadata={'category': 'sports', 'priority': '2'}, distance_threshold=None), distance=0.663254022598),\n", - " RouteMatch(route=Route(name='entertainment', references=['what are the top movies right now?', 'who won the best actor award?', \"what's new in the entertainment industry?\"], metadata={'category': 'entertainment', 'priority': '3'}, distance_threshold=None), distance=0.712985336781),\n", - " RouteMatch(route=Route(name='technology', references=['what are the latest advancements in AI?', 'tell me about the newest gadgets', \"what's trending in tech?\"], metadata={'category': 'tech', 'priority': '1'}, distance_threshold=None), distance=0.832674443722)]" + "[RouteMatch(name='sports', distance=0.663254022598),\n", + " RouteMatch(name='entertainment', distance=0.712985336781),\n", + " RouteMatch(name='technology', distance=0.832674443722)]" ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -288,7 +302,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -301,18 +315,18 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[RouteMatch(route=Route(name='sports', references=['who won the game last night?', 'tell me about the upcoming sports events', \"what's the latest in the world of sports?\", 'sports', 'basketball and football'], metadata={'category': 'sports', 'priority': '2'}, distance_threshold=None), distance=0.663254022598),\n", - " RouteMatch(route=Route(name='entertainment', references=['what are the top movies right now?', 'who won the best actor award?', \"what's new in the entertainment industry?\"], metadata={'category': 'entertainment', 'priority': '3'}, distance_threshold=None), distance=0.712985336781),\n", - " RouteMatch(route=Route(name='technology', references=['what are the latest advancements in AI?', 'tell me about the newest gadgets', \"what's trending in tech?\"], metadata={'category': 'tech', 'priority': '1'}, distance_threshold=None), distance=0.832674443722)]" + "[RouteMatch(name='sports', distance=0.663254022598),\n", + " RouteMatch(name='entertainment', distance=0.712985336781),\n", + " RouteMatch(name='technology', distance=0.832674443722)]" ] }, - "execution_count": 10, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -331,7 +345,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -341,7 +355,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ diff --git a/redisvl/extensions/llmcache/base.py b/redisvl/extensions/llmcache/base.py index 5b02b0f5..b0c92e7c 100644 --- a/redisvl/extensions/llmcache/base.py +++ b/redisvl/extensions/llmcache/base.py @@ -1,7 +1,8 @@ -import hashlib import json from typing import Any, Dict, List, Optional +from redisvl.redis.utils import hashify + class BaseLLMCache: def __init__(self, ttl: Optional[int] = None): @@ -54,7 +55,7 @@ def store( def hash_input(self, prompt: str): """Hashes the input using SHA256.""" - return hashlib.sha256(prompt.encode("utf-8")).hexdigest() + return hashify(prompt) def serialize(self, metadata: Dict[str, Any]) -> str: """Serlize the input into a string.""" diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py index 62d86b20..c01f9254 100644 --- a/redisvl/extensions/router/schema.py +++ b/redisvl/extensions/router/schema.py @@ -40,10 +40,10 @@ def distance_threshold_must_be_positive(cls, v): class RouteMatch(BaseModel): """Model representing a matched route with distance information.""" - route: Optional[Route] = None - """The matched route.""" + name: Optional[str] = None + """The matched route name.""" distance: Optional[float] = Field(default=None) - """The distance of the match.""" + """The vector distance between the statement and the matched route.""" class DistanceAggregationMethod(Enum): diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index f0b89848..a9487dc9 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -15,10 +15,13 @@ ) from redisvl.index import SearchIndex from redisvl.query import RangeQuery -from redisvl.redis.utils import convert_bytes, make_dict +from redisvl.redis.utils import convert_bytes, hashify, make_dict from redisvl.schema import IndexInfo, IndexSchema +from redisvl.utils.log import get_logger from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer +logger = get_logger(__name__) + class SemanticRouterIndexSchema(IndexSchema): """Customized index schema for SemanticRouter.""" @@ -164,6 +167,10 @@ def update_routing_config(self, routing_config: RoutingConfig): """ self.routing_config = routing_config + def _route_ref_key(self, route_name: str, reference: str) -> str: + reference_hash = hashify(reference) + return f"{self._index.prefix}:{route_name}:{reference_hash}" + def _add_routes(self, routes: List[Route]): """Add routes to the router and index. @@ -183,10 +190,7 @@ def _add_routes(self, routes: List[Route]): "vector": self.vectorizer.embed(reference, as_buffer=True), } ) - reference_hash = hashlib.sha256(reference.encode("utf-8")).hexdigest() - keys.append( - f"{self._index.schema.index.prefix}:{route.name}:{reference_hash}" - ) + keys.append(self._route_ref_key(route.name, reference)) # set route if does not yet exist client side if not self.get(route.name): @@ -212,11 +216,12 @@ def _process_route(self, result: Dict[str, Any]) -> RouteMatch: result (Dict[str, Any]): Aggregation query result object. Returns: - RouteMatch: Processed route match with route object and distance. + RouteMatch: Processed route match with route name and distance. """ route_dict = make_dict(convert_bytes(result)) - route = self.get(route_dict["route_name"]) - return RouteMatch(route=route, distance=float(route_dict["distance"])) + return RouteMatch( + name=route_dict["route_name"], distance=float(route_dict["distance"]) + ) def _build_aggregate_request( self, @@ -255,13 +260,13 @@ def _build_aggregate_request( return aggregate_request - def _classify( + def _classify_route( self, vector: List[float], distance_threshold: float, aggregation_method: DistanceAggregationMethod, - ) -> List[RouteMatch]: - """Classify a single query vector. + ) -> RouteMatch: + """Classify to a single route using a vector. Args: vector (List[float]): The query vector. @@ -269,7 +274,7 @@ def _classify( aggregation_method (DistanceAggregationMethod): The aggregation method. Returns: - List[RouteMatch]: List of route matches. + RouteMatch: Top matching route. """ vector_range_query = RangeQuery( vector=vector, @@ -281,13 +286,11 @@ def _classify( aggregate_request = self._build_aggregate_request( vector_range_query, aggregation_method, max_k=1 ) + try: - route_matches: AggregateResult = self._index.client.ft( # type: ignore + aggregation_result: AggregateResult = self._index.client.ft( # type: ignore self._index.name ).aggregate(aggregate_request, vector_range_query.params) - return [ - self._process_route(route_match) for route_match in route_matches.rows - ] except ResponseError as e: if "VSS is not yet supported on FT.AGGREGATE" in str(e): raise RuntimeError( @@ -295,14 +298,36 @@ def _classify( ) raise e - def _classify_many( + # process aggregation results into route matches + route_matches = [ + self._process_route(route_match) for route_match in aggregation_result.rows + ] + + # process route matches + if route_matches: + top_route_match = route_matches[0] + if top_route_match.name is not None: + if route := self.get(top_route_match.name): + # use the matched route's distance threshold + _distance_threshold = route.distance_threshold or distance_threshold + if self._pass_threshold(top_route_match, _distance_threshold): + return top_route_match + else: + raise ValueError( + f"{top_route_match.name} not a supported route for the {self.name} semantic router." + ) + + # fallback to empty route match if no hits + return RouteMatch() + + def _classify_multi_route( self, vector: List[float], max_k: int, distance_threshold: float, aggregation_method: DistanceAggregationMethod, ) -> List[RouteMatch]: - """Classify multiple query vectors. + """Classify to multiple routes, up to max_k (int), using a vector. Args: vector (List[float]): The query vector. @@ -311,7 +336,7 @@ def _classify_many( aggregation_method (DistanceAggregationMethod): The aggregation method. Returns: - List[RouteMatch]: List of route matches. + RouteMatch: Top matching route. """ vector_range_query = RangeQuery( vector=vector, @@ -322,13 +347,11 @@ def _classify_many( aggregate_request = self._build_aggregate_request( vector_range_query, aggregation_method, max_k ) + try: - route_matches: AggregateResult = self._index.client.ft( # type: ignore + aggregation_result: AggregateResult = self._index.client.ft( # type: ignore self._index.name ).aggregate(aggregate_request, vector_range_query.params) - return [ - self._process_route(route_match) for route_match in route_matches.rows - ] except ResponseError as e: if "VSS is not yet supported on FT.AGGREGATE" in str(e): raise RuntimeError( @@ -336,6 +359,30 @@ def _classify_many( ) raise e + # process aggregation results into route matches + route_matches = [ + self._process_route(route_match) for route_match in aggregation_result.rows + ] + + # process route matches + top_route_matches: List[RouteMatch] = [] + if route_matches: + for route_match in route_matches: + if route_match.name is not None: + if route := self.get(route_match.name): + # use the matched route's distance threshold + _distance_threshold = ( + route.distance_threshold or distance_threshold + ) + if self._pass_threshold(route_match, _distance_threshold): + top_route_matches.append(route_match) + else: + raise ValueError( + f"{route_match.name} not a supported route for the {self.name} semantic router." + ) + + return top_route_matches + def _pass_threshold( self, route_match: Optional[RouteMatch], distance_threshold: float ) -> bool: @@ -348,13 +395,9 @@ def _pass_threshold( Returns: bool: True if the route match passes the threshold, False otherwise. """ - if route_match: - if route_match.distance is not None and route_match.route is not None: - _distance_threshold = ( - route_match.route.distance_threshold or distance_threshold - ) - if _distance_threshold: - return route_match.distance <= _distance_threshold + if route_match and distance_threshold: + if route_match.distance is not None: + return route_match.distance <= distance_threshold return False def __call__( @@ -380,19 +423,19 @@ def __call__( raise ValueError("Must provide a vector or statement to the router") vector = self.vectorizer.embed(statement) + # override routing config distance_threshold = ( distance_threshold or self.routing_config.distance_threshold ) aggregation_method = ( aggregation_method or self.routing_config.aggregation_method ) - route_matches = self._classify(vector, distance_threshold, aggregation_method) - route_match = route_matches[0] if route_matches else None - - if route_match and self._pass_threshold(route_match, distance_threshold): - return route_match - return RouteMatch() + # perform route classification + top_route_match = self._classify_route( + vector, distance_threshold, aggregation_method + ) + return top_route_match def route_many( self, @@ -419,6 +462,7 @@ def route_many( raise ValueError("Must provide a vector or statement to the router") vector = self.vectorizer.embed(statement) + # override routing config defaults distance_threshold = ( distance_threshold or self.routing_config.distance_threshold ) @@ -426,18 +470,36 @@ def route_many( aggregation_method = ( aggregation_method or self.routing_config.aggregation_method ) - route_matches = self._classify_many( + + # classify routes + top_route_matches = self._classify_multi_route( vector, max_k, distance_threshold, aggregation_method ) + return top_route_matches - return [ - route_match - for route_match in route_matches - if self._pass_threshold(route_match, distance_threshold) - ] + def remove_route(self, route_name: str) -> None: + """Remove a route and all references from the semantic router. + + Args: + route_name (str): Name of the route to remove. + """ + route = self.get(route_name) + if route is None: + logger.warning(f"Route {route_name} is not found in the SemanticRouter") + else: + self._index.drop_keys( + [ + self._route_ref_key(route.name, reference) + for reference in route.references + ] + ) + self.routes = [route for route in self.routes if route.name != route_name] - def delete(self): + def delete(self) -> None: + """Delete the semantic router index.""" self._index.delete(drop=True) - def clear(self): + def clear(self) -> None: + """Flush all routes from the semantic router index.""" self._index.clear() + self.routes = [] diff --git a/redisvl/redis/utils.py b/redisvl/redis/utils.py index 29ad5f78..a421022b 100644 --- a/redisvl/redis/utils.py +++ b/redisvl/redis/utils.py @@ -1,3 +1,4 @@ +import hashlib from typing import Any, Dict, List import numpy as np @@ -37,3 +38,8 @@ def array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes: def buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]: """Convert bytes into into a list of floats.""" return np.frombuffer(buffer, dtype=dtype).tolist() + + +def hashify(content: str) -> str: + """Create a secure hash of some arbitrary input text.""" + return hashlib.sha256(content.encode("utf-8")).hexdigest() diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index 75947698..e8bb2b38 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -71,8 +71,7 @@ def test_single_query(semantic_router): pytest.skip("Not using a late enough version of Redis") match = semantic_router("hello") - assert match.route is not None - assert match.route.name == "greeting" + assert match.name == "greeting" assert match.distance <= semantic_router.route_thresholds["greeting"] @@ -82,7 +81,7 @@ def test_single_query_no_match(semantic_router): pytest.skip("Not using a late enough version of Redis") match = semantic_router("unknown_phrase") - assert match.route is None + assert match.name is None def test_multiple_query(semantic_router): @@ -92,7 +91,7 @@ def test_multiple_query(semantic_router): matches = semantic_router.route_many("hello", max_k=2) assert len(matches) > 0 - assert matches[0].route.name == "greeting" + assert matches[0].name == "greeting" def test_update_routing_config(semantic_router): @@ -109,8 +108,7 @@ def test_vector_query(semantic_router): vector = semantic_router.vectorizer.embed("goodbye") match = semantic_router(vector=vector) - assert match.route is not None - assert match.route.name == "farewell" + assert match.name == "farewell" def test_vector_query_no_match(semantic_router): @@ -122,10 +120,10 @@ def test_vector_query_no_match(semantic_router): 0.0 ] * semantic_router.vectorizer.dims # Random vector unlikely to match any route match = semantic_router(vector=vector) - assert match.route is None + assert match.name is None -def test_additional_route(semantic_router): +def test_add_route(semantic_router): new_routes = [ Route( name="politics", @@ -149,4 +147,12 @@ def test_additional_route(semantic_router): match = semantic_router("political speech") print(match, flush=True) assert match is not None - assert match.route.name == "politics" + assert match.name == "politics" + + +def test_remove_routes(semantic_router): + semantic_router.remove_route("greeting") + assert semantic_router.get("greeting") is None + + semantic_router.remove_route("unknown_route") + assert semantic_router.get("unknown_route") is None diff --git a/tests/unit/test_route_schema.py b/tests/unit/test_route_schema.py index 0d42451e..f1ad5cb5 100644 --- a/tests/unit/test_route_schema.py +++ b/tests/unit/test_route_schema.py @@ -90,20 +90,14 @@ def test_route_invalid_threshold_negative(): def test_route_match(): - route = Route( - name="Test Route", - references=["reference1", "reference2"], - metadata={"key": "value"}, - distance_threshold=0.3, - ) - route_match = RouteMatch(route=route, distance=0.25) - assert route_match.route == route + route_match = RouteMatch(name="test", distance=0.25) + assert route_match.name == "test" assert route_match.distance == 0.25 def test_route_match_no_route(): route_match = RouteMatch() - assert route_match.route is None + assert route_match.name is None assert route_match.distance is None From b070c2437982ac1a6395e0bd289e40bcc9c44fc6 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 18 Jul 2024 09:14:36 -0400 Subject: [PATCH 16/16] cleanup --- redisvl/extensions/router/semantic.py | 53 +++------------------------ redisvl/index/index.py | 11 ++---- 2 files changed, 9 insertions(+), 55 deletions(-) diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index a9487dc9..93621036 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -118,14 +118,7 @@ def _initialize_index( overwrite: bool = False, **connection_kwargs, ): - """Initialize the search index and handle Redis connection. - - Args: - redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None. - redis_url (Optional[str], optional): Redis URL for connection. Defaults to None. - overwrite (bool, optional): Whether to overwrite existing index. Defaults to False. - **connection_kwargs: Additional connection arguments. - """ + """Initialize the search index and handle Redis connection.""" schema = SemanticRouterIndexSchema.from_params(self.name, self.vectorizer.dims) self._index = SearchIndex(schema=schema) @@ -168,6 +161,7 @@ def update_routing_config(self, routing_config: RoutingConfig): self.routing_config = routing_config def _route_ref_key(self, route_name: str, reference: str) -> str: + """Generate the route reference key.""" reference_hash = hashify(reference) return f"{self._index.prefix}:{route_name}:{reference_hash}" @@ -210,14 +204,7 @@ def get(self, route_name: str) -> Optional[Route]: return next((route for route in self.routes if route.name == route_name), None) def _process_route(self, result: Dict[str, Any]) -> RouteMatch: - """Process resulting route objects and metadata. - - Args: - result (Dict[str, Any]): Aggregation query result object. - - Returns: - RouteMatch: Processed route match with route name and distance. - """ + """Process resulting route objects and metadata.""" route_dict = make_dict(convert_bytes(result)) return RouteMatch( name=route_dict["route_name"], distance=float(route_dict["distance"]) @@ -229,16 +216,7 @@ def _build_aggregate_request( aggregation_method: DistanceAggregationMethod, max_k: int, ) -> AggregateRequest: - """Build the Redis aggregation request. - - Args: - vector_range_query (RangeQuery): The query vector. - aggregation_method (DistanceAggregationMethod): The aggregation method. - max_k (int): The maximum number of top matches to return. - - Returns: - AggregateRequest: The constructed aggregation request. - """ + """Build the Redis aggregation request.""" aggregation_func: Type[Reducer] if aggregation_method == DistanceAggregationMethod.min: @@ -266,16 +244,7 @@ def _classify_route( distance_threshold: float, aggregation_method: DistanceAggregationMethod, ) -> RouteMatch: - """Classify to a single route using a vector. - - Args: - vector (List[float]): The query vector. - distance_threshold (float): The distance threshold. - aggregation_method (DistanceAggregationMethod): The aggregation method. - - Returns: - RouteMatch: Top matching route. - """ + """Classify to a single route using a vector.""" vector_range_query = RangeQuery( vector=vector, vector_field_name="vector", @@ -327,17 +296,7 @@ def _classify_multi_route( distance_threshold: float, aggregation_method: DistanceAggregationMethod, ) -> List[RouteMatch]: - """Classify to multiple routes, up to max_k (int), using a vector. - - Args: - vector (List[float]): The query vector. - max_k (int): The maximum number of top matches to return. - distance_threshold (float): The distance threshold. - aggregation_method (DistanceAggregationMethod): The aggregation method. - - Returns: - RouteMatch: Top matching route. - """ + """Classify to multiple routes, up to max_k (int), using a vector.""" vector_range_query = RangeQuery( vector=vector, vector_field_name="vector", diff --git a/redisvl/index/index.py b/redisvl/index/index.py index a479f8ed..b70fcb6a 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -22,7 +22,7 @@ import redis.asyncio as aredis from redis.commands.search.indexDefinition import IndexDefinition -from redisvl.index.storage import HashStorage, JsonStorage +from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage from redisvl.query import BaseQuery, CountQuery, FilterQuery from redisvl.query.filter import FilterExpression from redisvl.redis.connection import ( @@ -176,14 +176,9 @@ def __init__( elif redis_url is not None: self.connect(redis_url, **connection_args) - # set up index storage layer - # self._storage = self._STORAGE_MAP[self.schema.index.storage_type]( - # prefix=self.schema.index.prefix, - # key_separator=self.schema.index.key_separator, - # ) - @property - def _storage(self): + def _storage(self) -> BaseStorage: + """The storage type for the index schema.""" return self._STORAGE_MAP[self.schema.index.storage_type]( prefix=self.schema.index.prefix, key_separator=self.schema.index.key_separator,