diff --git a/tableauserverclient/server/pager.py b/tableauserverclient/server/pager.py index fede56012..ca9d83872 100644 --- a/tableauserverclient/server/pager.py +++ b/tableauserverclient/server/pager.py @@ -1,24 +1,23 @@ import copy from functools import partial -from typing import Generic, Iterable, Iterator, List, Optional, Protocol, Tuple, TypeVar, Union, runtime_checkable +from typing import Iterable, Iterator, List, Optional, Protocol, Tuple, TypeVar, Union, runtime_checkable from tableauserverclient.models.pagination_item import PaginationItem from tableauserverclient.server.request_options import RequestOptions T = TypeVar("T") -ReturnType = Tuple[List[T], PaginationItem] @runtime_checkable -class Endpoint(Protocol): - def get(self, req_options: Optional[RequestOptions], **kwargs) -> ReturnType: +class Endpoint(Protocol[T]): + def get(self, req_options: Optional[RequestOptions]) -> Tuple[List[T], PaginationItem]: ... @runtime_checkable -class CallableEndpoint(Protocol): - def __call__(self, __req_options: Optional[RequestOptions], **kwargs) -> ReturnType: +class CallableEndpoint(Protocol[T]): + def __call__(self, __req_options: Optional[RequestOptions], **kwargs) -> Tuple[List[T], PaginationItem]: ... @@ -33,7 +32,7 @@ class Pager(Iterable[T]): def __init__( self, - endpoint: Union[CallableEndpoint, Endpoint], + endpoint: Union[CallableEndpoint[T], Endpoint[T]], request_opts: Optional[RequestOptions] = None, **kwargs, ) -> None: diff --git a/test/test_pager.py b/test/test_pager.py index 7659f2725..c30352809 100644 --- a/test/test_pager.py +++ b/test/test_pager.py @@ -9,6 +9,7 @@ TEST_ASSET_DIR = os.path.join(os.path.dirname(__file__), "assets") +GET_VIEW_XML = os.path.join(TEST_ASSET_DIR, "view_get.xml") GET_XML_PAGE1 = os.path.join(TEST_ASSET_DIR, "workbook_get_page_1.xml") GET_XML_PAGE2 = os.path.join(TEST_ASSET_DIR, "workbook_get_page_2.xml") GET_XML_PAGE3 = os.path.join(TEST_ASSET_DIR, "workbook_get_page_3.xml") @@ -35,7 +36,7 @@ def setUp(self): self.baseurl = self.server.workbooks.baseurl - def test_pager_with_no_options(self): + def test_pager_with_no_options(self) -> None: with open(GET_XML_PAGE1, "rb") as f: page_1 = f.read().decode("utf-8") with open(GET_XML_PAGE2, "rb") as f: @@ -61,7 +62,7 @@ def test_pager_with_no_options(self): self.assertEqual(wb2.name, "Page2Workbook") self.assertEqual(wb3.name, "Page3Workbook") - def test_pager_with_options(self): + def test_pager_with_options(self) -> None: with open(GET_XML_PAGE1, "rb") as f: page_1 = f.read().decode("utf-8") with open(GET_XML_PAGE2, "rb") as f: @@ -102,14 +103,22 @@ def test_pager_with_options(self): wb3 = workbooks.pop() self.assertEqual(wb3.name, "Page3Workbook") - def test_pager_with_env_var(self): + def test_pager_with_env_var(self) -> None: with set_env(TSC_PAGE_SIZE="1000"): assert config.PAGE_SIZE == 1000 loop = TSC.Pager(self.server.workbooks) assert loop._options.pagesize == 1000 - def test_queryset_with_env_var(self): + def test_queryset_with_env_var(self) -> None: with set_env(TSC_PAGE_SIZE="1000"): assert config.PAGE_SIZE == 1000 loop = self.server.workbooks.all() assert loop.request_options.pagesize == 1000 + + def test_pager_view(self) -> None: + with open(GET_VIEW_XML, "rb") as f: + view_xml = f.read().decode("utf-8") + with requests_mock.mock() as m: + m.get(self.server.views.baseurl, text=view_xml) + for view in TSC.Pager(self.server.views): + assert view.name is not None