Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""


class CustomViews(QuerysetEndpoint):
class CustomViews(QuerysetEndpoint[CustomViewItem]):
def __init__(self, parent_srv):
super(CustomViews, self).__init__(parent_srv)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
PathOrFileW = Union[FilePath, FileObjectW]


class Datasources(QuerysetEndpoint):
class Datasources(QuerysetEndpoint[DatasourceItem]):
def __init__(self, parent_srv: "Server") -> None:
super(Datasources, self).__init__(parent_srv)
self._resource_tagger = _ResourceTagger(parent_srv)
Expand Down
27 changes: 21 additions & 6 deletions tableauserverclient/server/endpoint/endpoint.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from tableauserverclient import datetime_helpers as datetime

import abc
from packaging.version import Version
from functools import wraps
from xml.etree.ElementTree import ParseError
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Generic, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union

from tableauserverclient.models.pagination_item import PaginationItem
from tableauserverclient.server.request_options import RequestOptions

from .exceptions import (
ServerResponseError,
Expand Down Expand Up @@ -300,25 +304,36 @@ def wrapper(self, *args, **kwargs):
return _decorator


class QuerysetEndpoint(Endpoint):
T = TypeVar("T")


class QuerysetEndpoint(Endpoint, Generic[T]):
@api(version="2.0")
def all(self, *args, **kwargs):
def all(self, *args, **kwargs) -> QuerySet[T]:
if args or kwargs:
raise ValueError(".all method takes no arguments.")
queryset = QuerySet(self)
return queryset

@api(version="2.0")
def filter(self, *_, **kwargs) -> QuerySet:
def filter(self, *_, **kwargs) -> QuerySet[T]:
if _:
raise RuntimeError("Only keyword arguments accepted.")
queryset = QuerySet(self).filter(**kwargs)
return queryset

@api(version="2.0")
def order_by(self, *args, **kwargs):
def order_by(self, *args, **kwargs) -> QuerySet[T]:
if kwargs:
raise ValueError(".order_by does not accept keyword arguments.")
queryset = QuerySet(self).order_by(*args)
return queryset

@api(version="2.0")
def paginate(self, **kwargs):
def paginate(self, **kwargs) -> QuerySet[T]:
queryset = QuerySet(self).paginate(**kwargs)
return queryset

@abc.abstractmethod
def get(self, request_options: RequestOptions) -> Tuple[List[T], PaginationItem]:
raise NotImplementedError(f".get has not been implemented for {self.__class__.__qualname__}")
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/flow_runs_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..request_options import RequestOptions


class FlowRuns(QuerysetEndpoint):
class FlowRuns(QuerysetEndpoint[FlowRunItem]):
def __init__(self, parent_srv: "Server") -> None:
super(FlowRuns, self).__init__(parent_srv)
return None
Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/flows_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
PathOrFileW = Union[FilePath, FileObjectW]


class Flows(QuerysetEndpoint):
class Flows(QuerysetEndpoint[FlowItem]):
def __init__(self, parent_srv):
super(Flows, self).__init__(parent_srv)
self._resource_tagger = _ResourceTagger(parent_srv)
Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/groups_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..request_options import RequestOptions


class Groups(QuerysetEndpoint):
class Groups(QuerysetEndpoint[GroupItem]):
@property
def baseurl(self) -> str:
return "{0}/sites/{1}/groups".format(self.parent_srv.baseurl, self.parent_srv.site_id)
Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/jobs_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import List, Optional, Tuple, Union


class Jobs(QuerysetEndpoint):
class Jobs(QuerysetEndpoint[JobItem]):
@property
def baseurl(self):
return "{0}/sites/{1}/jobs".format(self.parent_srv.baseurl, self.parent_srv.site_id)
Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/metrics_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from tableauserverclient.helpers.logging import logger


class Metrics(QuerysetEndpoint):
class Metrics(QuerysetEndpoint[MetricItem]):
def __init__(self, parent_srv: "Server") -> None:
super(Metrics, self).__init__(parent_srv)
self._resource_tagger = _ResourceTagger(parent_srv)
Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/projects_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tableauserverclient.helpers.logging import logger


class Projects(QuerysetEndpoint):
class Projects(QuerysetEndpoint[ProjectItem]):
def __init__(self, parent_srv: "Server") -> None:
super(Projects, self).__init__(parent_srv)

Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/users_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tableauserverclient.helpers.logging import logger


class Users(QuerysetEndpoint):
class Users(QuerysetEndpoint[UserItem]):
@property
def baseurl(self) -> str:
return "{0}/sites/{1}/users".format(self.parent_srv.baseurl, self.parent_srv.site_id)
Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/views_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)


class Views(QuerysetEndpoint):
class Views(QuerysetEndpoint[ViewItem]):
def __init__(self, parent_srv):
super(Views, self).__init__(parent_srv)
self._resource_tagger = _ResourceTagger(parent_srv)
Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/workbooks_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
PathOrFileW = Union[FilePath, FileObjectW]


class Workbooks(QuerysetEndpoint):
class Workbooks(QuerysetEndpoint[WorkbookItem]):
def __init__(self, parent_srv: "Server") -> None:
super(Workbooks, self).__init__(parent_srv)
self._resource_tagger = _ResourceTagger(parent_srv)
Expand Down
91 changes: 59 additions & 32 deletions tableauserverclient/server/query.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
from typing import Tuple
from .filter import Filter
from .request_options import RequestOptions
from .sort import Sort
from collections.abc import Sized
from itertools import count
from typing import Iterable, Iterator, List, Optional, Protocol, Tuple, TYPE_CHECKING, TypeVar, overload
from tableauserverclient.models.pagination_item import PaginationItem
from tableauserverclient.server.filter import Filter
from tableauserverclient.server.request_options import RequestOptions
from tableauserverclient.server.sort import Sort
import math

from typing_extensions import Self

if TYPE_CHECKING:
from tableauserverclient.server.endpoint import QuerysetEndpoint

T = TypeVar("T")


class Slice(Protocol):
start: Optional[int]
step: Optional[int]
stop: Optional[int]


def to_camel_case(word: str) -> str:
return word.split("_")[0] + "".join(x.capitalize() or "_" for x in word.split("_")[1:])
Expand All @@ -16,28 +32,35 @@ def to_camel_case(word: str) -> str:
"""


class QuerySet:
def __init__(self, model):
class QuerySet(Iterable[T], Sized):
def __init__(self, model: "QuerysetEndpoint[T]") -> None:
self.model = model
self.request_options = RequestOptions()
self._result_cache = None
self._pagination_item = None
self._result_cache: List[T] = []
self._pagination_item = PaginationItem()

def __iter__(self):
def __iter__(self: Self) -> Iterator[T]:
# Not built to be re-entrant. Starts back at page 1, and empties
# the result cache.
self.request_options.pagenumber = 1
self._result_cache = None
total = self.total_available
size = self.page_size
yield from self._result_cache
# the result cache. Ensure the result_cache is empty to not yield
# items from prior usage.
self._result_cache = []

# Loop through the subsequent pages.
for page in range(1, math.ceil(total / size)):
self.request_options.pagenumber = page + 1
self._result_cache = None
for page in count(1):
self.request_options.pagenumber = page
self._result_cache = []
self._fetch_all()
yield from self._result_cache
# Set result_cache to empty so the fetch will populate
if (page * self.page_size) >= len(self):
return

@overload
def __getitem__(self: Self, k: Slice) -> List[T]:
...

@overload
def __getitem__(self: Self, k: int) -> T:
...

def __getitem__(self, k):
page = self.page_number
Expand Down Expand Up @@ -78,61 +101,65 @@ def __getitem__(self, k):
return self._result_cache[k % size]
elif k in range(self.total_available):
# Otherwise, check if k is even sensible to return
self._result_cache = None
self._result_cache = []
# Add one to k, otherwise it gets stuck at page boundaries, e.g. 100
self.request_options.pagenumber = max(1, math.ceil((k + 1) / size))
return self[k]
else:
# If k is unreasonable, raise an IndexError.
raise IndexError

def _fetch_all(self):
def _fetch_all(self: Self) -> None:
"""
Retrieve the data and store result and pagination item in cache
"""
if self._result_cache is None:
if not self._result_cache:
self._result_cache, self._pagination_item = self.model.get(self.request_options)

def __len__(self) -> int:
def __len__(self: Self) -> int:
return self.total_available

@property
def total_available(self) -> int:
def total_available(self: Self) -> int:
self._fetch_all()
return self._pagination_item.total_available

@property
def page_number(self) -> int:
def page_number(self: Self) -> int:
self._fetch_all()
return self._pagination_item.page_number

@property
def page_size(self) -> int:
def page_size(self: Self) -> int:
self._fetch_all()
return self._pagination_item.page_size

def filter(self, *invalid, **kwargs):
def filter(self: Self, *invalid, **kwargs) -> Self:
if invalid:
raise RuntimeError(f"Only accepts keyword arguments.")
raise RuntimeError("Only accepts keyword arguments.")
for kwarg_key, value in kwargs.items():
field_name, operator = self._parse_shorthand_filter(kwarg_key)
self.request_options.filter.add(Filter(field_name, operator, value))
return self

def order_by(self, *args):
def order_by(self: Self, *args) -> Self:
for arg in args:
field_name, direction = self._parse_shorthand_sort(arg)
self.request_options.sort.add(Sort(field_name, direction))
return self

def paginate(self, **kwargs):
def paginate(self: Self, **kwargs) -> Self:
if "page_number" in kwargs:
self.request_options.pagenumber = kwargs["page_number"]
if "page_size" in kwargs:
self.request_options.pagesize = kwargs["page_size"]
return self

def _parse_shorthand_filter(self, key: str) -> Tuple[str, str]:
def with_page_size(self: Self, value: int) -> Self:
self.request_options.pagesize = value
return self

def _parse_shorthand_filter(self: Self, key: str) -> Tuple[str, str]:
tokens = key.split("__", 1)
if len(tokens) == 1:
operator = RequestOptions.Operator.Equals
Expand All @@ -146,7 +173,7 @@ def _parse_shorthand_filter(self, key: str) -> Tuple[str, str]:
raise ValueError("Field name `{}` is not valid.".format(field))
return (field, operator)

def _parse_shorthand_sort(self, key: str) -> Tuple[str, str]:
def _parse_shorthand_sort(self: Self, key: str) -> Tuple[str, str]:
direction = RequestOptions.Direction.Asc
if key.startswith("-"):
direction = RequestOptions.Direction.Desc
Expand Down
8 changes: 8 additions & 0 deletions test/test_request_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,11 @@ def test_filtering_parameters(self) -> None:
self.assertIn("value2", query_params["name2$"])
self.assertIn("type", query_params)
self.assertIn("tabloid", query_params["type"])

def test_queryset_pagesize(self) -> None:
for page_size in (1, 10, 100, 1000):
with self.subTest(page_size):
with requests_mock.mock() as m:
m.get(f"{self.baseurl}/views?pageSize={page_size}", text=SLICING_QUERYSET_PAGE_1.read_text())
queryset = self.server.views.all().with_page_size(page_size)
_ = list(queryset)