Skip to content

Commit 6ba4641

Browse files
authored
Support CASESENSITIVE for TAG fields (#2112)
* Support CASESENSITIVE for TAG fields * add wait fot index + update all the callings to use getattr() instead of the string "idx"
1 parent e5e265d commit 6ba4641

File tree

2 files changed

+49
-15
lines changed

2 files changed

+49
-15
lines changed

redis/commands/search/field.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,20 @@ class TagField(Field):
105105
"""
106106

107107
SEPARATOR = "SEPARATOR"
108+
CASESENSITIVE = "CASESENSITIVE"
108109

109-
def __init__(self, name: str, separator: str = ",", **kwargs):
110-
Field.__init__(
111-
self, name, args=[Field.TAG, self.SEPARATOR, separator], **kwargs
112-
)
110+
def __init__(
111+
self,
112+
name: str,
113+
separator: str = ",",
114+
case_sensitive: bool = False,
115+
**kwargs,
116+
):
117+
args = [Field.TAG, self.SEPARATOR, separator]
118+
if case_sensitive:
119+
args.append(self.CASESENSITIVE)
120+
121+
Field.__init__(self, name, args=args, **kwargs)
113122

114123

115124
class VectorField(Field):

tests/test_search.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def client(modclient):
107107
def test_client(client):
108108
num_docs = 500
109109
createIndex(client.ft(), num_docs=num_docs)
110-
waitForIndex(client, "idx")
110+
waitForIndex(client, getattr(client.ft(), "index_name", "idx"))
111111
# verify info
112112
info = client.ft().info()
113113
for k in [
@@ -252,7 +252,7 @@ def test_replace(client):
252252

253253
client.ft().add_document("doc1", txt="foo bar")
254254
client.ft().add_document("doc2", txt="foo bar")
255-
waitForIndex(client, "idx")
255+
waitForIndex(client, getattr(client.ft(), "index_name", "idx"))
256256

257257
res = client.ft().search("foo bar")
258258
assert 2 == res.total
@@ -272,7 +272,7 @@ def test_stopwords(client):
272272
client.ft().create_index((TextField("txt"),), stopwords=["foo", "bar", "baz"])
273273
client.ft().add_document("doc1", txt="foo bar")
274274
client.ft().add_document("doc2", txt="hello world")
275-
waitForIndex(client, "idx")
275+
waitForIndex(client, getattr(client.ft(), "index_name", "idx"))
276276

277277
q1 = Query("foo bar").no_content()
278278
q2 = Query("foo bar hello world").no_content()
@@ -287,7 +287,7 @@ def test_filters(client):
287287
client.ft().add_document("doc1", txt="foo bar", num=3.141, loc="-0.441,51.458")
288288
client.ft().add_document("doc2", txt="foo baz", num=2, loc="-0.1,51.2")
289289

290-
waitForIndex(client, "idx")
290+
waitForIndex(client, getattr(client.ft(), "index_name", "idx"))
291291
# Test numerical filter
292292
q1 = Query("foo").add_filter(NumericFilter("num", 0, 2)).no_content()
293293
q2 = (
@@ -456,7 +456,7 @@ def test_no_index(client):
456456
client.ft().add_document(
457457
"doc2", field="aab", text="2", numeric="2", geo="2,2", tag="2"
458458
)
459-
waitForIndex(client, "idx")
459+
waitForIndex(client, getattr(client.ft(), "index_name", "idx"))
460460

461461
res = client.ft().search(Query("@text:aa*"))
462462
assert 0 == res.total
@@ -498,7 +498,7 @@ def test_partial(client):
498498
client.ft().add_document("doc2", f1="f1_val", f2="f2_val")
499499
client.ft().add_document("doc1", f3="f3_val", partial=True)
500500
client.ft().add_document("doc2", f3="f3_val", replace=True)
501-
waitForIndex(client, "idx")
501+
waitForIndex(client, getattr(client.ft(), "index_name", "idx"))
502502

503503
# Search for f3 value. All documents should have it
504504
res = client.ft().search("@f3:f3_val")
@@ -516,7 +516,7 @@ def test_no_create(client):
516516
client.ft().add_document("doc2", f1="f1_val", f2="f2_val")
517517
client.ft().add_document("doc1", f3="f3_val", no_create=True)
518518
client.ft().add_document("doc2", f3="f3_val", no_create=True, partial=True)
519-
waitForIndex(client, "idx")
519+
waitForIndex(client, getattr(client.ft(), "index_name", "idx"))
520520

521521
# Search for f3 value. All documents should have it
522522
res = client.ft().search("@f3:f3_val")
@@ -546,7 +546,7 @@ def test_explaincli(client):
546546
@pytest.mark.redismod
547547
def test_summarize(client):
548548
createIndex(client.ft())
549-
waitForIndex(client, "idx")
549+
waitForIndex(client, getattr(client.ft(), "index_name", "idx"))
550550

551551
q = Query("king henry").paging(0, 1)
552552
q.highlight(fields=("play", "txt"), tags=("<b>", "</b>"))
@@ -654,7 +654,7 @@ def test_tags(client):
654654

655655
client.ft().add_document("doc1", txt="fooz barz", tags=tags)
656656
client.ft().add_document("doc2", txt="noodles", tags=tags2)
657-
waitForIndex(client, "idx")
657+
waitForIndex(client, getattr(client.ft(), "index_name", "idx"))
658658

659659
q = Query("@tags:{foo}")
660660
res = client.ft().search(q)
@@ -714,7 +714,7 @@ def test_spell_check(client):
714714

715715
client.ft().add_document("doc1", f1="some valid content", f2="this is sample text")
716716
client.ft().add_document("doc2", f1="very important", f2="lorem ipsum")
717-
waitForIndex(client, "idx")
717+
waitForIndex(client, getattr(client.ft(), "index_name", "idx"))
718718

719719
# test spellcheck
720720
res = client.ft().spellcheck("impornant")
@@ -1304,6 +1304,31 @@ def test_fields_as_name(client):
13041304
assert "25" == total[0].just_a_number
13051305

13061306

1307+
@pytest.mark.redismod
1308+
def test_casesensitive(client):
1309+
# create index
1310+
SCHEMA = (TagField("t", case_sensitive=False),)
1311+
client.ft().create_index(SCHEMA)
1312+
client.ft().client.hset("1", "t", "HELLO")
1313+
client.ft().client.hset("2", "t", "hello")
1314+
1315+
res = client.ft().search("@t:{HELLO}").docs
1316+
1317+
assert 2 == len(res)
1318+
assert "1" == res[0].id
1319+
assert "2" == res[1].id
1320+
1321+
# create casesensitive index
1322+
client.ft().dropindex()
1323+
SCHEMA = (TagField("t", case_sensitive=True),)
1324+
client.ft().create_index(SCHEMA)
1325+
waitForIndex(client, getattr(client.ft(), "index_name", "idx"))
1326+
1327+
res = client.ft().search("@t:{HELLO}").docs
1328+
assert 1 == len(res)
1329+
assert "1" == res[0].id
1330+
1331+
13071332
@pytest.mark.redismod
13081333
@skip_ifmodversion_lt("2.2.0", "search")
13091334
def test_search_return_fields(client):
@@ -1321,7 +1346,7 @@ def test_search_return_fields(client):
13211346
NumericField("$.flt"),
13221347
)
13231348
client.ft().create_index(SCHEMA, definition=definition)
1324-
waitForIndex(client, "idx")
1349+
waitForIndex(client, getattr(client.ft(), "index_name", "idx"))
13251350

13261351
total = client.ft().search(Query("*").return_field("$.t", as_field="txt")).docs
13271352
assert 1 == len(total)

0 commit comments

Comments
 (0)