diff --git a/docs/_static/js/sidebar.js b/docs/_static/js/sidebar.js index 7a52eb54..00b49dd5 100644 --- a/docs/_static/js/sidebar.js +++ b/docs/_static/js/sidebar.js @@ -9,17 +9,23 @@ const toc = [ { title: "Query and Filter", path: "/user_guide/hybrid_queries_02.html" }, { title: "JSON vs Hash Storage", path: "/user_guide/hash_vs_json_05.html" }, { title: "Vectorizers", path: "/user_guide/vectorizers_04.html" }, - { title: "Rerankers", path: "/user_guide/rerankers_06.html"}, + { title: "Rerankers", path: "/user_guide/rerankers_06.html" }, { title: "Semantic Caching", path: "/user_guide/llmcache_03.html" }, + { title: "Semantic Routing", path: "/user_guide/semantic_router_08.html" }, ]}, { header: "API", toc: [ { title: "Schema", path: "/api/schema.html"}, { title: "Search Index", path: "/api/searchindex.html" }, { title: "Query", path: "/api/query.html" }, { title: "Filter", path: "/api/filter.html" }, + ]}, + { header: "Utils", toc: [ { title: "Vectorizers", path: "/api/vectorizer.html" }, { title: "Rerankers", path: "/api/reranker.html" }, - { title: "LLMCache", path: "/api/cache.html" } + ]}, + { header: "Extensions", toc: [ + { title: "LLM Cache", path: "/api/cache.html" }, + { title: "Semantic Router", path: "/api/router.html" }, ]} ]; diff --git a/docs/api/cache.rst b/docs/api/cache.rst index 7e34ee1f..9c921c60 100644 --- a/docs/api/cache.rst +++ b/docs/api/cache.rst @@ -1,7 +1,7 @@ -******** -LLMCache -******** +********* +LLM Cache +********* SemanticCache ============= diff --git a/docs/api/index.md b/docs/api/index.md index b2c45ab6..14371397 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -18,5 +18,6 @@ filter vectorizer reranker cache +router ``` diff --git a/docs/api/query.rst b/docs/api/query.rst index 9146160a..a06d8bad 100644 --- a/docs/api/query.rst +++ b/docs/api/query.rst @@ -3,11 +3,12 @@ Query ***** +.. _query_api: + + VectorQuery =========== -.. _query_api: - .. currentmodule:: redisvl.query diff --git a/docs/api/router.rst b/docs/api/router.rst new file mode 100644 index 00000000..191d3e0f --- /dev/null +++ b/docs/api/router.rst @@ -0,0 +1,51 @@ + +*************** +Semantic Router +*************** + +.. _semantic_router_api: + + +Semantic Router +=============== + +.. currentmodule:: redisvl.extensions.router + +.. autoclass:: SemanticRouter + :members: + + +Routing Config +============== + +.. currentmodule:: redisvl.extensions.router + +.. autoclass:: RoutingConfig + :members: + + +Route +===== + +.. currentmodule:: redisvl.extensions.router + +.. autoclass:: Route + :members: + + +Route Match +=========== + +.. currentmodule:: redisvl.extensions.router.schema + +.. autoclass:: RouteMatch + :members: + + +Distance Aggregation Method +=========================== + +.. currentmodule:: redisvl.extensions.router.schema + +.. autoclass:: DistanceAggregationMethod + :members: diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md index 6f0d0b5f..9664a9a3 100644 --- a/docs/user_guide/index.md +++ b/docs/user_guide/index.md @@ -18,5 +18,6 @@ vectorizers_04 hash_vs_json_05 rerankers_06 session_manager_07 +semantic_router_08 ``` diff --git a/docs/user_guide/semantic_router_08.ipynb b/docs/user_guide/semantic_router_08.ipynb new file mode 100644 index 00000000..0fc01beb --- /dev/null +++ b/docs/user_guide/semantic_router_08.ipynb @@ -0,0 +1,389 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Semantic Routing\n", + "\n", + "RedisVL provides a `SemanticRouter` interface to utilize Redis' built-in search & aggregation in order to perform\n", + "KNN-style classification over a set of `Route` references to determine the best match.\n", + "\n", + "This notebook will go over how to use Redis as a Semantic Router for your applications" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the Routes\n", + "\n", + "Below we define 3 different routes. One for `technology`, one for `sports`, and\n", + "another for `entertainment`. Now for this example, the goal here is\n", + "surely topic \"classification\". But you can create routes and references for\n", + "almost anything.\n", + "\n", + "Each route has a set of references that cover the \"semantic surface area\" of the\n", + "route. The incoming query from a user needs to be semantically similar to one or\n", + "more of the references in order to \"match\" on the route." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from redisvl.extensions.router import Route\n", + "\n", + "\n", + "# Define routes for the semantic router\n", + "technology = Route(\n", + " name=\"technology\",\n", + " references=[\n", + " \"what are the latest advancements in AI?\",\n", + " \"tell me about the newest gadgets\",\n", + " \"what's trending in tech?\"\n", + " ],\n", + " metadata={\"category\": \"tech\", \"priority\": 1}\n", + ")\n", + "\n", + "sports = Route(\n", + " name=\"sports\",\n", + " references=[\n", + " \"who won the game last night?\",\n", + " \"tell me about the upcoming sports events\",\n", + " \"what's the latest in the world of sports?\",\n", + " \"sports\",\n", + " \"basketball and football\"\n", + " ],\n", + " metadata={\"category\": \"sports\", \"priority\": 2}\n", + ")\n", + "\n", + "entertainment = Route(\n", + " name=\"entertainment\",\n", + " references=[\n", + " \"what are the top movies right now?\",\n", + " \"who won the best actor award?\",\n", + " \"what's new in the entertainment industry?\"\n", + " ],\n", + " metadata={\"category\": \"entertainment\", \"priority\": 3}\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize the SemanticRouter\n", + "\n", + "``SemanticRouter`` will automatically create an index within Redis upon initialization for the route references. By default, it uses the `HFTextVectorizer` to \n", + "generate embeddings for each route reference." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from redisvl.extensions.router import SemanticRouter\n", + "from redisvl.utils.vectorize import HFTextVectorizer\n", + "\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "\n", + "# Initialize the SemanticRouter\n", + "router = SemanticRouter(\n", + " name=\"topic-router\",\n", + " vectorizer=HFTextVectorizer(),\n", + " routes=[technology, sports, entertainment],\n", + " redis_url=\"redis://localhost:6379\",\n", + " overwrite=True # Blow away any other routing index with this name\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "HFTextVectorizer(model='sentence-transformers/all-mpnet-base-v2', dims=768)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "router.vectorizer" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Index Information:\n", + "╭──────────────┬────────────────┬──────────────────┬─────────────────┬────────────╮\n", + "│ Index Name │ Storage Type │ Prefixes │ Index Options │ Indexing │\n", + "├──────────────┼────────────────┼──────────────────┼─────────────────┼────────────┤\n", + "│ topic-router │ HASH │ ['topic-router'] │ [] │ 0 │\n", + "╰──────────────┴────────────────┴──────────────────┴─────────────────┴────────────╯\n", + "Index Fields:\n", + "╭────────────┬─────────────┬────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬─────────────────┬────────────────╮\n", + "│ Name │ Attribute │ Type │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │\n", + "├────────────┼─────────────┼────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────────────────┼────────────────┤\n", + "│ route_name │ route_name │ TAG │ SEPARATOR │ , │ │ │ │ │ │ │\n", + "│ reference │ reference │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n", + "│ vector │ vector │ VECTOR │ algorithm │ FLAT │ data_type │ FLOAT32 │ dim │ 768 │ distance_metric │ COSINE │\n", + "╰────────────┴─────────────┴────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴─────────────────┴────────────────╯\n" + ] + } + ], + "source": [ + "# look at the index specification created for the semantic router\n", + "!rvl index info -i topic-router" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simple routing" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteMatch(name='technology', distance=0.119614183903)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Query the router with a statement\n", + "route_match = router(\"Can you tell me about the latest in artificial intelligence?\")\n", + "route_match" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteMatch(name=None, distance=None)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Query the router with a statement and return a miss\n", + "route_match = router(\"are aliens real?\")\n", + "route_match" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RouteMatch(name='sports', distance=0.554210186005)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Toggle the runtime distance threshold\n", + "route_match = router(\"Which basketball team will win the NBA finals?\", distance_threshold=0.7)\n", + "route_match" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also route a statement to many routes and order them by distance:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[RouteMatch(name='sports', distance=0.758580672741),\n", + " RouteMatch(name='entertainment', distance=0.812423805396),\n", + " RouteMatch(name='technology', distance=0.884235262871)]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Perform multi-class classification with route_many() -- toggle the max_k and the distance_threshold\n", + "route_matches = router.route_many(\"Lebron James\", distance_threshold=1.0, max_k=3)\n", + "route_matches" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[RouteMatch(name='sports', distance=0.663254022598),\n", + " RouteMatch(name='entertainment', distance=0.712985336781),\n", + " RouteMatch(name='technology', distance=0.832674443722)]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Toggle the aggregation method -- note the different distances in the result\n", + "from redisvl.extensions.router.schema import DistanceAggregationMethod\n", + "\n", + "route_matches = router.route_many(\"Lebron James\", aggregation_method=DistanceAggregationMethod.min, distance_threshold=1.0, max_k=3)\n", + "route_matches" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note the different route match distances. This is because we used the `min` aggregation method instead of the default `avg` approach." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Update the routing config" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from redisvl.extensions.router import RoutingConfig\n", + "\n", + "router.update_routing_config(\n", + " RoutingConfig(distance_threshold=1.0, aggregation_method=DistanceAggregationMethod.min, max_k=3)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[RouteMatch(name='sports', distance=0.663254022598),\n", + " RouteMatch(name='entertainment', distance=0.712985336781),\n", + " RouteMatch(name='technology', distance=0.832674443722)]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "route_matches = router.route_many(\"Lebron James\")\n", + "route_matches" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Clean up the router" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# Use clear to flush all routes from the index\n", + "router.clear()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# Use delete to clear the index and remove it completely\n", + "router.delete()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "rvl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/redisvl/extensions/llmcache/base.py b/redisvl/extensions/llmcache/base.py index 5b02b0f5..b0c92e7c 100644 --- a/redisvl/extensions/llmcache/base.py +++ b/redisvl/extensions/llmcache/base.py @@ -1,7 +1,8 @@ -import hashlib import json from typing import Any, Dict, List, Optional +from redisvl.redis.utils import hashify + class BaseLLMCache: def __init__(self, ttl: Optional[int] = None): @@ -54,7 +55,7 @@ def store( def hash_input(self, prompt: str): """Hashes the input using SHA256.""" - return hashlib.sha256(prompt.encode("utf-8")).hexdigest() + return hashify(prompt) def serialize(self, metadata: Dict[str, Any]) -> str: """Serlize the input into a string.""" diff --git a/redisvl/extensions/router/__init__.py b/redisvl/extensions/router/__init__.py new file mode 100644 index 00000000..25249f3b --- /dev/null +++ b/redisvl/extensions/router/__init__.py @@ -0,0 +1,4 @@ +from redisvl.extensions.router.schema import Route, RoutingConfig +from redisvl.extensions.router.semantic import SemanticRouter + +__all__ = ["SemanticRouter", "Route", "RoutingConfig"] diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py new file mode 100644 index 00000000..c01f9254 --- /dev/null +++ b/redisvl/extensions/router/schema.py @@ -0,0 +1,82 @@ +from enum import Enum +from typing import Dict, List, Optional + +from pydantic.v1 import BaseModel, Field, validator + + +class Route(BaseModel): + """Model representing a routing path with associated metadata and thresholds.""" + + name: str + """The name of the route.""" + references: List[str] + """List of reference phrases for the route.""" + metadata: Dict[str, str] = Field(default={}) + """Metadata associated with the route.""" + distance_threshold: Optional[float] = Field(default=None) + """Distance threshold for matching the route.""" + + @validator("name") + def name_must_not_be_empty(cls, v): + if not v or not v.strip(): + raise ValueError("Route name must not be empty") + return v + + @validator("references") + def references_must_not_be_empty(cls, v): + if not v: + raise ValueError("References must not be empty") + if any(not ref.strip() for ref in v): + raise ValueError("All references must be non-empty strings") + return v + + @validator("distance_threshold") + def distance_threshold_must_be_positive(cls, v): + if v is not None and v <= 0: + raise ValueError("Route distance threshold must be greater than zero") + return v + + +class RouteMatch(BaseModel): + """Model representing a matched route with distance information.""" + + name: Optional[str] = None + """The matched route name.""" + distance: Optional[float] = Field(default=None) + """The vector distance between the statement and the matched route.""" + + +class DistanceAggregationMethod(Enum): + """Enumeration for distance aggregation methods.""" + + avg = "avg" + """Compute the average of the vector distances.""" + min = "min" + """Compute the minimum of the vector distances.""" + sum = "sum" + """Compute the sum of the vector distances.""" + + +class RoutingConfig(BaseModel): + """Configuration for routing behavior.""" + + distance_threshold: float = Field(default=0.5) + """The threshold for semantic distance.""" + max_k: int = Field(default=1) + """The maximum number of top matches to return.""" + aggregation_method: DistanceAggregationMethod = Field( + default=DistanceAggregationMethod.avg + ) + """Aggregation method to use to classify queries.""" + + @validator("max_k") + def max_k_must_be_positive(cls, v): + if v <= 0: + raise ValueError("max_k must be a positive integer") + return v + + @validator("distance_threshold") + def distance_threshold_must_be_valid(cls, v): + if v <= 0 or v > 1: + raise ValueError("distance_threshold must be between 0 and 1") + return v diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py new file mode 100644 index 00000000..93621036 --- /dev/null +++ b/redisvl/extensions/router/semantic.py @@ -0,0 +1,464 @@ +import hashlib +from typing import Any, Dict, List, Optional, Type + +import redis.commands.search.reducers as reducers +from pydantic.v1 import BaseModel, Field, PrivateAttr +from redis import Redis +from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer +from redis.exceptions import ResponseError + +from redisvl.extensions.router.schema import ( + DistanceAggregationMethod, + Route, + RouteMatch, + RoutingConfig, +) +from redisvl.index import SearchIndex +from redisvl.query import RangeQuery +from redisvl.redis.utils import convert_bytes, hashify, make_dict +from redisvl.schema import IndexInfo, IndexSchema +from redisvl.utils.log import get_logger +from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer + +logger = get_logger(__name__) + + +class SemanticRouterIndexSchema(IndexSchema): + """Customized index schema for SemanticRouter.""" + + @classmethod + def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema": + """Create an index schema based on router name and vector dimensions. + + Args: + name (str): The name of the index. + vector_dims (int): The dimensions of the vectors. + + Returns: + SemanticRouterIndexSchema: The constructed index schema. + """ + return cls( + index=IndexInfo(name=name, prefix=name), + fields=[ # type: ignore + {"name": "route_name", "type": "tag"}, + {"name": "reference", "type": "text"}, + { + "name": "vector", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": vector_dims, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + ) + + +class SemanticRouter(BaseModel): + """Semantic Router for managing and querying route vectors.""" + + name: str + """The name of the semantic router.""" + routes: List[Route] + """List of Route objects.""" + vectorizer: BaseVectorizer = Field(default_factory=HFTextVectorizer) + """The vectorizer used to embed route references.""" + routing_config: RoutingConfig = Field(default_factory=RoutingConfig) + """Configuration for routing behavior.""" + + _index: SearchIndex = PrivateAttr() + + class Config: + arbitrary_types_allowed = True + + def __init__( + self, + name: str, + routes: List[Route], + vectorizer: Optional[BaseVectorizer] = None, + routing_config: Optional[RoutingConfig] = None, + redis_client: Optional[Redis] = None, + redis_url: Optional[str] = None, + overwrite: bool = False, + **kwargs, + ): + """Initialize the SemanticRouter. + + Args: + name (str): The name of the semantic router. + routes (List[Route]): List of Route objects. + vectorizer (BaseVectorizer, optional): The vectorizer used to embed route references. Defaults to default HFTextVectorizer. + routing_config (RoutingConfig, optional): Configuration for routing behavior. Defaults to the default RoutingConfig. + redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None. + redis_url (Optional[str], optional): Redis URL for connection. Defaults to None. + overwrite (bool, optional): Whether to overwrite existing index. Defaults to False. + **kwargs: Additional arguments. + """ + # Set vectorizer default + if vectorizer is None: + vectorizer = HFTextVectorizer() + + if routing_config is None: + routing_config = RoutingConfig() + + super().__init__( + name=name, + routes=routes, + vectorizer=vectorizer, + routing_config=routing_config, + ) + self._initialize_index(redis_client, redis_url, overwrite) + + def _initialize_index( + self, + redis_client: Optional[Redis] = None, + redis_url: Optional[str] = None, + overwrite: bool = False, + **connection_kwargs, + ): + """Initialize the search index and handle Redis connection.""" + schema = SemanticRouterIndexSchema.from_params(self.name, self.vectorizer.dims) + self._index = SearchIndex(schema=schema) + + if redis_client: + self._index.set_client(redis_client) + else: + self._index.connect(redis_url=redis_url, **connection_kwargs) + + existed = self._index.exists() + self._index.create(overwrite=overwrite) + + if not existed or overwrite: + # write the routes to Redis + self._add_routes(self.routes) + + @property + def route_names(self) -> List[str]: + """Get the list of route names. + + Returns: + List[str]: List of route names. + """ + return [route.name for route in self.routes] + + @property + def route_thresholds(self) -> Dict[str, Optional[float]]: + """Get the distance thresholds for each route. + + Returns: + Dict[str, float]: Dictionary of route names and their distance thresholds. + """ + return {route.name: route.distance_threshold for route in self.routes} + + def update_routing_config(self, routing_config: RoutingConfig): + """Update the routing configuration. + + Args: + routing_config (RoutingConfig): The new routing configuration. + """ + self.routing_config = routing_config + + def _route_ref_key(self, route_name: str, reference: str) -> str: + """Generate the route reference key.""" + reference_hash = hashify(reference) + return f"{self._index.prefix}:{route_name}:{reference_hash}" + + def _add_routes(self, routes: List[Route]): + """Add routes to the router and index. + + Args: + routes (List[Route]): List of routes to be added. + """ + route_references: List[Dict[str, Any]] = [] + keys: List[str] = [] + + for route in routes: + # set route references + for reference in route.references: + route_references.append( + { + "route_name": route.name, + "reference": reference, + "vector": self.vectorizer.embed(reference, as_buffer=True), + } + ) + keys.append(self._route_ref_key(route.name, reference)) + + # set route if does not yet exist client side + if not self.get(route.name): + self.routes.append(route) + + self._index.load(route_references, keys=keys) + + def get(self, route_name: str) -> Optional[Route]: + """Get a route by its name. + + Args: + route_name (str): Name of the route. + + Returns: + Optional[Route]: The selected Route object or None if not found. + """ + return next((route for route in self.routes if route.name == route_name), None) + + def _process_route(self, result: Dict[str, Any]) -> RouteMatch: + """Process resulting route objects and metadata.""" + route_dict = make_dict(convert_bytes(result)) + return RouteMatch( + name=route_dict["route_name"], distance=float(route_dict["distance"]) + ) + + def _build_aggregate_request( + self, + vector_range_query: RangeQuery, + aggregation_method: DistanceAggregationMethod, + max_k: int, + ) -> AggregateRequest: + """Build the Redis aggregation request.""" + aggregation_func: Type[Reducer] + + if aggregation_method == DistanceAggregationMethod.min: + aggregation_func = reducers.min + elif aggregation_method == DistanceAggregationMethod.sum: + aggregation_func = reducers.sum + else: + aggregation_func = reducers.avg + + aggregate_query = str(vector_range_query).split(" RETURN")[0] + aggregate_request = ( + AggregateRequest(aggregate_query) + .group_by( + "@route_name", aggregation_func("vector_distance").alias("distance") + ) + .sort_by("@distance", max=max_k) + .dialect(2) + ) + + return aggregate_request + + def _classify_route( + self, + vector: List[float], + distance_threshold: float, + aggregation_method: DistanceAggregationMethod, + ) -> RouteMatch: + """Classify to a single route using a vector.""" + vector_range_query = RangeQuery( + vector=vector, + vector_field_name="vector", + distance_threshold=distance_threshold, + return_fields=["route_name"], + ) + + aggregate_request = self._build_aggregate_request( + vector_range_query, aggregation_method, max_k=1 + ) + + try: + aggregation_result: AggregateResult = self._index.client.ft( # type: ignore + self._index.name + ).aggregate(aggregate_request, vector_range_query.params) + except ResponseError as e: + if "VSS is not yet supported on FT.AGGREGATE" in str(e): + raise RuntimeError( + "Semantic routing is only available on Redis version 7.x.x or greater" + ) + raise e + + # process aggregation results into route matches + route_matches = [ + self._process_route(route_match) for route_match in aggregation_result.rows + ] + + # process route matches + if route_matches: + top_route_match = route_matches[0] + if top_route_match.name is not None: + if route := self.get(top_route_match.name): + # use the matched route's distance threshold + _distance_threshold = route.distance_threshold or distance_threshold + if self._pass_threshold(top_route_match, _distance_threshold): + return top_route_match + else: + raise ValueError( + f"{top_route_match.name} not a supported route for the {self.name} semantic router." + ) + + # fallback to empty route match if no hits + return RouteMatch() + + def _classify_multi_route( + self, + vector: List[float], + max_k: int, + distance_threshold: float, + aggregation_method: DistanceAggregationMethod, + ) -> List[RouteMatch]: + """Classify to multiple routes, up to max_k (int), using a vector.""" + vector_range_query = RangeQuery( + vector=vector, + vector_field_name="vector", + distance_threshold=distance_threshold, + return_fields=["route_name"], + ) + aggregate_request = self._build_aggregate_request( + vector_range_query, aggregation_method, max_k + ) + + try: + aggregation_result: AggregateResult = self._index.client.ft( # type: ignore + self._index.name + ).aggregate(aggregate_request, vector_range_query.params) + except ResponseError as e: + if "VSS is not yet supported on FT.AGGREGATE" in str(e): + raise RuntimeError( + "Semantic routing is only available on Redis version 7.x.x or greater" + ) + raise e + + # process aggregation results into route matches + route_matches = [ + self._process_route(route_match) for route_match in aggregation_result.rows + ] + + # process route matches + top_route_matches: List[RouteMatch] = [] + if route_matches: + for route_match in route_matches: + if route_match.name is not None: + if route := self.get(route_match.name): + # use the matched route's distance threshold + _distance_threshold = ( + route.distance_threshold or distance_threshold + ) + if self._pass_threshold(route_match, _distance_threshold): + top_route_matches.append(route_match) + else: + raise ValueError( + f"{route_match.name} not a supported route for the {self.name} semantic router." + ) + + return top_route_matches + + def _pass_threshold( + self, route_match: Optional[RouteMatch], distance_threshold: float + ) -> bool: + """Check if a route match passes the distance threshold. + + Args: + route_match (Optional[RouteMatch]): The route match to check. + distance_threshold (float): The fallback distance threshold to use if not assigned to a route. + + Returns: + bool: True if the route match passes the threshold, False otherwise. + """ + if route_match and distance_threshold: + if route_match.distance is not None: + return route_match.distance <= distance_threshold + return False + + def __call__( + self, + statement: Optional[str] = None, + vector: Optional[List[float]] = None, + distance_threshold: Optional[float] = None, + aggregation_method: Optional[DistanceAggregationMethod] = None, + ) -> RouteMatch: + """Query the semantic router with a given statement or vector. + + Args: + statement (Optional[str]): The input statement to be queried. + vector (Optional[List[float]]): The input vector to be queried. + distance_threshold (Optional[float]): The threshold for semantic distance. + aggregation_method (Optional[DistanceAggregationMethod]): The aggregation method used for vector distances. + + Returns: + RouteMatch: The matching route. + """ + if not vector: + if not statement: + raise ValueError("Must provide a vector or statement to the router") + vector = self.vectorizer.embed(statement) + + # override routing config + distance_threshold = ( + distance_threshold or self.routing_config.distance_threshold + ) + aggregation_method = ( + aggregation_method or self.routing_config.aggregation_method + ) + + # perform route classification + top_route_match = self._classify_route( + vector, distance_threshold, aggregation_method + ) + return top_route_match + + def route_many( + self, + statement: Optional[str] = None, + vector: Optional[List[float]] = None, + max_k: Optional[int] = None, + distance_threshold: Optional[float] = None, + aggregation_method: Optional[DistanceAggregationMethod] = None, + ) -> List[RouteMatch]: + """Query the semantic router with a given statement or vector for multiple matches. + + Args: + statement (Optional[str]): The input statement to be queried. + vector (Optional[List[float]]): The input vector to be queried. + max_k (Optional[int]): The maximum number of top matches to return. + distance_threshold (Optional[float]): The threshold for semantic distance. + aggregation_method (Optional[DistanceAggregationMethod]): The aggregation method used for vector distances. + + Returns: + List[RouteMatch]: The matching routes and their details. + """ + if not vector: + if not statement: + raise ValueError("Must provide a vector or statement to the router") + vector = self.vectorizer.embed(statement) + + # override routing config defaults + distance_threshold = ( + distance_threshold or self.routing_config.distance_threshold + ) + max_k = max_k or self.routing_config.max_k + aggregation_method = ( + aggregation_method or self.routing_config.aggregation_method + ) + + # classify routes + top_route_matches = self._classify_multi_route( + vector, max_k, distance_threshold, aggregation_method + ) + return top_route_matches + + def remove_route(self, route_name: str) -> None: + """Remove a route and all references from the semantic router. + + Args: + route_name (str): Name of the route to remove. + """ + route = self.get(route_name) + if route is None: + logger.warning(f"Route {route_name} is not found in the SemanticRouter") + else: + self._index.drop_keys( + [ + self._route_ref_key(route.name, reference) + for reference in route.references + ] + ) + self.routes = [route for route in self.routes if route.name != route_name] + + def delete(self) -> None: + """Delete the semantic router index.""" + self._index.delete(drop=True) + + def clear(self) -> None: + """Flush all routes from the semantic router index.""" + self._index.clear() + self.routes = [] diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 2187c759..b70fcb6a 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -22,7 +22,7 @@ import redis.asyncio as aredis from redis.commands.search.indexDefinition import IndexDefinition -from redisvl.index.storage import HashStorage, JsonStorage +from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage from redisvl.query import BaseQuery, CountQuery, FilterQuery from redisvl.query.filter import FilterExpression from redisvl.redis.connection import ( @@ -176,8 +176,10 @@ def __init__( elif redis_url is not None: self.connect(redis_url, **connection_args) - # set up index storage layer - self._storage = self._STORAGE_MAP[self.schema.index.storage_type]( + @property + def _storage(self) -> BaseStorage: + """The storage type for the index schema.""" + return self._STORAGE_MAP[self.schema.index.storage_type]( prefix=self.schema.index.prefix, key_separator=self.schema.index.key_separator, ) diff --git a/redisvl/redis/connection.py b/redisvl/redis/connection.py index 3e949dd3..888cb127 100644 --- a/redisvl/redis/connection.py +++ b/redisvl/redis/connection.py @@ -18,6 +18,35 @@ from redisvl.version import __version__ +def compare_versions(version1, version2): + """ + Compare two Redis version strings numerically. + + Parameters: + version1 (str): The first version string (e.g., "7.2.4"). + version2 (str): The second version string (e.g., "6.2.1"). + + Returns: + int: -1 if version1 < version2, 0 if version1 == version2, 1 if version1 > version2. + """ + v1_parts = list(map(int, version1.split("."))) + v2_parts = list(map(int, version2.split("."))) + + for v1, v2 in zip(v1_parts, v2_parts): + if v1 < v2: + return False + elif v1 > v2: + return True + + # If the versions are equal so far, compare the lengths of the version parts + if len(v1_parts) < len(v2_parts): + return False + elif len(v1_parts) > len(v2_parts): + return True + + return True + + def unpack_redis_modules(module_list: List[Dict[str, Any]]) -> Dict[str, Any]: """Unpack a list of Redis modules pulled from the MODULES LIST command.""" return {module["name"]: module["ver"] for module in module_list} diff --git a/redisvl/redis/utils.py b/redisvl/redis/utils.py index 29ad5f78..a421022b 100644 --- a/redisvl/redis/utils.py +++ b/redisvl/redis/utils.py @@ -1,3 +1,4 @@ +import hashlib from typing import Any, Dict, List import numpy as np @@ -37,3 +38,8 @@ def array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes: def buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]: """Convert bytes into into a list of floats.""" return np.frombuffer(buffer, dtype=dtype).tolist() + + +def hashify(content: str) -> str: + """Create a secure hash of some arbitrary input text.""" + return hashlib.sha256(content.encode("utf-8")).hexdigest() diff --git a/redisvl/schema/schema.py b/redisvl/schema/schema.py index d165cff4..ed9cffd4 100644 --- a/redisvl/schema/schema.py +++ b/redisvl/schema/schema.py @@ -195,7 +195,11 @@ def validate_and_create_fields(cls, values): """ Validate uniqueness of field names and create valid field instances. """ - index = IndexInfo(**values.get("index")) + # Ensure index is a dictionary for validation + index = values.get("index") + if not isinstance(index, IndexInfo): + index = IndexInfo(**index) + input_fields = values.get("fields", []) prepared_fields: Dict[str, BaseField] = {} # Handle old fields format temporarily diff --git a/tests/integration/test_connection.py b/tests/integration/test_connection.py index 30f1c3f3..99608ddc 100644 --- a/tests/integration/test_connection.py +++ b/tests/integration/test_connection.py @@ -7,6 +7,7 @@ from redisvl.redis.connection import ( RedisConnectionFactory, + compare_versions, convert_index_info_to_schema, get_address_from_env, unpack_redis_modules, @@ -18,35 +19,6 @@ EXPECTED_LIB_NAME = f"redis-py(redisvl_v{__version__})" -def compare_versions(version1, version2): - """ - Compare two Redis version strings numerically. - - Parameters: - version1 (str): The first version string (e.g., "7.2.4"). - version2 (str): The second version string (e.g., "6.2.1"). - - Returns: - int: -1 if version1 < version2, 0 if version1 == version2, 1 if version1 > version2. - """ - v1_parts = list(map(int, version1.split("."))) - v2_parts = list(map(int, version2.split("."))) - - for v1, v2 in zip(v1_parts, v2_parts): - if v1 < v2: - return False - elif v1 > v2: - return True - - # If the versions are equal so far, compare the lengths of the version parts - if len(v1_parts) < len(v2_parts): - return False - elif len(v1_parts) > len(v2_parts): - return True - - return True - - def test_get_address_from_env(redis_url): assert get_address_from_env() == redis_url diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py new file mode 100644 index 00000000..e8bb2b38 --- /dev/null +++ b/tests/integration/test_semantic_router.py @@ -0,0 +1,158 @@ +import pytest + +from redisvl.extensions.router import SemanticRouter +from redisvl.extensions.router.schema import Route, RoutingConfig +from redisvl.redis.connection import compare_versions + + +@pytest.fixture +def routes(): + return [ + Route( + name="greeting", + references=["hello", "hi"], + metadata={"type": "greeting"}, + distance_threshold=0.3, + ), + Route( + name="farewell", + references=["bye", "goodbye"], + metadata={"type": "farewell"}, + distance_threshold=0.3, + ), + ] + + +@pytest.fixture +def semantic_router(client, routes): + router = SemanticRouter( + name="test-router", + routes=routes, + routing_config=RoutingConfig(distance_threshold=0.3, max_k=2), + redis_client=client, + overwrite=False, + ) + yield router + router.delete() + + +def test_initialize_router(semantic_router): + assert semantic_router.name == "test-router" + assert len(semantic_router.routes) == 2 + assert semantic_router.routing_config.distance_threshold == 0.3 + assert semantic_router.routing_config.max_k == 2 + + +def test_router_properties(semantic_router): + route_names = semantic_router.route_names + assert "greeting" in route_names + assert "farewell" in route_names + + thresholds = semantic_router.route_thresholds + assert thresholds["greeting"] == 0.3 + assert thresholds["farewell"] == 0.3 + + +def test_get_route(semantic_router): + route = semantic_router.get("greeting") + assert route is not None + assert route.name == "greeting" + assert "hello" in route.references + + +def test_get_non_existing_route(semantic_router): + route = semantic_router.get("non_existent_route") + assert route is None + + +def test_single_query(semantic_router): + redis_version = semantic_router._index.client.info()["redis_version"] + if not compare_versions(redis_version, "7.0.0"): + pytest.skip("Not using a late enough version of Redis") + + match = semantic_router("hello") + assert match.name == "greeting" + assert match.distance <= semantic_router.route_thresholds["greeting"] + + +def test_single_query_no_match(semantic_router): + redis_version = semantic_router._index.client.info()["redis_version"] + if not compare_versions(redis_version, "7.0.0"): + pytest.skip("Not using a late enough version of Redis") + + match = semantic_router("unknown_phrase") + assert match.name is None + + +def test_multiple_query(semantic_router): + redis_version = semantic_router._index.client.info()["redis_version"] + if not compare_versions(redis_version, "7.0.0"): + pytest.skip("Not using a late enough version of Redis") + + matches = semantic_router.route_many("hello", max_k=2) + assert len(matches) > 0 + assert matches[0].name == "greeting" + + +def test_update_routing_config(semantic_router): + new_config = RoutingConfig(distance_threshold=0.5, max_k=1) + semantic_router.update_routing_config(new_config) + assert semantic_router.routing_config.distance_threshold == 0.5 + assert semantic_router.routing_config.max_k == 1 + + +def test_vector_query(semantic_router): + redis_version = semantic_router._index.client.info()["redis_version"] + if not compare_versions(redis_version, "7.0.0"): + pytest.skip("Not using a late enough version of Redis") + + vector = semantic_router.vectorizer.embed("goodbye") + match = semantic_router(vector=vector) + assert match.name == "farewell" + + +def test_vector_query_no_match(semantic_router): + redis_version = semantic_router._index.client.info()["redis_version"] + if not compare_versions(redis_version, "7.0.0"): + pytest.skip("Not using a late enough version of Redis") + + vector = [ + 0.0 + ] * semantic_router.vectorizer.dims # Random vector unlikely to match any route + match = semantic_router(vector=vector) + assert match.name is None + + +def test_add_route(semantic_router): + new_routes = [ + Route( + name="politics", + references=[ + "are you liberal or conservative?", + "who will you vote for?", + "political speech", + ], + metadata={"type": "greeting"}, + ) + ] + semantic_router._add_routes(new_routes) + + route = semantic_router.get("politics") + assert route is not None + assert route.name == "politics" + assert "political speech" in route.references + + redis_version = semantic_router._index.client.info()["redis_version"] + if compare_versions(redis_version, "7.0.0"): + match = semantic_router("political speech") + print(match, flush=True) + assert match is not None + assert match.name == "politics" + + +def test_remove_routes(semantic_router): + semantic_router.remove_route("greeting") + assert semantic_router.get("greeting") is None + + semantic_router.remove_route("unknown_route") + assert semantic_router.get("unknown_route") is None diff --git a/tests/unit/test_route_schema.py b/tests/unit/test_route_schema.py new file mode 100644 index 00000000..f1ad5cb5 --- /dev/null +++ b/tests/unit/test_route_schema.py @@ -0,0 +1,125 @@ +import pytest +from pydantic.v1 import ValidationError + +from redisvl.extensions.router.schema import ( + DistanceAggregationMethod, + Route, + RouteMatch, + RoutingConfig, +) + + +def test_route_valid(): + route = Route( + name="Test Route", + references=["reference1", "reference2"], + metadata={"key": "value"}, + distance_threshold=0.3, + ) + assert route.name == "Test Route" + assert route.references == ["reference1", "reference2"] + assert route.metadata == {"key": "value"} + assert route.distance_threshold == 0.3 + + +def test_route_empty_name(): + with pytest.raises(ValidationError) as excinfo: + Route( + name="", + references=["reference1", "reference2"], + metadata={"key": "value"}, + distance_threshold=0.3, + ) + assert "Route name must not be empty" in str(excinfo.value) + + +def test_route_empty_references(): + with pytest.raises(ValidationError) as excinfo: + Route( + name="Test Route", + references=[], + metadata={"key": "value"}, + distance_threshold=0.3, + ) + assert "References must not be empty" in str(excinfo.value) + + +def test_route_non_empty_references(): + with pytest.raises(ValidationError) as excinfo: + Route( + name="Test Route", + references=["reference1", ""], + metadata={"key": "value"}, + distance_threshold=0.3, + ) + assert "All references must be non-empty strings" in str(excinfo.value) + + +def test_route_valid_no_threshold(): + route = Route( + name="Test Route", + references=["reference1", "reference2"], + metadata={"key": "value"}, + ) + assert route.name == "Test Route" + assert route.references == ["reference1", "reference2"] + assert route.metadata == {"key": "value"} + assert route.distance_threshold is None + + +def test_route_invalid_threshold_zero(): + with pytest.raises(ValidationError) as excinfo: + Route( + name="Test Route", + references=["reference1", "reference2"], + metadata={"key": "value"}, + distance_threshold=0, + ) + assert "Route distance threshold must be greater than zero" in str(excinfo.value) + + +def test_route_invalid_threshold_negative(): + with pytest.raises(ValidationError) as excinfo: + Route( + name="Test Route", + references=["reference1", "reference2"], + metadata={"key": "value"}, + distance_threshold=-0.1, + ) + assert "Route distance threshold must be greater than zero" in str(excinfo.value) + + +def test_route_match(): + route_match = RouteMatch(name="test", distance=0.25) + assert route_match.name == "test" + assert route_match.distance == 0.25 + + +def test_route_match_no_route(): + route_match = RouteMatch() + assert route_match.name is None + assert route_match.distance is None + + +def test_distance_aggregation_method(): + assert DistanceAggregationMethod.avg == DistanceAggregationMethod("avg") + assert DistanceAggregationMethod.min == DistanceAggregationMethod("min") + assert DistanceAggregationMethod.sum == DistanceAggregationMethod("sum") + + +def test_routing_config_valid(): + config = RoutingConfig(distance_threshold=0.6, max_k=5) + assert config.distance_threshold == 0.6 + assert config.max_k == 5 + + +def test_routing_config_invalid_max_k(): + with pytest.raises(ValidationError) as excinfo: + RoutingConfig(distance_threshold=0.6, max_k=0) + assert "max_k must be a positive integer" in str(excinfo.value) + + +def test_routing_config_invalid_distance_threshold(): + with pytest.raises(ValidationError) as excinfo: + RoutingConfig(distance_threshold=1.5, max_k=5) + assert "distance_threshold must be between 0 and 1" in str(excinfo.value)