Skip to content

Commit e0f6ed1

Browse files
committed
Add score unit test
1 parent 47bc62b commit e0f6ed1

File tree

2 files changed

+216
-11
lines changed

2 files changed

+216
-11
lines changed

django_mongodb_backend/expressions/search.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def search_operator(self, compiler, connection):
260260
"path": self.path.as_mql(compiler, connection, as_path=True),
261261
}
262262
if self.score is not None:
263-
params["score"] = self.score.definitions
263+
params["score"] = self.score.as_mql(compiler, connection)
264264
return {"exists": params}
265265

266266

@@ -601,7 +601,7 @@ def search_operator(self, compiler, connection):
601601
"query": self.query.value,
602602
}
603603
if self.score:
604-
params["score"] = self.score.query.as_mql(compiler, connection)
604+
params["score"] = self.score.as_mql(compiler, connection)
605605
if self.allow_analyzed_field is not None:
606606
params["allowAnalyzedField"] = self.allow_analyzed_field.value
607607
return {"wildcard": params}
@@ -991,10 +991,10 @@ class SearchScoreOption(Expression):
991991
"""Class to mutate scoring on a search operation"""
992992

993993
def __init__(self, definitions=None):
994-
self.definitions = definitions
994+
self._definitions = definitions
995995

996996
def as_mql(self, compiler, connection):
997-
return self.definitions
997+
return self._definitions
998998

999999

10001000
class SearchTextLookup(Lookup):

tests/queries_/test_search.py

Lines changed: 212 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
SearchPhrase,
2121
SearchRange,
2222
SearchRegex,
23+
SearchScoreOption,
2324
SearchText,
2425
SearchVector,
2526
SearchWildcard,
@@ -70,18 +71,16 @@ def inner_wait_loop(predicate: Callable):
7071
class SearchUtilsMixin(TransactionTestCase):
7172
available_apps = []
7273

73-
@staticmethod
74-
def _get_collection(model):
74+
def _get_collection(self, model):
7575
return connection.database.get_collection(model._meta.db_table)
7676

77-
@staticmethod
78-
def create_search_index(model, index_name, definition, type="search"):
79-
collection = SearchUtilsMixin._get_collection(model)
77+
def create_search_index(self, model, index_name, definition, type="search"):
78+
collection = self._get_collection(model)
8079
idx = SearchIndexModel(definition=definition, name=index_name, type=type)
8180
collection.create_search_index(idx)
8281

8382
def _tear_down(self, model):
84-
collection = SearchUtilsMixin._get_collection(model)
83+
collection = self._get_collection(model)
8584
for search_indexes in collection.list_search_indexes():
8685
collection.drop_search_index(search_indexes["name"])
8786
collection.delete_many({})
@@ -95,7 +94,12 @@ def setUp(self):
9594
self.create_search_index(
9695
Article,
9796
"equals_headline_index",
98-
{"mappings": {"dynamic": False, "fields": {"headline": {"type": "token"}}}},
97+
{
98+
"mappings": {
99+
"dynamic": False,
100+
"fields": {"headline": {"type": "token"}, "number": {"type": "number"}},
101+
}
102+
},
99103
)
100104
self.article = Article.objects.create(headline="cross", number=1, body="body")
101105
Article.objects.create(headline="other thing", number=2, body="body")
@@ -108,6 +112,44 @@ def test_search_equals(self):
108112
qs = Article.objects.annotate(score=SearchEquals(path="headline", value="cross"))
109113
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
110114

115+
def test_boost_score(self):
116+
boost_score = SearchScoreOption({"boost": {"value": 3}})
117+
118+
qs = Article.objects.annotate(
119+
score=SearchEquals(path="headline", value="cross", score=boost_score)
120+
)
121+
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
122+
scored = qs.first()
123+
self.assertGreaterEqual(scored.score, 3.0)
124+
125+
def test_constant_score(self):
126+
constant_score = SearchScoreOption({"constant": {"value": 10}})
127+
qs = Article.objects.annotate(
128+
score=SearchEquals(path="headline", value="cross", score=constant_score)
129+
)
130+
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
131+
scored = qs.first()
132+
self.assertAlmostEqual(scored.score, 10.0, places=2)
133+
134+
def test_function_score(self):
135+
function_score = SearchScoreOption(
136+
{
137+
"function": {
138+
"path": {
139+
"value": "number",
140+
"undefined": 0,
141+
},
142+
}
143+
}
144+
)
145+
146+
qs = Article.objects.annotate(
147+
score=SearchEquals(path="headline", value="cross", score=function_score)
148+
)
149+
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
150+
scored = qs.first()
151+
self.assertAlmostEqual(scored.score, 1.0, places=2)
152+
111153

112154
@skipUnlessDBFeature("supports_atlas_search")
113155
class SearchAutocompleteTest(SearchUtilsMixin):
@@ -173,6 +215,21 @@ def test_search_autocomplete_embedded_model(self):
173215
)
174216
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
175217

218+
def test_constant_score(self):
219+
constant_score = SearchScoreOption({"constant": {"value": 10}})
220+
qs = Article.objects.annotate(
221+
score=SearchAutocomplete(
222+
path="headline",
223+
query="crossing",
224+
token_order="sequential", # noqa: S106
225+
fuzzy={"maxEdits": 2},
226+
score=constant_score,
227+
)
228+
)
229+
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
230+
scored = qs.first()
231+
self.assertAlmostEqual(scored.score, 10.0, places=2)
232+
176233

177234
@skipUnlessDBFeature("supports_atlas_search")
178235
class SearchExistsTest(SearchUtilsMixin):
@@ -184,10 +241,21 @@ def setUp(self):
184241
)
185242
self.article = Article.objects.create(headline="ignored", number=3, body="something")
186243

244+
def tearDown(self):
245+
self._tear_down(Article)
246+
super().tearDown()
247+
187248
def test_search_exists(self):
188249
qs = Article.objects.annotate(score=SearchExists(path="body"))
189250
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
190251

252+
def test_constant_score(self):
253+
constant_score = SearchScoreOption({"constant": {"value": 10}})
254+
qs = Article.objects.annotate(score=SearchExists(path="body", score=constant_score))
255+
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
256+
scored = qs.first()
257+
self.assertAlmostEqual(scored.score, 10.0, places=2)
258+
191259

192260
@skipUnlessDBFeature("supports_atlas_search")
193261
class SearchInTest(SearchUtilsMixin):
@@ -208,6 +276,15 @@ def test_search_in(self):
208276
qs = Article.objects.annotate(score=SearchIn(path="headline", value=["cross", "river"]))
209277
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
210278

279+
def test_constant_score(self):
280+
constant_score = SearchScoreOption({"constant": {"value": 10}})
281+
qs = Article.objects.annotate(
282+
score=SearchIn(path="headline", value=["cross", "river"], score=constant_score)
283+
)
284+
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
285+
scored = qs.first()
286+
self.assertAlmostEqual(scored.score, 10.0, places=2)
287+
211288

212289
@skipUnlessDBFeature("supports_atlas_search")
213290
class SearchPhraseTest(SearchUtilsMixin):
@@ -230,6 +307,15 @@ def test_search_phrase(self):
230307
qs = Article.objects.annotate(score=SearchPhrase(path="body", query="quick brown"))
231308
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
232309

310+
def test_constant_score(self):
311+
constant_score = SearchScoreOption({"constant": {"value": 10}})
312+
qs = Article.objects.annotate(
313+
score=SearchPhrase(path="body", query="quick brown", score=constant_score)
314+
)
315+
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
316+
scored = qs.first()
317+
self.assertAlmostEqual(scored.score, 10.0, places=2)
318+
233319

234320
@skipUnlessDBFeature("supports_atlas_search")
235321
class SearchRangeTest(SearchUtilsMixin):
@@ -250,6 +336,15 @@ def test_search_range(self):
250336
qs = Article.objects.annotate(score=SearchRange(path="number", gte=10, lt=30))
251337
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.number20]))
252338

339+
def test_constant_score(self):
340+
constant_score = SearchScoreOption({"constant": {"value": 10}})
341+
qs = Article.objects.annotate(
342+
score=SearchRange(path="number", gte=10, lt=30, score=constant_score)
343+
)
344+
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.number20]))
345+
scored = qs.first()
346+
self.assertAlmostEqual(scored.score, 10.0, places=2)
347+
253348

254349
@skipUnlessDBFeature("supports_atlas_search")
255350
class SearchRegexTest(SearchUtilsMixin):
@@ -277,6 +372,17 @@ def test_search_regex(self):
277372
)
278373
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
279374

375+
def test_constant_score(self):
376+
constant_score = SearchScoreOption({"constant": {"value": 10}})
377+
qs = Article.objects.annotate(
378+
score=SearchRegex(
379+
path="headline", query="hello.*", allow_analyzed_field=True, score=constant_score
380+
)
381+
)
382+
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
383+
scored = qs.first()
384+
self.assertAlmostEqual(scored.score, 10.0, places=2)
385+
280386

281387
@skipUnlessDBFeature("supports_atlas_search")
282388
class SearchTextTest(SearchUtilsMixin):
@@ -311,6 +417,21 @@ def test_search_text_with_fuzzy_and_criteria(self):
311417
)
312418
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
313419

420+
def test_constant_score(self):
421+
constant_score = SearchScoreOption({"constant": {"value": 10}})
422+
qs = Article.objects.annotate(
423+
score=SearchText(
424+
path="body",
425+
query="lazzy",
426+
fuzzy={"maxEdits": 2},
427+
match_criteria="all",
428+
score=constant_score,
429+
)
430+
)
431+
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
432+
scored = qs.first()
433+
self.assertAlmostEqual(scored.score, 10.0, places=2)
434+
314435

315436
@skipUnlessDBFeature("supports_atlas_search")
316437
class SearchWildcardTest(SearchUtilsMixin):
@@ -336,6 +457,15 @@ def test_search_wildcard(self):
336457
qs = Article.objects.annotate(score=SearchWildcard(path="headline", query="dark-*"))
337458
self.wait_for_assertion(lambda: self.assertCountEqual([self.article], qs))
338459

460+
def test_constant_score(self):
461+
constant_score = SearchScoreOption({"constant": {"value": 10}})
462+
qs = Article.objects.annotate(
463+
score=SearchWildcard(path="headline", query="dark-*", score=constant_score)
464+
)
465+
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
466+
scored = qs.first()
467+
self.assertAlmostEqual(scored.score, 10.0, places=2)
468+
339469

340470
@skipUnlessDBFeature("supports_atlas_search")
341471
class SearchGeoShapeTest(SearchUtilsMixin):
@@ -371,6 +501,21 @@ def test_search_geo_shape(self):
371501
)
372502
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
373503

504+
def test_constant_score(self):
505+
polygon = {
506+
"type": "Polygon",
507+
"coordinates": [[[30, 0], [50, 0], [50, 10], [30, 10], [30, 0]]],
508+
}
509+
constant_score = SearchScoreOption({"constant": {"value": 10}})
510+
qs = Article.objects.annotate(
511+
score=SearchGeoShape(
512+
path="location", relation="within", geometry=polygon, score=constant_score
513+
)
514+
)
515+
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
516+
scored = qs.first()
517+
self.assertAlmostEqual(scored.score, 10.0, places=2)
518+
374519

375520
@skipUnlessDBFeature("supports_atlas_search")
376521
class SearchGeoWithinTest(SearchUtilsMixin):
@@ -405,6 +550,24 @@ def test_search_geo_within(self):
405550
)
406551
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
407552

553+
def test_constant_score(self):
554+
polygon = {
555+
"type": "Polygon",
556+
"coordinates": [[[30, 0], [50, 0], [50, 10], [30, 10], [30, 0]]],
557+
}
558+
constant_score = SearchScoreOption({"constant": {"value": 10}})
559+
qs = Article.objects.annotate(
560+
score=SearchGeoWithin(
561+
path="location",
562+
kind="geometry",
563+
geo_object=polygon,
564+
score=constant_score,
565+
)
566+
)
567+
self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article]))
568+
scored = qs.first()
569+
self.assertAlmostEqual(scored.score, 10.0, places=2)
570+
408571

409572
@skipUnlessDBFeature("supports_atlas_search")
410573
@unittest.expectedFailure
@@ -523,6 +686,48 @@ def test_operations(self):
523686
lambda: self.assertCountEqual(qs.all(), [self.mars_mission, self.exoplanet])
524687
)
525688

689+
def test_mixed_scores(self):
690+
boost_score = SearchScoreOption({"boost": {"value": 5}})
691+
constant_score = SearchScoreOption({"constant": {"value": 20}})
692+
function_score = SearchScoreOption(
693+
{"function": {"path": {"value": "number", "undefined": 0}}}
694+
)
695+
696+
must_expr = SearchEquals(path="headline", value="space exploration", score=boost_score)
697+
should_expr = SearchPhrase(path="body", query="exoplanets", score=constant_score)
698+
must_not_expr = SearchPhrase(path="body", query="icy moons", score=function_score)
699+
700+
compound = CompoundExpression(
701+
must=[must_expr],
702+
must_not=[must_not_expr],
703+
should=[should_expr],
704+
)
705+
qs = Article.objects.annotate(score=compound).order_by("-score")
706+
self.wait_for_assertion(
707+
lambda: self.assertListEqual(list(qs.all()), [self.exoplanet, self.mars_mission])
708+
)
709+
# Exoplanet should rank first because of the constant 20 bump.
710+
self.assertEqual(qs.first(), self.exoplanet)
711+
712+
def test_operationss_with_function_score(self):
713+
function_score = SearchScoreOption(
714+
{"function": {"path": {"value": "number", "undefined": 0}}}
715+
)
716+
717+
expr = SearchEquals(
718+
path="headline",
719+
value="space exploration",
720+
score=function_score,
721+
) & ~SearchEquals(path="number", value=3)
722+
723+
qs = Article.objects.annotate(score=expr).order_by("-score")
724+
725+
self.wait_for_assertion(
726+
lambda: self.assertListEqual(list(qs), [self.exoplanet, self.mars_mission])
727+
)
728+
# Returns mars_mission (score≈1) and exoplanet (score≈2) then; exoplanet first.
729+
self.assertEqual(qs.first(), self.exoplanet)
730+
526731
def test_multiple_search(self):
527732
msg = (
528733
"Only one $search operation is allowed per query. Received 2 search expressions. "

0 commit comments

Comments
 (0)