Skip to content

Commit 3c74dee

Browse files
Fix sentence transformers reranker import (#231)
Sentence transformers was not being dynamically imported into the reranker module for hugging face, causing dependency issues for anyone using a reranker. Fixes #229
1 parent 7982f7d commit 3c74dee

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

redisvl/utils/rerank/cohere.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _initialize_clients(self, api_config: Optional[Dict]):
8383
from cohere import AsyncClient, Client
8484
except ImportError:
8585
raise ImportError(
86-
"Cohere vectorizer requires the cohere library. \
86+
"Cohere reranker requires the cohere library. \
8787
Please install with `pip install cohere`"
8888
)
8989

redisvl/utils/rerank/hf_cross_encoder.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Dict, List, Optional, Tuple, Union
22

3-
from sentence_transformers import CrossEncoder
3+
from pydantic.v1 import PrivateAttr
44

55
from redisvl.utils.rerank.base import BaseReranker
66

@@ -31,25 +31,44 @@ class HFCrossEncoderReranker(BaseReranker):
3131
)
3232
"""
3333

34+
_client: Any = PrivateAttr()
35+
3436
def __init__(
3537
self,
36-
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
38+
model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
3739
limit: int = 3,
3840
return_score: bool = True,
41+
**kwargs,
3942
) -> None:
4043
"""
4144
Initialize the HFCrossEncoderReranker with a specified model and ranking criteria.
4245
4346
Parameters:
44-
model_name (str): The name or path of the cross-encoder model to use for reranking.
47+
model (str): The name or path of the cross-encoder model to use for reranking.
4548
Defaults to 'cross-encoder/ms-marco-MiniLM-L-6-v2'.
4649
limit (int): The maximum number of results to return after reranking. Must be a positive integer.
4750
return_score (bool): Whether to return scores alongside the reranked results.
4851
"""
52+
model = model or kwargs.pop("model_name", None)
4953
super().__init__(
50-
model=model_name, rank_by=None, limit=limit, return_score=return_score
54+
model=model, rank_by=None, limit=limit, return_score=return_score
5155
)
52-
self.model: CrossEncoder = CrossEncoder(model_name)
56+
self._initialize_client(**kwargs)
57+
58+
def _initialize_client(self, **kwargs):
59+
"""
60+
Setup the huggingface cross-encoder client using optional kwargs.
61+
"""
62+
# Dynamic import of the sentence-transformers module
63+
try:
64+
from sentence_transformers import CrossEncoder
65+
except ImportError:
66+
raise ImportError(
67+
"HFCrossEncoder reranker requires the sentence-transformers library. \
68+
Please install with `pip install sentence-transformers`"
69+
)
70+
71+
self._client = CrossEncoder(self.model, **kwargs)
5372

5473
def rank(
5574
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
@@ -97,7 +116,7 @@ def rank(
97116
texts = [str(doc) for doc in docs]
98117
doc_subset = [{"content": doc} for doc in docs]
99118

100-
scores = self.model.predict([(query, text) for text in texts])
119+
scores = self._client.predict([(query, text) for text in texts])
101120
scores = [float(score) for score in scores]
102121
docs_with_scores = list(zip(doc_subset, scores))
103122
docs_with_scores.sort(key=lambda x: x[1], reverse=True)

0 commit comments

Comments
 (0)