diff --git a/tableauserverclient/server/endpoint/custom_views_endpoint.py b/tableauserverclient/server/endpoint/custom_views_endpoint.py index 119580609..d1446b1fe 100644 --- a/tableauserverclient/server/endpoint/custom_views_endpoint.py +++ b/tableauserverclient/server/endpoint/custom_views_endpoint.py @@ -17,7 +17,7 @@ """ -class CustomViews(QuerysetEndpoint): +class CustomViews(QuerysetEndpoint[CustomViewItem]): def __init__(self, parent_srv): super(CustomViews, self).__init__(parent_srv) diff --git a/tableauserverclient/server/endpoint/datasources_endpoint.py b/tableauserverclient/server/endpoint/datasources_endpoint.py index 28226d280..da2ee3def 100644 --- a/tableauserverclient/server/endpoint/datasources_endpoint.py +++ b/tableauserverclient/server/endpoint/datasources_endpoint.py @@ -54,7 +54,7 @@ PathOrFileW = Union[FilePath, FileObjectW] -class Datasources(QuerysetEndpoint): +class Datasources(QuerysetEndpoint[DatasourceItem]): def __init__(self, parent_srv: "Server") -> None: super(Datasources, self).__init__(parent_srv) self._resource_tagger = _ResourceTagger(parent_srv) diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index 2b7f57069..d9dac47b2 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -1,9 +1,13 @@ from tableauserverclient import datetime_helpers as datetime +import abc from packaging.version import Version from functools import wraps from xml.etree.ElementTree import ParseError -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Generic, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union + +from tableauserverclient.models.pagination_item import PaginationItem +from tableauserverclient.server.request_options import RequestOptions from .exceptions import ( ServerResponseError, @@ -300,25 +304,36 @@ def wrapper(self, *args, **kwargs): return _decorator -class QuerysetEndpoint(Endpoint): +T = TypeVar("T") + + +class QuerysetEndpoint(Endpoint, Generic[T]): @api(version="2.0") - def all(self, *args, **kwargs): + def all(self, *args, **kwargs) -> QuerySet[T]: + if args or kwargs: + raise ValueError(".all method takes no arguments.") queryset = QuerySet(self) return queryset @api(version="2.0") - def filter(self, *_, **kwargs) -> QuerySet: + def filter(self, *_, **kwargs) -> QuerySet[T]: if _: raise RuntimeError("Only keyword arguments accepted.") queryset = QuerySet(self).filter(**kwargs) return queryset @api(version="2.0") - def order_by(self, *args, **kwargs): + def order_by(self, *args, **kwargs) -> QuerySet[T]: + if kwargs: + raise ValueError(".order_by does not accept keyword arguments.") queryset = QuerySet(self).order_by(*args) return queryset @api(version="2.0") - def paginate(self, **kwargs): + def paginate(self, **kwargs) -> QuerySet[T]: queryset = QuerySet(self).paginate(**kwargs) return queryset + + @abc.abstractmethod + def get(self, request_options: RequestOptions) -> Tuple[List[T], PaginationItem]: + raise NotImplementedError(f".get has not been implemented for {self.__class__.__qualname__}") diff --git a/tableauserverclient/server/endpoint/flow_runs_endpoint.py b/tableauserverclient/server/endpoint/flow_runs_endpoint.py index 63b32e006..ea45ce802 100644 --- a/tableauserverclient/server/endpoint/flow_runs_endpoint.py +++ b/tableauserverclient/server/endpoint/flow_runs_endpoint.py @@ -13,7 +13,7 @@ from ..request_options import RequestOptions -class FlowRuns(QuerysetEndpoint): +class FlowRuns(QuerysetEndpoint[FlowRunItem]): def __init__(self, parent_srv: "Server") -> None: super(FlowRuns, self).__init__(parent_srv) return None diff --git a/tableauserverclient/server/endpoint/flows_endpoint.py b/tableauserverclient/server/endpoint/flows_endpoint.py index 77b01c478..e392d807d 100644 --- a/tableauserverclient/server/endpoint/flows_endpoint.py +++ b/tableauserverclient/server/endpoint/flows_endpoint.py @@ -50,7 +50,7 @@ PathOrFileW = Union[FilePath, FileObjectW] -class Flows(QuerysetEndpoint): +class Flows(QuerysetEndpoint[FlowItem]): def __init__(self, parent_srv): super(Flows, self).__init__(parent_srv) self._resource_tagger = _ResourceTagger(parent_srv) diff --git a/tableauserverclient/server/endpoint/groups_endpoint.py b/tableauserverclient/server/endpoint/groups_endpoint.py index ab5f672d1..caa928f88 100644 --- a/tableauserverclient/server/endpoint/groups_endpoint.py +++ b/tableauserverclient/server/endpoint/groups_endpoint.py @@ -14,7 +14,7 @@ from ..request_options import RequestOptions -class Groups(QuerysetEndpoint): +class Groups(QuerysetEndpoint[GroupItem]): @property def baseurl(self) -> str: return "{0}/sites/{1}/groups".format(self.parent_srv.baseurl, self.parent_srv.site_id) diff --git a/tableauserverclient/server/endpoint/jobs_endpoint.py b/tableauserverclient/server/endpoint/jobs_endpoint.py index d0b865e21..74770e22b 100644 --- a/tableauserverclient/server/endpoint/jobs_endpoint.py +++ b/tableauserverclient/server/endpoint/jobs_endpoint.py @@ -11,7 +11,7 @@ from typing import List, Optional, Tuple, Union -class Jobs(QuerysetEndpoint): +class Jobs(QuerysetEndpoint[JobItem]): @property def baseurl(self): return "{0}/sites/{1}/jobs".format(self.parent_srv.baseurl, self.parent_srv.site_id) diff --git a/tableauserverclient/server/endpoint/metrics_endpoint.py b/tableauserverclient/server/endpoint/metrics_endpoint.py index a0e984475..ab1ec5852 100644 --- a/tableauserverclient/server/endpoint/metrics_endpoint.py +++ b/tableauserverclient/server/endpoint/metrics_endpoint.py @@ -18,7 +18,7 @@ from tableauserverclient.helpers.logging import logger -class Metrics(QuerysetEndpoint): +class Metrics(QuerysetEndpoint[MetricItem]): def __init__(self, parent_srv: "Server") -> None: super(Metrics, self).__init__(parent_srv) self._resource_tagger = _ResourceTagger(parent_srv) diff --git a/tableauserverclient/server/endpoint/projects_endpoint.py b/tableauserverclient/server/endpoint/projects_endpoint.py index 99bb2e39b..7645e72eb 100644 --- a/tableauserverclient/server/endpoint/projects_endpoint.py +++ b/tableauserverclient/server/endpoint/projects_endpoint.py @@ -16,7 +16,7 @@ from tableauserverclient.helpers.logging import logger -class Projects(QuerysetEndpoint): +class Projects(QuerysetEndpoint[ProjectItem]): def __init__(self, parent_srv: "Server") -> None: super(Projects, self).__init__(parent_srv) diff --git a/tableauserverclient/server/endpoint/users_endpoint.py b/tableauserverclient/server/endpoint/users_endpoint.py index e8c5cc962..a84ca7399 100644 --- a/tableauserverclient/server/endpoint/users_endpoint.py +++ b/tableauserverclient/server/endpoint/users_endpoint.py @@ -11,7 +11,7 @@ from tableauserverclient.helpers.logging import logger -class Users(QuerysetEndpoint): +class Users(QuerysetEndpoint[UserItem]): @property def baseurl(self) -> str: return "{0}/sites/{1}/users".format(self.parent_srv.baseurl, self.parent_srv.site_id) diff --git a/tableauserverclient/server/endpoint/views_endpoint.py b/tableauserverclient/server/endpoint/views_endpoint.py index 9c4b90657..87a77053f 100644 --- a/tableauserverclient/server/endpoint/views_endpoint.py +++ b/tableauserverclient/server/endpoint/views_endpoint.py @@ -21,7 +21,7 @@ ) -class Views(QuerysetEndpoint): +class Views(QuerysetEndpoint[ViewItem]): def __init__(self, parent_srv): super(Views, self).__init__(parent_srv) self._resource_tagger = _ResourceTagger(parent_srv) diff --git a/tableauserverclient/server/endpoint/workbooks_endpoint.py b/tableauserverclient/server/endpoint/workbooks_endpoint.py index bc535b2d6..5b4b29969 100644 --- a/tableauserverclient/server/endpoint/workbooks_endpoint.py +++ b/tableauserverclient/server/endpoint/workbooks_endpoint.py @@ -56,7 +56,7 @@ PathOrFileW = Union[FilePath, FileObjectW] -class Workbooks(QuerysetEndpoint): +class Workbooks(QuerysetEndpoint[WorkbookItem]): def __init__(self, parent_srv: "Server") -> None: super(Workbooks, self).__init__(parent_srv) self._resource_tagger = _ResourceTagger(parent_srv) diff --git a/tableauserverclient/server/query.py b/tableauserverclient/server/query.py index c5613b2d6..51c34d082 100644 --- a/tableauserverclient/server/query.py +++ b/tableauserverclient/server/query.py @@ -1,9 +1,25 @@ -from typing import Tuple -from .filter import Filter -from .request_options import RequestOptions -from .sort import Sort +from collections.abc import Sized +from itertools import count +from typing import Iterable, Iterator, List, Optional, Protocol, Tuple, TYPE_CHECKING, TypeVar, overload +from tableauserverclient.models.pagination_item import PaginationItem +from tableauserverclient.server.filter import Filter +from tableauserverclient.server.request_options import RequestOptions +from tableauserverclient.server.sort import Sort import math +from typing_extensions import Self + +if TYPE_CHECKING: + from tableauserverclient.server.endpoint import QuerysetEndpoint + +T = TypeVar("T") + + +class Slice(Protocol): + start: Optional[int] + step: Optional[int] + stop: Optional[int] + def to_camel_case(word: str) -> str: return word.split("_")[0] + "".join(x.capitalize() or "_" for x in word.split("_")[1:]) @@ -16,28 +32,35 @@ def to_camel_case(word: str) -> str: """ -class QuerySet: - def __init__(self, model): +class QuerySet(Iterable[T], Sized): + def __init__(self, model: "QuerysetEndpoint[T]") -> None: self.model = model self.request_options = RequestOptions() - self._result_cache = None - self._pagination_item = None + self._result_cache: List[T] = [] + self._pagination_item = PaginationItem() - def __iter__(self): + def __iter__(self: Self) -> Iterator[T]: # Not built to be re-entrant. Starts back at page 1, and empties - # the result cache. - self.request_options.pagenumber = 1 - self._result_cache = None - total = self.total_available - size = self.page_size - yield from self._result_cache + # the result cache. Ensure the result_cache is empty to not yield + # items from prior usage. + self._result_cache = [] - # Loop through the subsequent pages. - for page in range(1, math.ceil(total / size)): - self.request_options.pagenumber = page + 1 - self._result_cache = None + for page in count(1): + self.request_options.pagenumber = page + self._result_cache = [] self._fetch_all() yield from self._result_cache + # Set result_cache to empty so the fetch will populate + if (page * self.page_size) >= len(self): + return + + @overload + def __getitem__(self: Self, k: Slice) -> List[T]: + ... + + @overload + def __getitem__(self: Self, k: int) -> T: + ... def __getitem__(self, k): page = self.page_number @@ -78,7 +101,7 @@ def __getitem__(self, k): return self._result_cache[k % size] elif k in range(self.total_available): # Otherwise, check if k is even sensible to return - self._result_cache = None + self._result_cache = [] # Add one to k, otherwise it gets stuck at page boundaries, e.g. 100 self.request_options.pagenumber = max(1, math.ceil((k + 1) / size)) return self[k] @@ -86,53 +109,57 @@ def __getitem__(self, k): # If k is unreasonable, raise an IndexError. raise IndexError - def _fetch_all(self): + def _fetch_all(self: Self) -> None: """ Retrieve the data and store result and pagination item in cache """ - if self._result_cache is None: + if not self._result_cache: self._result_cache, self._pagination_item = self.model.get(self.request_options) - def __len__(self) -> int: + def __len__(self: Self) -> int: return self.total_available @property - def total_available(self) -> int: + def total_available(self: Self) -> int: self._fetch_all() return self._pagination_item.total_available @property - def page_number(self) -> int: + def page_number(self: Self) -> int: self._fetch_all() return self._pagination_item.page_number @property - def page_size(self) -> int: + def page_size(self: Self) -> int: self._fetch_all() return self._pagination_item.page_size - def filter(self, *invalid, **kwargs): + def filter(self: Self, *invalid, **kwargs) -> Self: if invalid: - raise RuntimeError(f"Only accepts keyword arguments.") + raise RuntimeError("Only accepts keyword arguments.") for kwarg_key, value in kwargs.items(): field_name, operator = self._parse_shorthand_filter(kwarg_key) self.request_options.filter.add(Filter(field_name, operator, value)) return self - def order_by(self, *args): + def order_by(self: Self, *args) -> Self: for arg in args: field_name, direction = self._parse_shorthand_sort(arg) self.request_options.sort.add(Sort(field_name, direction)) return self - def paginate(self, **kwargs): + def paginate(self: Self, **kwargs) -> Self: if "page_number" in kwargs: self.request_options.pagenumber = kwargs["page_number"] if "page_size" in kwargs: self.request_options.pagesize = kwargs["page_size"] return self - def _parse_shorthand_filter(self, key: str) -> Tuple[str, str]: + def with_page_size(self: Self, value: int) -> Self: + self.request_options.pagesize = value + return self + + def _parse_shorthand_filter(self: Self, key: str) -> Tuple[str, str]: tokens = key.split("__", 1) if len(tokens) == 1: operator = RequestOptions.Operator.Equals @@ -146,7 +173,7 @@ def _parse_shorthand_filter(self, key: str) -> Tuple[str, str]: raise ValueError("Field name `{}` is not valid.".format(field)) return (field, operator) - def _parse_shorthand_sort(self, key: str) -> Tuple[str, str]: + def _parse_shorthand_sort(self: Self, key: str) -> Tuple[str, str]: direction = RequestOptions.Direction.Asc if key.startswith("-"): direction = RequestOptions.Direction.Desc diff --git a/test/test_request_option.py b/test/test_request_option.py index 40dd3345a..5ade81ea1 100644 --- a/test/test_request_option.py +++ b/test/test_request_option.py @@ -331,3 +331,11 @@ def test_filtering_parameters(self) -> None: self.assertIn("value2", query_params["name2$"]) self.assertIn("type", query_params) self.assertIn("tabloid", query_params["type"]) + + def test_queryset_pagesize(self) -> None: + for page_size in (1, 10, 100, 1000): + with self.subTest(page_size): + with requests_mock.mock() as m: + m.get(f"{self.baseurl}/views?pageSize={page_size}", text=SLICING_QUERYSET_PAGE_1.read_text()) + queryset = self.server.views.all().with_page_size(page_size) + _ = list(queryset)