diff --git a/docs/python-api.rst b/docs/python-api.rst index 4b446bb7..2333d7f4 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -2308,6 +2308,9 @@ The ``.search()`` method also accepts the following optional parameters: ``where_args`` dictionary Arguments to use for ``:param`` placeholders in the extra WHERE clause +``include_rank`` bool + If set a ``rank`` column will be included with the BM25 ranking score - for FTS5 tables only. + ``quote`` bool Apply :ref:`FTS quoting rules ` to the search query, disabling advanced query syntax in a way that avoids surprising errors. diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index 6332dc26..27099d0c 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -2641,6 +2641,7 @@ def search( offset: Optional[int] = None, where: Optional[str] = None, where_args: Optional[Union[Iterable, dict]] = None, + include_rank: bool = False, quote: bool = False, ) -> Generator[dict, None, None]: """ @@ -2654,6 +2655,7 @@ def search( :param offset: Optional integer SQL offset. :param where: Extra SQL fragment for the WHERE clause :param where_args: Arguments to use for :param placeholders in the extra WHERE clause + :param include_rank: Select the search rank column in the final query :param quote: Apply quoting to disable any special characters in the search query See :ref:`python_api_fts_search`. @@ -2673,6 +2675,7 @@ def search( limit=limit, offset=offset, where=where, + include_rank=include_rank, ), args, ) diff --git a/tests/test_fts.py b/tests/test_fts.py index ecfcff9d..a35b9eef 100644 --- a/tests/test_fts.py +++ b/tests/test_fts.py @@ -1,6 +1,7 @@ import pytest from sqlite_utils import Database from sqlite_utils.utils import sqlite3 +from unittest.mock import ANY search_records = [ { @@ -126,6 +127,32 @@ def test_search_where_args_disallows_query(fresh_db): ) +def test_search_include_rank(fresh_db): + table = fresh_db["t"] + table.insert_all(search_records) + table.enable_fts(["text", "country"], fts_version="FTS5") + results = list(table.search("are", include_rank=True)) + assert results == [ + { + "rowid": 1, + "text": "tanuki are running tricksters", + "country": "Japan", + "not_searchable": "foo", + "rank": ANY, + }, + { + "rowid": 2, + "text": "racoons are biting trash pandas", + "country": "USA", + "not_searchable": "bar", + "rank": ANY, + }, + ] + assert isinstance(results[0]["rank"], float) + assert isinstance(results[1]["rank"], float) + assert results[0]["rank"] < results[1]["rank"] + + def test_enable_fts_table_names_containing_spaces(fresh_db): table = fresh_db["test"] table.insert({"column with spaces": "in its name"})