diff --git a/tableauserverclient/server/query.py b/tableauserverclient/server/query.py index feebc1a7e..801ad4a13 100644 --- a/tableauserverclient/server/query.py +++ b/tableauserverclient/server/query.py @@ -77,6 +77,7 @@ def __iter__(self: Self) -> Iterator[T]: for page in count(1): self.request_options.pagenumber = page self._result_cache = [] + self._pagination_item._page_number = None try: self._fetch_all() except ServerResponseError as e: @@ -85,6 +86,8 @@ def __iter__(self: Self) -> Iterator[T]: # up overrunning the total number of pages. Catch the # error and break out of the loop. raise StopIteration + if len(self._result_cache) == 0: + return yield from self._result_cache # If the length of the QuerySet is unknown, continue fetching until # the result cache is empty. @@ -139,6 +142,7 @@ def __getitem__(self, k): elif k in range(self.total_available): # Otherwise, check if k is even sensible to return self._result_cache = [] + self._pagination_item._page_number = None # Add one to k, otherwise it gets stuck at page boundaries, e.g. 100 self.request_options.pagenumber = max(1, math.ceil((k + 1) / size)) return self[k] @@ -150,7 +154,7 @@ def _fetch_all(self: Self) -> None: """ Retrieve the data and store result and pagination item in cache """ - if not self._result_cache: + if not self._result_cache and self._pagination_item._page_number is None: response = self.model.get(self.request_options) if isinstance(response, tuple): self._result_cache, self._pagination_item = response @@ -159,7 +163,7 @@ def _fetch_all(self: Self) -> None: self._pagination_item = PaginationItem() def __len__(self: Self) -> int: - return self.total_available or sys.maxsize + return sys.maxsize if self.total_available is None else self.total_available @property def total_available(self: Self) -> int: diff --git a/test/test_pager.py b/test/test_pager.py index c30352809..1836095bb 100644 --- a/test/test_pager.py +++ b/test/test_pager.py @@ -1,6 +1,7 @@ import contextlib import os import unittest +import xml.etree.ElementTree as ET import requests_mock @@ -122,3 +123,14 @@ def test_pager_view(self) -> None: m.get(self.server.views.baseurl, text=view_xml) for view in TSC.Pager(self.server.views): assert view.name is not None + + def test_queryset_no_matches(self) -> None: + elem = ET.Element("tsResponse", xmlns="http://tableau.com/api") + ET.SubElement(elem, "pagination", totalAvailable="0") + ET.SubElement(elem, "groups") + xml = ET.tostring(elem).decode("utf-8") + with requests_mock.mock() as m: + m.get(self.server.groups.baseurl, text=xml) + all_groups = self.server.groups.all() + groups = list(all_groups) + assert len(groups) == 0