Skip to content

Commit d9766da

Browse files
wip updates using aggregations
1 parent 69ea73e commit d9766da

File tree

3 files changed

+124
-93
lines changed

3 files changed

+124
-93
lines changed

redisvl/extensions/router/semantic.py

Lines changed: 48 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
from pydantic.v1 import BaseModel, root_validator, Field, PrivateAttr
22
from typing import Any, List, Dict, Optional, Union
3+
34
from redis import Redis
5+
from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer
6+
import redis.commands.search.reducers as reducers
7+
48
from redisvl.index import SearchIndex
59
from redisvl.query import VectorQuery, RangeQuery
610
from redisvl.schema import IndexSchema, IndexInfo
711
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
812
from redisvl.extensions.router.routes import Route, RoutingConfig, AccumulationMethod
913

14+
from redisvl.redis.utils import make_dict, convert_bytes
15+
1016
import hashlib
1117

1218

@@ -44,7 +50,7 @@ class SemanticRouter(BaseModel):
4450
"""Configuration for routing behavior"""
4551

4652
_index: SearchIndex = PrivateAttr()
47-
_accumulation_method: AccumulationMethod = PrivateAttr()
53+
# _accumulation_method: AccumulationMethod = PrivateAttr()
4854

4955
class Config:
5056
arbitrary_types_allowed = True
@@ -79,7 +85,7 @@ def __init__(
7985
routing_config=routing_config
8086
)
8187
self._initialize_index(redis_client, redis_url, overwrite)
82-
self._accumulation_method = self._pick_accumulation_method()
88+
# self._accumulation_method = self._pick_accumulation_method()
8389

8490
def _initialize_index(
8591
self,
@@ -110,19 +116,19 @@ def _initialize_index(
110116
if not existed or overwrite:
111117
self._add_routes(self.routes)
112118

113-
def _pick_accumulation_method(self) -> AccumulationMethod:
114-
"""Pick the accumulation method based on the routing configuration."""
115-
if self.routing_config.accumulation_method != AccumulationMethod.auto:
116-
return self.routing_config.accumulation_method
119+
# def _pick_accumulation_method(self) -> AccumulationMethod:
120+
# """Pick the accumulation method based on the routing configuration."""
121+
# if self.routing_config.accumulation_method != AccumulationMethod.auto:
122+
# return self.routing_config.accumulation_method
117123

118-
num_route_references = [len(route.references) for route in self.routes]
119-
avg_num_references = sum(num_route_references) / len(num_route_references)
120-
variance = sum((x - avg_num_references) ** 2 for x in num_route_references) / len(num_route_references)
124+
# num_route_references = [len(route.references) for route in self.routes]
125+
# avg_num_references = sum(num_route_references) / len(num_route_references)
126+
# variance = sum((x - avg_num_references) ** 2 for x in num_route_references) / len(num_route_references)
121127

122-
if variance < 1: # TODO: Arbitrary threshold for low variance
123-
return AccumulationMethod.sum
124-
else:
125-
return AccumulationMethod.avg
128+
# if variance < 1: # TODO: Arbitrary threshold for low variance
129+
# return AccumulationMethod.sum
130+
# else:
131+
# return AccumulationMethod.avg
126132

127133
def update_routing_config(self, routing_config: RoutingConfig):
128134
"""Update the routing configuration.
@@ -131,7 +137,7 @@ def update_routing_config(self, routing_config: RoutingConfig):
131137
routing_config (RoutingConfig): The new routing configuration.
132138
"""
133139
self.routing_config = routing_config
134-
self._accumulation_method = self._pick_accumulation_method()
140+
# self._accumulation_method = self._pick_accumulation_method()
135141

136142
def _add_routes(self, routes: List[Route]):
137143
"""Add routes to the index.
@@ -174,64 +180,41 @@ def __call__(
174180
max_k = max_k if max_k is not None else self.routing_config.max_k
175181
distance_threshold = distance_threshold if distance_threshold is not None else self.routing_config.distance_threshold
176182

177-
# get the total number of route references in the index
178-
num_route_references = sum(
179-
[len(route.references) for route in self.routes]
180-
)
183+
# # get the total number of route references in the index
184+
# num_route_references = sum(
185+
# [len(route.references) for route in self.routes]
186+
# )
181187
# define the baseline range query to fetch relevant route references
182-
query = RangeQuery(
188+
vector_range_query = RangeQuery(
183189
vector=vector,
184190
vector_field_name="vector",
185-
distance_threshold=distance_threshold,
186-
return_fields=["route_name", "reference"],
187-
# max number of results from range query
188-
num_results=num_route_references
191+
distance_threshold=2,
192+
return_fields=["route_name"]
189193
)
190-
# execute query and accumulate results
191-
route_references = self._index.query(query)
192-
top_routes_and_scores = self._reduce_scores(route_references, max_k)
193-
top_routes = self._fetch_routes(top_routes_and_scores)
194194

195-
return top_routes
195+
# build redis aggregation query
196+
aggregate_query = str(vector_range_query).split(" RETURN")[0]
197+
aggregate_request = (
198+
AggregateRequest(aggregate_query)
199+
.group_by(
200+
"@route_name",
201+
reducers.avg("vector_distance").alias("avg"),
202+
reducers.min("vector_distance").alias("score")
203+
)
204+
.apply(avg_score="1 - @avg", score="1 - @score")
205+
.dialect(2)
206+
)
196207

197-
def _reduce_scores(
198-
self,
199-
route_references: List[Dict[str, Any]],
200-
max_k: int
201-
) -> List[Dict[str, Any]]:
202-
"""Group by route name and reduce scores to return max_k routes overall.
208+
top_routes_and_scores = []
209+
aggregate_results = self._index.client.ft(self._index.name).aggregate(aggregate_request, vector_range_query.params)
203210

204-
Args:
205-
route_references: List of route references with scores.
206-
max_k: The number of top results to return.
211+
for result in aggregate_results.rows:
212+
top_routes_and_scores.append(make_dict(convert_bytes(result)))
213+
214+
top_routes = self._fetch_routes(top_routes_and_scores)
215+
216+
return top_routes
207217

208-
Returns:
209-
List[Dict[str, Any]]: Accumulated scores for the top routes.
210-
"""
211-
# TODO: eventually this should be replaced by an AggregationQuery class
212-
scores_by_route = {}
213-
for ref in route_references:
214-
route_name = ref['route_name']
215-
score = ref['vector_distance']
216-
if route_name not in scores_by_route:
217-
scores_by_route[route_name] = []
218-
scores_by_route[route_name].append(float(score))
219-
220-
accumulated_scores = []
221-
for route_name, scores in scores_by_route.items():
222-
if self._accumulation_method == AccumulationMethod.sum:
223-
accumulated_score = sum(scores)
224-
elif self._accumulation_method == AccumulationMethod.avg:
225-
accumulated_score = sum(scores) / len(scores)
226-
else:
227-
# simple strategy
228-
accumulated_score = scores[0] # take the first score
229-
230-
accumulated_scores.append({"route_name": route_name, "score": accumulated_score})
231-
232-
# Sort by score in descending order and return the max_k results
233-
accumulated_scores.sort(key=lambda x: x["score"], reverse=False)
234-
return accumulated_scores[:max_k]
235218

236219
def _fetch_routes(self, top_routes_and_scores: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
237220
"""Fetch route objects and metadata based on top matches.
@@ -250,6 +233,7 @@ def _fetch_routes(self, top_routes_and_scores: List[Dict[str, Any]]) -> List[Dic
250233
results.append({
251234
**route.dict(),
252235
"score": route_info["score"],
236+
"avg_score": route_info["avg_score"]
253237
})
254238

255239
return results

redisvl/extensions/router/test.ipynb

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"metadata": {},
1616
"outputs": [],
1717
"source": [
18+
"from redisvl.extensions.router.routes import Route\n",
19+
"\n",
1820
"# Define individual routes manually with metadata\n",
1921
"politics = Route(\n",
2022
" name=\"politics\",\n",
@@ -63,13 +65,13 @@
6365
"name": "stdout",
6466
"output_type": "stream",
6567
"text": [
66-
"23:37:26 redisvl.index.index INFO Index already exists, overwriting.\n"
68+
"13:38:58 redisvl.index.index INFO Index already exists, overwriting.\n"
6769
]
6870
}
6971
],
7072
"source": [
71-
"from redisvl.extensions.router.semantic import SemanticRouter\n",
7273
"import redis\n",
74+
"from redisvl.extensions.router.semantic import SemanticRouter\n",
7375
"\n",
7476
"# Create SemanticRouter named \"topic-router\"\n",
7577
"redis_client = redis.Redis()\n",
@@ -85,48 +87,67 @@
8587
},
8688
{
8789
"cell_type": "code",
88-
"execution_count": 7,
90+
"execution_count": 6,
8991
"metadata": {},
90-
"outputs": [],
92+
"outputs": [
93+
{
94+
"data": {
95+
"text/plain": [
96+
"[{'name': 'chitchat',\n",
97+
" 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n",
98+
" 'metadata': {'priority': '2'},\n",
99+
" 'score': '0.0641258955002',\n",
100+
" 'avg_score': '0.0481971502304'},\n",
101+
" {'name': 'politics',\n",
102+
" 'references': [\"isn't politics the best thing ever\",\n",
103+
" \"why don't you tell me about your political opinions\"],\n",
104+
" 'metadata': {'priority': '1'},\n",
105+
" 'score': '0.298070549965',\n",
106+
" 'avg_score': '0.207850039005'}]"
107+
]
108+
},
109+
"execution_count": 6,
110+
"metadata": {},
111+
"output_type": "execute_result"
112+
}
113+
],
91114
"source": [
92-
"assert topic_router.routes == routes\n",
93-
"assert topic_router.name == \"topic-router\"\n",
94-
"assert topic_router.name == topic_router._index.name == topic_router._index.prefix\n",
95-
"assert topic_router.routing_config == config"
115+
"topic_router(\"I am thinking about running for Governor in the state of VA. What do I need to consider?\")"
96116
]
97117
},
98118
{
99119
"cell_type": "code",
100-
"execution_count": 8,
120+
"execution_count": 7,
101121
"metadata": {},
102122
"outputs": [
103123
{
104124
"data": {
105125
"text/plain": [
106-
"[{'name': 'politics',\n",
126+
"[{'name': 'chitchat',\n",
127+
" 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n",
128+
" 'metadata': {'priority': '2'},\n",
129+
" 'score': '0.12840616703',\n",
130+
" 'avg_score': '0.112765411536'},\n",
131+
" {'name': 'politics',\n",
107132
" 'references': [\"isn't politics the best thing ever\",\n",
108133
" \"why don't you tell me about your political opinions\"],\n",
109134
" 'metadata': {'priority': '1'},\n",
110-
" 'score': 0.3825837373735},\n",
111-
" {'name': 'chitchat',\n",
112-
" 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n",
113-
" 'metadata': {'priority': '2'},\n",
114-
" 'score': 0.8872345884643332}]"
135+
" 'score': '0.764727830887',\n",
136+
" 'avg_score': '0.617416262627'}]"
115137
]
116138
},
117-
"execution_count": 8,
139+
"execution_count": 7,
118140
"metadata": {},
119141
"output_type": "execute_result"
120142
}
121143
],
122144
"source": [
123-
"# Query topic-router with behavior based on the config\n",
124145
"topic_router(\"don't you love politics?\")"
125146
]
126147
},
127148
{
128149
"cell_type": "code",
129-
"execution_count": 9,
150+
"execution_count": 8,
130151
"metadata": {},
131152
"outputs": [
132153
{
@@ -135,15 +156,17 @@
135156
"[{'name': 'chitchat',\n",
136157
" 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n",
137158
" 'metadata': {'priority': '2'},\n",
138-
" 'score': 0.5357088247936667},\n",
159+
" 'score': '0.54086458683',\n",
160+
" 'avg_score': '0.464291175207'},\n",
139161
" {'name': 'politics',\n",
140162
" 'references': [\"isn't politics the best thing ever\",\n",
141163
" \"why don't you tell me about your political opinions\"],\n",
142164
" 'metadata': {'priority': '1'},\n",
143-
" 'score': 0.8782881200315}]"
165+
" 'score': '0.156601548195',\n",
166+
" 'avg_score': '0.121711879969'}]"
144167
]
145168
},
146-
"execution_count": 9,
169+
"execution_count": 8,
147170
"metadata": {},
148171
"output_type": "execute_result"
149172
}
@@ -163,7 +186,7 @@
163186
},
164187
{
165188
"cell_type": "code",
166-
"execution_count": 13,
189+
"execution_count": 9,
167190
"metadata": {},
168191
"outputs": [],
169192
"source": [
@@ -174,16 +197,26 @@
174197
},
175198
{
176199
"cell_type": "code",
177-
"execution_count": 14,
200+
"execution_count": 10,
178201
"metadata": {},
179202
"outputs": [
180203
{
181204
"data": {
182205
"text/plain": [
183-
"[]"
206+
"[{'name': 'chitchat',\n",
207+
" 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n",
208+
" 'metadata': {'priority': '2'},\n",
209+
" 'score': '0.756013274193',\n",
210+
" 'avg_score': '0.423087894917'},\n",
211+
" {'name': 'politics',\n",
212+
" 'references': [\"isn't politics the best thing ever\",\n",
213+
" \"why don't you tell me about your political opinions\"],\n",
214+
" 'metadata': {'priority': '1'},\n",
215+
" 'score': '0.175542235374',\n",
216+
" 'avg_score': '0.138914197683'}]"
184217
]
185218
},
186-
"execution_count": 14,
219+
"execution_count": 10,
187220
"metadata": {},
188221
"output_type": "execute_result"
189222
}
@@ -195,7 +228,7 @@
195228
},
196229
{
197230
"cell_type": "code",
198-
"execution_count": 15,
231+
"execution_count": 11,
199232
"metadata": {},
200233
"outputs": [
201234
{
@@ -204,10 +237,17 @@
204237
"[{'name': 'chitchat',\n",
205238
" 'references': ['hello', \"how's the weather today?\", 'how are things going?'],\n",
206239
" 'metadata': {'priority': '2'},\n",
207-
" 'score': 0.243986725807}]"
240+
" 'score': '0.756013274193',\n",
241+
" 'avg_score': '0.423087894917'},\n",
242+
" {'name': 'politics',\n",
243+
" 'references': [\"isn't politics the best thing ever\",\n",
244+
" \"why don't you tell me about your political opinions\"],\n",
245+
" 'metadata': {'priority': '1'},\n",
246+
" 'score': '0.175542235374',\n",
247+
" 'avg_score': '0.138914197683'}]"
208248
]
209249
},
210-
"execution_count": 15,
250+
"execution_count": 11,
211251
"metadata": {},
212252
"output_type": "execute_result"
213253
}

redisvl/index/index.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,14 @@ def __init__(
176176
self.connect(redis_url, **connection_args)
177177

178178
# set up index storage layer
179-
self._storage = self._STORAGE_MAP[self.schema.index.storage_type](
179+
# self._storage = self._STORAGE_MAP[self.schema.index.storage_type](
180+
# prefix=self.schema.index.prefix,
181+
# key_separator=self.schema.index.key_separator,
182+
# )
183+
184+
@property
185+
def _storage(self):
186+
return self._STORAGE_MAP[self.schema.index.storage_type](
180187
prefix=self.schema.index.prefix,
181188
key_separator=self.schema.index.key_separator,
182189
)

0 commit comments

Comments
 (0)