Skip to content

Commit bb95f1c

Browse files
committed
revert back to string input
1 parent 5e1cadb commit bb95f1c

File tree

3 files changed

+3
-29
lines changed

3 files changed

+3
-29
lines changed

redis/commands/search/aggregation.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,8 @@
1-
from enum import Enum
21
from typing import List, Union
32

43
FIELDNAME = object()
54

65

7-
class Scorers(Enum):
8-
TFIDF = "TFIDF"
9-
TFIDF_DOCNORM = "TFIDF.DOCNORM"
10-
BM25 = "BM25"
11-
DISMAX = "DISMAX"
12-
DOCSCORE = "DOCSCORE"
13-
HAMMING = "HAMMING"
14-
15-
166
class Limit:
177
def __init__(self, offset: int = 0, count: int = 0) -> None:
188
self.offset = offset
@@ -122,7 +112,7 @@ def __init__(self, query: str = "*") -> None:
122112
self._cursor = []
123113
self._dialect = None
124114
self._add_scores = False
125-
self._scorer = Scorers.TFIDF.value
115+
self._scorer = "TFIDF"
126116

127117
def load(self, *fields: List[str]) -> "AggregateRequest":
128118
"""
@@ -311,15 +301,15 @@ def add_scores(self) -> "AggregateRequest":
311301
self._add_scores = True
312302
return self
313303

314-
def scorer(self, scorer: Scorers) -> "AggregateRequest":
304+
def scorer(self, scorer: str) -> "AggregateRequest":
315305
"""
316306
Use a different scoring function to evaluate document relevance.
317307
Default is `TFIDF`.
318308
319309
:param scorer: The scoring function to use
320310
(e.g. `TFIDF.DOCNORM` or `BM25`)
321311
"""
322-
self._scorer = Scorers(scorer).value
312+
self._scorer = scorer
323313
return self
324314

325315
def verbatim(self) -> "AggregateRequest":

tests/test_asyncio/test_search.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,14 +1611,6 @@ async def test_aggregations_hybrid_scoring(decoded_r: redis.Redis):
16111611
len(row) == 6
16121612

16131613

1614-
@pytest.mark.redismod
1615-
@skip_ifmodversion_lt("2.10.05", "search")
1616-
async def test_invalid_scorer():
1617-
1618-
with pytest.raises(ValueError):
1619-
aggregations.AggregateRequest("*").scorer("blah")
1620-
1621-
16221614
@pytest.mark.redismod
16231615
@skip_if_redis_enterprise()
16241616
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):

tests/test_search.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,14 +1521,6 @@ async def test_aggregations_hybrid_scoring(client):
15211521
len(row) == 6
15221522

15231523

1524-
@pytest.mark.redismod
1525-
@skip_ifmodversion_lt("2.10.05", "search")
1526-
async def test_invalid_scorer():
1527-
1528-
with pytest.raises(ValueError):
1529-
aggregations.AggregateRequest("*").scorer("blah")
1530-
1531-
15321524
@pytest.mark.redismod
15331525
@skip_ifmodversion_lt("2.0.0", "search")
15341526
def test_index_definition(client):

0 commit comments

Comments
 (0)