11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
-
14
+ import enum
15
15
import logging
16
16
import re
17
- from typing import TYPE_CHECKING , Any , Collection , Iterable , List , Optional , Set , Tuple
17
+ from collections import deque
18
+ from dataclasses import dataclass
19
+ from typing import (
20
+ TYPE_CHECKING ,
21
+ Any ,
22
+ Collection ,
23
+ Iterable ,
24
+ List ,
25
+ Optional ,
26
+ Set ,
27
+ Tuple ,
28
+ Union ,
29
+ )
18
30
19
31
import attr
20
32
27
39
LoggingTransaction ,
28
40
)
29
41
from synapse .storage .databases .main .events_worker import EventRedactBehaviour
30
- from synapse .storage .engines import BaseDatabaseEngine , PostgresEngine , Sqlite3Engine
42
+ from synapse .storage .engines import PostgresEngine , Sqlite3Engine
31
43
from synapse .types import JsonDict
32
44
33
45
if TYPE_CHECKING :
@@ -421,8 +433,6 @@ async def search_msgs(
421
433
"""
422
434
clauses = []
423
435
424
- search_query = _parse_query (self .database_engine , search_term )
425
-
426
436
args : List [Any ] = []
427
437
428
438
# Make sure we don't explode because the person is in too many rooms.
@@ -444,20 +454,24 @@ async def search_msgs(
444
454
count_clauses = clauses
445
455
446
456
if isinstance (self .database_engine , PostgresEngine ):
457
+ search_query = search_term
458
+ tsquery_func = self .database_engine .tsquery_func
447
459
sql = (
448
- "SELECT ts_rank_cd(vector, to_tsquery ('english', ?)) AS rank,"
460
+ f "SELECT ts_rank_cd(vector, { tsquery_func } ('english', ?)) AS rank,"
449
461
" room_id, event_id"
450
462
" FROM event_search"
451
- " WHERE vector @@ to_tsquery ('english', ?)"
463
+ f " WHERE vector @@ { tsquery_func } ('english', ?)"
452
464
)
453
465
args = [search_query , search_query ] + args
454
466
455
467
count_sql = (
456
468
"SELECT room_id, count(*) as count FROM event_search"
457
- " WHERE vector @@ to_tsquery ('english', ?)"
469
+ f " WHERE vector @@ { tsquery_func } ('english', ?)"
458
470
)
459
471
count_args = [search_query ] + count_args
460
472
elif isinstance (self .database_engine , Sqlite3Engine ):
473
+ search_query = _parse_query_for_sqlite (search_term )
474
+
461
475
sql = (
462
476
"SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
463
477
" FROM event_search"
@@ -469,7 +483,7 @@ async def search_msgs(
469
483
"SELECT room_id, count(*) as count FROM event_search"
470
484
" WHERE value MATCH ?"
471
485
)
472
- count_args = [search_term ] + count_args
486
+ count_args = [search_query ] + count_args
473
487
else :
474
488
# This should be unreachable.
475
489
raise Exception ("Unrecognized database engine" )
@@ -501,7 +515,9 @@ async def search_msgs(
501
515
502
516
highlights = None
503
517
if isinstance (self .database_engine , PostgresEngine ):
504
- highlights = await self ._find_highlights_in_postgres (search_query , events )
518
+ highlights = await self ._find_highlights_in_postgres (
519
+ search_query , events , tsquery_func
520
+ )
505
521
506
522
count_sql += " GROUP BY room_id"
507
523
@@ -510,7 +526,6 @@ async def search_msgs(
510
526
)
511
527
512
528
count = sum (row ["count" ] for row in count_results if row ["room_id" ] in room_ids )
513
-
514
529
return {
515
530
"results" : [
516
531
{"event" : event_map [r ["event_id" ]], "rank" : r ["rank" ]}
@@ -542,9 +557,6 @@ async def search_rooms(
542
557
Each match as a dictionary.
543
558
"""
544
559
clauses = []
545
-
546
- search_query = _parse_query (self .database_engine , search_term )
547
-
548
560
args : List [Any ] = []
549
561
550
562
# Make sure we don't explode because the person is in too many rooms.
@@ -582,20 +594,23 @@ async def search_rooms(
582
594
args .extend ([origin_server_ts , origin_server_ts , stream ])
583
595
584
596
if isinstance (self .database_engine , PostgresEngine ):
597
+ search_query = search_term
598
+ tsquery_func = self .database_engine .tsquery_func
585
599
sql = (
586
- "SELECT ts_rank_cd(vector, to_tsquery ('english', ?)) as rank,"
600
+ f "SELECT ts_rank_cd(vector, { tsquery_func } ('english', ?)) as rank,"
587
601
" origin_server_ts, stream_ordering, room_id, event_id"
588
602
" FROM event_search"
589
- " WHERE vector @@ to_tsquery ('english', ?) AND "
603
+ f " WHERE vector @@ { tsquery_func } ('english', ?) AND "
590
604
)
591
605
args = [search_query , search_query ] + args
592
606
593
607
count_sql = (
594
608
"SELECT room_id, count(*) as count FROM event_search"
595
- " WHERE vector @@ to_tsquery ('english', ?) AND "
609
+ f " WHERE vector @@ { tsquery_func } ('english', ?) AND "
596
610
)
597
611
count_args = [search_query ] + count_args
598
612
elif isinstance (self .database_engine , Sqlite3Engine ):
613
+
599
614
# We use CROSS JOIN here to ensure we use the right indexes.
600
615
# https://sqlite.org/optoverview.html#crossjoin
601
616
#
@@ -614,13 +629,14 @@ async def search_rooms(
614
629
" CROSS JOIN events USING (event_id)"
615
630
" WHERE "
616
631
)
632
+ search_query = _parse_query_for_sqlite (search_term )
617
633
args = [search_query ] + args
618
634
619
635
count_sql = (
620
636
"SELECT room_id, count(*) as count FROM event_search"
621
637
" WHERE value MATCH ? AND "
622
638
)
623
- count_args = [search_term ] + count_args
639
+ count_args = [search_query ] + count_args
624
640
else :
625
641
# This should be unreachable.
626
642
raise Exception ("Unrecognized database engine" )
@@ -660,7 +676,9 @@ async def search_rooms(
660
676
661
677
highlights = None
662
678
if isinstance (self .database_engine , PostgresEngine ):
663
- highlights = await self ._find_highlights_in_postgres (search_query , events )
679
+ highlights = await self ._find_highlights_in_postgres (
680
+ search_query , events , tsquery_func
681
+ )
664
682
665
683
count_sql += " GROUP BY room_id"
666
684
@@ -686,7 +704,7 @@ async def search_rooms(
686
704
}
687
705
688
706
async def _find_highlights_in_postgres (
689
- self , search_query : str , events : List [EventBase ]
707
+ self , search_query : str , events : List [EventBase ], tsquery_func : str
690
708
) -> Set [str ]:
691
709
"""Given a list of events and a search term, return a list of words
692
710
that match from the content of the event.
@@ -697,6 +715,7 @@ async def _find_highlights_in_postgres(
697
715
Args:
698
716
search_query
699
717
events: A list of events
718
+ tsquery_func: The tsquery_* function to use when making queries
700
719
701
720
Returns:
702
721
A set of strings.
@@ -729,7 +748,7 @@ def f(txn: LoggingTransaction) -> Set[str]:
729
748
while stop_sel in value :
730
749
stop_sel += ">"
731
750
732
- query = "SELECT ts_headline(?, to_tsquery ('english', ?), %s)" % (
751
+ query = f "SELECT ts_headline(?, { tsquery_func } ('english', ?), %s)" % (
733
752
_to_postgres_options (
734
753
{
735
754
"StartSel" : start_sel ,
@@ -760,20 +779,128 @@ def _to_postgres_options(options_dict: JsonDict) -> str:
760
779
return "'%s'" % ("," .join ("%s=%s" % (k , v ) for k , v in options_dict .items ()),)
761
780
762
781
763
- def _parse_query (database_engine : BaseDatabaseEngine , search_term : str ) -> str :
764
- """Takes a plain unicode string from the user and converts it into a form
765
- that can be passed to database.
766
- We use this so that we can add prefix matching, which isn't something
767
- that is supported by default.
782
+ @dataclass
783
+ class Phrase :
784
+ phrase : List [str ]
785
+
786
+
787
+ class SearchToken (enum .Enum ):
788
+ Not = enum .auto ()
789
+ Or = enum .auto ()
790
+ And = enum .auto ()
791
+
792
+
793
+ Token = Union [str , Phrase , SearchToken ]
794
+ TokenList = List [Token ]
795
+
796
+
797
+ def _is_stop_word (word : str ) -> bool :
798
+ # TODO Pull these out of the dictionary:
799
+ # https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop
800
+ return word in {"the" , "a" , "you" , "me" , "and" , "but" }
801
+
802
+
803
+ def _tokenize_query (query : str ) -> TokenList :
804
+ """
805
+ Convert the user-supplied `query` into a TokenList, which can be translated into
806
+ some DB-specific syntax.
807
+
808
+ The following constructs are supported:
809
+
810
+ - phrase queries using "double quotes"
811
+ - case-insensitive `or` and `and` operators
812
+ - negation of a keyword via unary `-`
813
+ - unary hyphen to denote NOT e.g. 'include -exclude'
814
+
815
+ The following differs from websearch_to_tsquery:
816
+
817
+ - Stop words are not removed.
818
+ - Unclosed phrases are treated differently.
819
+
820
+ """
821
+ tokens : TokenList = []
822
+
823
+ # Find phrases.
824
+ in_phrase = False
825
+ parts = deque (query .split ('"' ))
826
+ for i , part in enumerate (parts ):
827
+ # The contents inside double quotes is treated as a phrase, a trailing
828
+ # double quote is not implied.
829
+ in_phrase = bool (i % 2 ) and i != (len (parts ) - 1 )
830
+
831
+ # Pull out the individual words, discarding any non-word characters.
832
+ words = deque (re .findall (r"([\w\-]+)" , part , re .UNICODE ))
833
+
834
+ # Phrases have simplified handling of words.
835
+ if in_phrase :
836
+ # Skip stop words.
837
+ phrase = [word for word in words if not _is_stop_word (word )]
838
+
839
+ # Consecutive words are implicitly ANDed together.
840
+ if tokens and tokens [- 1 ] not in (SearchToken .Not , SearchToken .Or ):
841
+ tokens .append (SearchToken .And )
842
+
843
+ # Add the phrase.
844
+ tokens .append (Phrase (phrase ))
845
+ continue
846
+
847
+ # Otherwise, not in a phrase.
848
+ while words :
849
+ word = words .popleft ()
850
+
851
+ if word .startswith ("-" ):
852
+ tokens .append (SearchToken .Not )
853
+
854
+ # If there's more word, put it back to be processed again.
855
+ word = word [1 :]
856
+ if word :
857
+ words .appendleft (word )
858
+ elif word .lower () == "or" :
859
+ tokens .append (SearchToken .Or )
860
+ else :
861
+ # Skip stop words.
862
+ if _is_stop_word (word ):
863
+ continue
864
+
865
+ # Consecutive words are implicitly ANDed together.
866
+ if tokens and tokens [- 1 ] not in (SearchToken .Not , SearchToken .Or ):
867
+ tokens .append (SearchToken .And )
868
+
869
+ # Add the search term.
870
+ tokens .append (word )
871
+
872
+ return tokens
873
+
874
+
875
+ def _tokens_to_sqlite_match_query (tokens : TokenList ) -> str :
876
+ """
877
+ Convert the list of tokens to a string suitable for passing to sqlite's MATCH.
878
+ Assume sqlite was compiled with enhanced query syntax.
879
+
880
+ Ref: https://www.sqlite.org/fts3.html#full_text_index_queries
768
881
"""
882
+ match_query = []
883
+ for token in tokens :
884
+ if isinstance (token , str ):
885
+ match_query .append (token )
886
+ elif isinstance (token , Phrase ):
887
+ match_query .append ('"' + " " .join (token .phrase ) + '"' )
888
+ elif token == SearchToken .Not :
889
+ # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search
890
+ # term has already been added before this.
891
+ match_query .append (" NOT " )
892
+ elif token == SearchToken .Or :
893
+ match_query .append (" OR " )
894
+ elif token == SearchToken .And :
895
+ match_query .append (" AND " )
896
+ else :
897
+ raise ValueError (f"unknown token { token } " )
898
+
899
+ return "" .join (match_query )
769
900
770
- # Pull out the individual words, discarding any non-word characters.
771
- results = re .findall (r"([\w\-]+)" , search_term , re .UNICODE )
772
901
773
- if isinstance (database_engine , PostgresEngine ):
774
- return " & " .join (result + ":*" for result in results )
775
- elif isinstance (database_engine , Sqlite3Engine ):
776
- return " & " .join (result + "*" for result in results )
777
- else :
778
- # This should be unreachable.
779
- raise Exception ("Unrecognized database engine" )
902
+ def _parse_query_for_sqlite (search_term : str ) -> str :
903
+ """Takes a plain unicode string from the user and converts it into a form
904
+ that can be passed to sqllite's matchinfo().
905
+ """
906
+ return _tokens_to_sqlite_match_query (_tokenize_query (search_term ))
0 commit comments