Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tableauserverclient/server/endpoint/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,17 @@ def wrapper(self, *args, **kwargs):

class QuerysetEndpoint(Endpoint, Generic[T]):
@api(version="2.0")
def all(self, *args, **kwargs) -> QuerySet[T]:
def all(self, *args, page_size: Optional[int] = None, **kwargs) -> QuerySet[T]:
if args or kwargs:
raise ValueError(".all method takes no arguments.")
queryset = QuerySet(self)
queryset = QuerySet(self, page_size=page_size)
return queryset

@api(version="2.0")
def filter(self, *_, **kwargs) -> QuerySet[T]:
def filter(self, *_, page_size: Optional[int] = None, **kwargs) -> QuerySet[T]:
if _:
raise RuntimeError("Only keyword arguments accepted.")
queryset = QuerySet(self).filter(**kwargs)
queryset = QuerySet(self, page_size=page_size).filter(**kwargs)
return queryset

@api(version="2.0")
Expand Down
19 changes: 10 additions & 9 deletions tableauserverclient/server/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def to_camel_case(word: str) -> str:


class QuerySet(Iterable[T], Sized):
def __init__(self, model: "QuerysetEndpoint[T]") -> None:
def __init__(self, model: "QuerysetEndpoint[T]", page_size: Optional[int] = None) -> None:
self.model = model
self.request_options = RequestOptions()
self.request_options = RequestOptions(pagesize=page_size or 100)
self._result_cache: List[T] = []
self._pagination_item = PaginationItem()

Expand Down Expand Up @@ -134,12 +134,15 @@ def page_size(self: Self) -> int:
self._fetch_all()
return self._pagination_item.page_size

def filter(self: Self, *invalid, **kwargs) -> Self:
def filter(self: Self, *invalid, page_size: Optional[int] = None, **kwargs) -> Self:
if invalid:
raise RuntimeError("Only accepts keyword arguments.")
for kwarg_key, value in kwargs.items():
field_name, operator = self._parse_shorthand_filter(kwarg_key)
self.request_options.filter.add(Filter(field_name, operator, value))

if page_size:
self.request_options.pagesize = page_size
return self

def order_by(self: Self, *args) -> Self:
Expand All @@ -155,11 +158,8 @@ def paginate(self: Self, **kwargs) -> Self:
self.request_options.pagesize = kwargs["page_size"]
return self

def with_page_size(self: Self, value: int) -> Self:
self.request_options.pagesize = value
return self

def _parse_shorthand_filter(self: Self, key: str) -> Tuple[str, str]:
@staticmethod
def _parse_shorthand_filter(key: str) -> Tuple[str, str]:
tokens = key.split("__", 1)
if len(tokens) == 1:
operator = RequestOptions.Operator.Equals
Expand All @@ -173,7 +173,8 @@ def _parse_shorthand_filter(self: Self, key: str) -> Tuple[str, str]:
raise ValueError("Field name `{}` is not valid.".format(field))
return (field, operator)

def _parse_shorthand_sort(self: Self, key: str) -> Tuple[str, str]:
@staticmethod
def _parse_shorthand_sort(key: str) -> Tuple[str, str]:
direction = RequestOptions.Direction.Asc
if key.startswith("-"):
direction = RequestOptions.Direction.Desc
Expand Down
23 changes: 21 additions & 2 deletions test/test_request_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,29 @@ def test_filtering_parameters(self) -> None:
self.assertIn("type", query_params)
self.assertIn("tabloid", query_params["type"])

def test_queryset_pagesize(self) -> None:
def test_queryset_endpoint_pagesize_all(self) -> None:
for page_size in (1, 10, 100, 1000):
with self.subTest(page_size):
with requests_mock.mock() as m:
m.get(f"{self.baseurl}/views?pageSize={page_size}", text=SLICING_QUERYSET_PAGE_1.read_text())
queryset = self.server.views.all().with_page_size(page_size)
queryset = self.server.views.all(page_size=page_size)
assert queryset.request_options.pagesize == page_size
_ = list(queryset)

def test_queryset_endpoint_pagesize_filter(self) -> None:
for page_size in (1, 10, 100, 1000):
with self.subTest(page_size):
with requests_mock.mock() as m:
m.get(f"{self.baseurl}/views?pageSize={page_size}", text=SLICING_QUERYSET_PAGE_1.read_text())
queryset = self.server.views.filter(page_size=page_size)
assert queryset.request_options.pagesize == page_size
_ = list(queryset)

def test_queryset_pagesize_filter(self) -> None:
for page_size in (1, 10, 100, 1000):
with self.subTest(page_size):
with requests_mock.mock() as m:
m.get(f"{self.baseurl}/views?pageSize={page_size}", text=SLICING_QUERYSET_PAGE_1.read_text())
queryset = self.server.views.all().filter(page_size=page_size)
assert queryset.request_options.pagesize == page_size
_ = list(queryset)