Skip to content

Commit 5e845f2

Browse files
authored
Reranker implementation using Cross Encoders (HuggingFace / SentenceT… (#150)
…ransformers) Inspired by [LangChain's Cross Encoder Reranker](https://python.langchain.com/v0.1/docs/integrations/document_transformers/cross_encoder_reranker/)
1 parent c7e90ea commit 5e845f2

File tree

5 files changed

+353
-43
lines changed

5 files changed

+353
-43
lines changed

docs/user_guide/rerankers_06.ipynb

+116-28
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
"\n",
1010
"In this notebook, we will show how to use RedisVL to rerank search results\n",
1111
"(documents or chunks or records) based on the input query. Today RedisVL\n",
12-
"supports reranking through the [Cohere /rerank API](https://docs.cohere.com/docs/rerank-2).\n",
12+
"supports reranking through: \n",
13+
"\n",
14+
"- A re-ranker that uses pre-trained [Cross-Encoders](https://sbert.net/examples/applications/cross-encoder/README.html) which can use models from [Hugging Face cross encoder models](https://huggingface.co/cross-encoder) or Hugging Face models that implement a cross encoder function ([example: BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)).\n",
15+
"- The [Cohere /rerank API](https://docs.cohere.com/docs/rerank-2).\n",
1316
"\n",
1417
"Before running this notebook, be sure to:\n",
1518
"1. Have installed ``redisvl`` and have that environment active for this notebook.\n",
@@ -26,8 +29,10 @@
2629
},
2730
{
2831
"cell_type": "code",
29-
"execution_count": 1,
30-
"metadata": {},
32+
"execution_count": 27,
33+
"metadata": {
34+
"metadata": {}
35+
},
3136
"outputs": [],
3237
"source": [
3338
"# import necessary modules\n",
@@ -48,8 +53,10 @@
4853
},
4954
{
5055
"cell_type": "code",
51-
"execution_count": 2,
52-
"metadata": {},
56+
"execution_count": 28,
57+
"metadata": {
58+
"metadata": {}
59+
},
5360
"outputs": [],
5461
"source": [
5562
"query = \"What is the capital of the United States?\"\n",
@@ -75,24 +82,93 @@
7582
"cell_type": "markdown",
7683
"metadata": {},
7784
"source": [
78-
"### Init the Reranker\n",
85+
"### Using the Cross-Encoder Reranker\n",
7986
"\n",
80-
"Initialize the reranker. Install the cohere library and provide the right Cohere API Key."
87+
"To use the cross-encoder reranker we initialize an instance of `HFCrossEncoderReranker` passing a suitable model (if no model is provided, the `cross-encoder/ms-marco-MiniLM-L-6-v2` model is used): "
8188
]
8289
},
8390
{
8491
"cell_type": "code",
85-
"execution_count": null,
92+
"execution_count": 29,
93+
"metadata": {
94+
"metadata": {}
95+
},
96+
"outputs": [],
97+
"source": [
98+
"from redisvl.utils.rerank import HFCrossEncoderReranker\n",
99+
"\n",
100+
"cross_encoder_reranker = HFCrossEncoderReranker(\"BAAI/bge-reranker-base\")"
101+
]
102+
},
103+
{
104+
"cell_type": "markdown",
86105
"metadata": {},
106+
"source": [
107+
"### Rerank documents with HFCrossEncoderReranker\n",
108+
"\n",
109+
"With the obtained reranker instance we can rerank and truncate the list of\n",
110+
"documents based on relevance to the initial query."
111+
]
112+
},
113+
{
114+
"cell_type": "code",
115+
"execution_count": 30,
116+
"metadata": {
117+
"metadata": {}
118+
},
87119
"outputs": [],
88120
"source": [
89-
"#!pip install cohere"
121+
"results, scores = cross_encoder_reranker.rank(query=query, docs=docs)"
90122
]
91123
},
92124
{
93125
"cell_type": "code",
94-
"execution_count": 3,
126+
"execution_count": 31,
127+
"metadata": {
128+
"metadata": {}
129+
},
130+
"outputs": [
131+
{
132+
"name": "stdout",
133+
"output_type": "stream",
134+
"text": [
135+
"0.07461125403642654 -- {'content': 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.'}\n",
136+
"0.05220315232872963 -- {'content': 'Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas.'}\n",
137+
"0.3802368640899658 -- {'content': 'Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274.'}\n"
138+
]
139+
}
140+
],
141+
"source": [
142+
"for result, score in zip(results, scores):\n",
143+
" print(score, \" -- \", result)"
144+
]
145+
},
146+
{
147+
"cell_type": "markdown",
95148
"metadata": {},
149+
"source": [
150+
"### Using the Cohere Reranker\n",
151+
"\n",
152+
"To initialize the Cohere reranker you'll need to install the cohere library and provide the right Cohere API Key."
153+
]
154+
},
155+
{
156+
"cell_type": "code",
157+
"execution_count": 32,
158+
"metadata": {
159+
"metadata": {}
160+
},
161+
"outputs": [],
162+
"source": [
163+
"#!pip install cohere"
164+
]
165+
},
166+
{
167+
"cell_type": "code",
168+
"execution_count": 33,
169+
"metadata": {
170+
"metadata": {}
171+
},
96172
"outputs": [],
97173
"source": [
98174
"import getpass\n",
@@ -103,38 +179,44 @@
103179
},
104180
{
105181
"cell_type": "code",
106-
"execution_count": 4,
107-
"metadata": {},
182+
"execution_count": 34,
183+
"metadata": {
184+
"metadata": {}
185+
},
108186
"outputs": [],
109187
"source": [
110188
"from redisvl.utils.rerank import CohereReranker\n",
111189
"\n",
112-
"reranker = CohereReranker(limit=3, api_config={\"api_key\": api_key})"
190+
"cohere_reranker = CohereReranker(limit=3, api_config={\"api_key\": api_key})"
113191
]
114192
},
115193
{
116194
"cell_type": "markdown",
117195
"metadata": {},
118196
"source": [
119-
"### Rerank documents\n",
197+
"### Rerank documents with CohereReranker\n",
120198
"\n",
121-
"Below we will use the `CohereReranker` to rerank and also truncate the list of\n",
199+
"Below we will use the `CohereReranker` to rerank and truncate the list of\n",
122200
"documents above based on relevance to the initial query."
123201
]
124202
},
125203
{
126204
"cell_type": "code",
127-
"execution_count": 5,
128-
"metadata": {},
205+
"execution_count": 35,
206+
"metadata": {
207+
"metadata": {}
208+
},
129209
"outputs": [],
130210
"source": [
131-
"results, scores = reranker.rank(query=query, docs=docs)"
211+
"results, scores = cohere_reranker.rank(query=query, docs=docs)"
132212
]
133213
},
134214
{
135215
"cell_type": "code",
136-
"execution_count": 7,
137-
"metadata": {},
216+
"execution_count": 36,
217+
"metadata": {
218+
"metadata": {}
219+
},
138220
"outputs": [
139221
{
140222
"name": "stdout",
@@ -162,8 +244,10 @@
162244
},
163245
{
164246
"cell_type": "code",
165-
"execution_count": 8,
166-
"metadata": {},
247+
"execution_count": 37,
248+
"metadata": {
249+
"metadata": {}
250+
},
167251
"outputs": [],
168252
"source": [
169253
"docs = [\n",
@@ -192,17 +276,21 @@
192276
},
193277
{
194278
"cell_type": "code",
195-
"execution_count": 10,
196-
"metadata": {},
279+
"execution_count": 38,
280+
"metadata": {
281+
"metadata": {}
282+
},
197283
"outputs": [],
198284
"source": [
199-
"results, scores = reranker.rank(query=query, docs=docs, rank_by=[\"passage\", \"source\"])"
285+
"results, scores = cohere_reranker.rank(query=query, docs=docs, rank_by=[\"passage\", \"source\"])"
200286
]
201287
},
202288
{
203289
"cell_type": "code",
204-
"execution_count": 11,
205-
"metadata": {},
290+
"execution_count": 39,
291+
"metadata": {
292+
"metadata": {}
293+
},
206294
"outputs": [
207295
{
208296
"name": "stdout",
@@ -236,7 +324,7 @@
236324
"name": "python",
237325
"nbconvert_exporter": "python",
238326
"pygments_lexer": "ipython3",
239-
"version": "3.10.14"
327+
"version": "3.11.9"
240328
},
241329
"orig_nbformat": 4,
242330
"vscode": {

redisvl/utils/rerank/__init__.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from redisvl.utils.rerank.base import BaseReranker
22
from redisvl.utils.rerank.cohere import CohereReranker
3+
from redisvl.utils.rerank.hf_cross_encoder import HFCrossEncoderReranker
34

4-
__all__ = [
5-
"BaseReranker",
6-
"CohereReranker",
7-
]
5+
__all__ = ["BaseReranker", "CohereReranker", "HFCrossEncoderReranker"]
+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from typing import Any, Dict, List, Optional, Tuple, Union
2+
3+
from sentence_transformers import CrossEncoder
4+
5+
from redisvl.utils.rerank.base import BaseReranker
6+
7+
8+
class HFCrossEncoderReranker(BaseReranker):
9+
"""
10+
The HFCrossEncoderReranker class uses a cross-encoder models from Hugging Face
11+
to rerank documents based on an input query.
12+
13+
This reranker loads a cross-encoder model using the `CrossEncoder` class
14+
from the `sentence_transformers` library. It requires the
15+
`sentence_transformers` library to be installed.
16+
17+
.. code-block:: python
18+
19+
from redisvl.utils.rerank import HFCrossEncoderReranker
20+
21+
# set up the HFCrossEncoderReranker with a specific model
22+
reranker = HFCrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", limit=3)
23+
# rerank raw search results based on user input/query
24+
results = reranker.rank(
25+
query="your input query text here",
26+
docs=[
27+
{"content": "document 1"},
28+
{"content": "document 2"},
29+
{"content": "document 3"}
30+
]
31+
)
32+
"""
33+
34+
def __init__(
35+
self,
36+
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
37+
limit: int = 3,
38+
return_score: bool = True,
39+
) -> None:
40+
"""
41+
Initialize the HFCrossEncoderReranker with a specified model and ranking criteria.
42+
43+
Parameters:
44+
model_name (str): The name or path of the cross-encoder model to use for reranking.
45+
Defaults to 'cross-encoder/ms-marco-MiniLM-L-6-v2'.
46+
limit (int): The maximum number of results to return after reranking. Must be a positive integer.
47+
return_score (bool): Whether to return scores alongside the reranked results.
48+
"""
49+
super().__init__(
50+
model=model_name, rank_by=None, limit=limit, return_score=return_score
51+
)
52+
self.model: CrossEncoder = CrossEncoder(model_name)
53+
54+
def rank(
55+
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
56+
) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
57+
"""
58+
Rerank documents based on the provided query using the loaded cross-encoder model.
59+
60+
This method processes the user's query and the provided documents to rerank them
61+
in a manner that is potentially more relevant to the query's context.
62+
63+
Parameters:
64+
query (str): The user's search query.
65+
docs (Union[List[Dict[str, Any]], List[str]]): The list of documents to be ranked,
66+
either as dictionaries or strings.
67+
68+
Returns:
69+
Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
70+
The reranked list of documents and optionally associated scores.
71+
"""
72+
limit = kwargs.get("limit", self.limit)
73+
return_score = kwargs.get("return_score", self.return_score)
74+
75+
if not query:
76+
raise ValueError("query cannot be empty")
77+
78+
if not isinstance(query, str):
79+
raise TypeError("query must be a string")
80+
81+
if not isinstance(docs, list):
82+
raise TypeError("docs must be a list")
83+
84+
if not docs:
85+
return [] if not return_score else ([], [])
86+
87+
if all(isinstance(doc, dict) for doc in docs):
88+
texts = [
89+
str(doc["content"])
90+
for doc in docs
91+
if isinstance(doc, dict) and "content" in doc
92+
]
93+
doc_subset = [
94+
doc for doc in docs if isinstance(doc, dict) and "content" in doc
95+
]
96+
else:
97+
texts = [str(doc) for doc in docs]
98+
doc_subset = [{"content": doc} for doc in docs]
99+
100+
scores = self.model.predict([(query, text) for text in texts])
101+
scores = [float(score) for score in scores]
102+
docs_with_scores = list(zip(doc_subset, scores))
103+
docs_with_scores.sort(key=lambda x: x[1], reverse=True)
104+
reranked_docs = [doc for doc, _ in docs_with_scores[:limit]]
105+
scores = scores[:limit]
106+
107+
if return_score:
108+
return reranked_docs, scores
109+
return reranked_docs
110+
111+
async def arank(
112+
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
113+
) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
114+
"""
115+
Asynchronously rerank documents based on the provided query using the loaded cross-encoder model.
116+
117+
This method processes the user's query and the provided documents to rerank them
118+
in a manner that is potentially more relevant to the query's context.
119+
120+
Parameters:
121+
query (str): The user's search query.
122+
docs (Union[List[Dict[str, Any]], List[str]]): The list of documents to be ranked,
123+
either as dictionaries or strings.
124+
125+
Returns:
126+
Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
127+
The reranked list of documents and optionally associated scores.
128+
"""
129+
return self.rank(query, docs, **kwargs)

0 commit comments

Comments
 (0)