|
1 | 1 | from typing import Any, Dict, List, Optional, Tuple, Union
|
2 | 2 |
|
3 |
| -from sentence_transformers import CrossEncoder |
| 3 | +from pydantic.v1 import PrivateAttr |
4 | 4 |
|
5 | 5 | from redisvl.utils.rerank.base import BaseReranker
|
6 | 6 |
|
@@ -31,25 +31,44 @@ class HFCrossEncoderReranker(BaseReranker):
|
31 | 31 | )
|
32 | 32 | """
|
33 | 33 |
|
| 34 | + _client: Any = PrivateAttr() |
| 35 | + |
34 | 36 | def __init__(
|
35 | 37 | self,
|
36 |
| - model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", |
| 38 | + model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", |
37 | 39 | limit: int = 3,
|
38 | 40 | return_score: bool = True,
|
| 41 | + **kwargs, |
39 | 42 | ) -> None:
|
40 | 43 | """
|
41 | 44 | Initialize the HFCrossEncoderReranker with a specified model and ranking criteria.
|
42 | 45 |
|
43 | 46 | Parameters:
|
44 |
| - model_name (str): The name or path of the cross-encoder model to use for reranking. |
| 47 | + model (str): The name or path of the cross-encoder model to use for reranking. |
45 | 48 | Defaults to 'cross-encoder/ms-marco-MiniLM-L-6-v2'.
|
46 | 49 | limit (int): The maximum number of results to return after reranking. Must be a positive integer.
|
47 | 50 | return_score (bool): Whether to return scores alongside the reranked results.
|
48 | 51 | """
|
| 52 | + model = model or kwargs.pop("model_name", None) |
49 | 53 | super().__init__(
|
50 |
| - model=model_name, rank_by=None, limit=limit, return_score=return_score |
| 54 | + model=model, rank_by=None, limit=limit, return_score=return_score |
51 | 55 | )
|
52 |
| - self.model: CrossEncoder = CrossEncoder(model_name) |
| 56 | + self._initialize_client(**kwargs) |
| 57 | + |
| 58 | + def _initialize_client(self, **kwargs): |
| 59 | + """ |
| 60 | + Setup the huggingface cross-encoder client using optional kwargs. |
| 61 | + """ |
| 62 | + # Dynamic import of the sentence-transformers module |
| 63 | + try: |
| 64 | + from sentence_transformers import CrossEncoder |
| 65 | + except ImportError: |
| 66 | + raise ImportError( |
| 67 | + "HFCrossEncoder reranker requires the sentence-transformers library. \ |
| 68 | + Please install with `pip install sentence-transformers`" |
| 69 | + ) |
| 70 | + |
| 71 | + self._client = CrossEncoder(self.model, **kwargs) |
53 | 72 |
|
54 | 73 | def rank(
|
55 | 74 | self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
|
@@ -97,7 +116,7 @@ def rank(
|
97 | 116 | texts = [str(doc) for doc in docs]
|
98 | 117 | doc_subset = [{"content": doc} for doc in docs]
|
99 | 118 |
|
100 |
| - scores = self.model.predict([(query, text) for text in texts]) |
| 119 | + scores = self._client.predict([(query, text) for text in texts]) |
101 | 120 | scores = [float(score) for score in scores]
|
102 | 121 | docs_with_scores = list(zip(doc_subset, scores))
|
103 | 122 | docs_with_scores.sort(key=lambda x: x[1], reverse=True)
|
|
0 commit comments