diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 3db5542ae1..061e69c235 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -103,6 +103,7 @@ def __init__(self, query="*"): self._query = query self._aggregateplan = [] self._loadfields = [] + self._loadall = False self._limit = Limit() self._max = 0 self._with_schema = False @@ -116,9 +117,13 @@ def load(self, *fields): ### Parameters - - **fields**: One or more fields in the format of `@field` + - **fields**: If fields not specified, all the fields will be loaded. + Otherwise, fields should be given in the format of `@field`. """ - self._loadfields.extend(fields) + if fields: + self._loadfields.extend(fields) + else: + self._loadall = True return self def group_by(self, fields, *reducers): @@ -308,7 +313,10 @@ def build_args(self): if self._cursor: ret += self._cursor - if self._loadfields: + if self._loadall: + ret.append("LOAD") + ret.append("*") + elif self._loadfields: ret.append("LOAD") ret.append(str(len(self._loadfields))) ret.extend(self._loadfields) diff --git a/tests/test_search.py b/tests/test_search.py index 5b6a66009a..1a22b665a8 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1054,6 +1054,11 @@ def test_aggregations_load(client): res = client.ft().aggregate(req) assert res.rows[0] == ["t2", "world"] + # load all + req = aggregations.AggregateRequest("*").load() + res = client.ft().aggregate(req) + assert res.rows[0] == ["t1", "hello", "t2", "world"] + @pytest.mark.redismod def test_aggregations_apply(client):