diff --git a/.gitignore b/.gitignore index 5f5db36d7..cd6405447 100644 --- a/.gitignore +++ b/.gitignore @@ -76,6 +76,10 @@ target/ # pyenv .python-version +# poetry +poetry.lock +pyproject.toml + # celery beat schedule file celerybeat-schedule diff --git a/tableauserverclient/server/endpoint/datasources_endpoint.py b/tableauserverclient/server/endpoint/datasources_endpoint.py index 7a00157fe..e0086773c 100644 --- a/tableauserverclient/server/endpoint/datasources_endpoint.py +++ b/tableauserverclient/server/endpoint/datasources_endpoint.py @@ -1,9 +1,10 @@ -from .endpoint import Endpoint, api, parameter_added_in +from .endpoint import QuerysetEndpoint, api, parameter_added_in from .exceptions import InternalServerError, MissingRequiredFieldError from .permissions_endpoint import _PermissionsEndpoint from .fileuploads_endpoint import Fileuploads from .resource_tagger import _ResourceTagger from .. import RequestFactory, DatasourceItem, PaginationItem, ConnectionItem +from ..query import QuerySet from ...filesys_helpers import to_filename, make_download_path from ...models.job_item import JobItem @@ -21,7 +22,7 @@ logger = logging.getLogger('tableau.endpoint.datasources') -class Datasources(Endpoint): +class Datasources(QuerysetEndpoint): def __init__(self, parent_srv): super(Datasources, self).__init__(parent_srv) self._resource_tagger = _ResourceTagger(parent_srv) diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index 5e48b5cc2..821fdada6 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -1,7 +1,7 @@ from .exceptions import ServerResponseError, InternalServerError, NonXMLResponseError from functools import wraps from xml.etree.ElementTree import ParseError - +from ..query import QuerySet import logging try: @@ -165,3 +165,25 @@ def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) return wrapper return _decorator + + +class QuerysetEndpoint(Endpoint): + @api(version="2.0") + def all(self, *args, **kwargs): + queryset = QuerySet(self) + return queryset + + @api(version="2.0") + def filter(self, *args, **kwargs): + queryset = QuerySet(self).filter(**kwargs) + return queryset + + @api(version="2.0") + def order_by(self, *args, **kwargs): + queryset = QuerySet(self).order_by(*args) + return queryset + + @api(version="2.0") + def paginate(self, **kwargs): + queryset = QuerySet(self).paginate(**kwargs) + return queryset diff --git a/tableauserverclient/server/endpoint/users_endpoint.py b/tableauserverclient/server/endpoint/users_endpoint.py index 3ce1f16ab..71d59445a 100644 --- a/tableauserverclient/server/endpoint/users_endpoint.py +++ b/tableauserverclient/server/endpoint/users_endpoint.py @@ -1,4 +1,4 @@ -from .endpoint import Endpoint, api +from .endpoint import QuerysetEndpoint, api from .exceptions import MissingRequiredFieldError from .. import RequestFactory, UserItem, WorkbookItem, PaginationItem from ..pager import Pager @@ -9,7 +9,7 @@ logger = logging.getLogger('tableau.endpoint.users') -class Users(Endpoint): +class Users(QuerysetEndpoint): @property def baseurl(self): return "{0}/sites/{1}/users".format(self.parent_srv.baseurl, self.parent_srv.site_id) diff --git a/tableauserverclient/server/endpoint/views_endpoint.py b/tableauserverclient/server/endpoint/views_endpoint.py index 7c8a4768e..cd2792f5d 100644 --- a/tableauserverclient/server/endpoint/views_endpoint.py +++ b/tableauserverclient/server/endpoint/views_endpoint.py @@ -1,4 +1,4 @@ -from .endpoint import Endpoint, api +from .endpoint import QuerysetEndpoint, api from .exceptions import MissingRequiredFieldError from .resource_tagger import _ResourceTagger from .permissions_endpoint import _PermissionsEndpoint @@ -10,7 +10,7 @@ logger = logging.getLogger('tableau.endpoint.views') -class Views(Endpoint): +class Views(QuerysetEndpoint): def __init__(self, parent_srv): super(Views, self).__init__(parent_srv) self._resource_tagger = _ResourceTagger(parent_srv) diff --git a/tableauserverclient/server/endpoint/workbooks_endpoint.py b/tableauserverclient/server/endpoint/workbooks_endpoint.py index 82a5f9cd0..3ab483352 100644 --- a/tableauserverclient/server/endpoint/workbooks_endpoint.py +++ b/tableauserverclient/server/endpoint/workbooks_endpoint.py @@ -1,4 +1,4 @@ -from .endpoint import Endpoint, api, parameter_added_in +from .endpoint import QuerysetEndpoint, api, parameter_added_in from .exceptions import InternalServerError, MissingRequiredFieldError from .permissions_endpoint import _PermissionsEndpoint from .fileuploads_endpoint import Fileuploads @@ -21,7 +21,7 @@ logger = logging.getLogger('tableau.endpoint.workbooks') -class Workbooks(Endpoint): +class Workbooks(QuerysetEndpoint): def __init__(self, parent_srv): super(Workbooks, self).__init__(parent_srv) self._resource_tagger = _ResourceTagger(parent_srv) @@ -37,8 +37,10 @@ def get(self, req_options=None): logger.info('Querying all workbooks on site') url = self.baseurl server_response = self.get_request(url, req_options) - pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace) - all_workbook_items = WorkbookItem.from_response(server_response.content, self.parent_srv.namespace) + pagination_item = PaginationItem.from_response( + server_response.content, self.parent_srv.namespace) + all_workbook_items = WorkbookItem.from_response( + server_response.content, self.parent_srv.namespace) return all_workbook_items, pagination_item # Get 1 workbook diff --git a/tableauserverclient/server/query.py b/tableauserverclient/server/query.py new file mode 100644 index 000000000..c8ba5e6c6 --- /dev/null +++ b/tableauserverclient/server/query.py @@ -0,0 +1,89 @@ +from .request_options import RequestOptions +from .filter import Filter +from .sort import Sort + + +def to_camel_case(word): + return word.split('_')[0] + ''.join(x.capitalize() or '_' for x in word.split('_')[1:]) + + +class QuerySet: + + def __init__(self, model): + self.model = model + self.request_options = RequestOptions() + self._result_cache = None + self._pagination_item = None + + def __iter__(self): + self._fetch_all() + return iter(self._result_cache) + + def __getitem__(self, k): + return list(self)[k] + + def _fetch_all(self): + """ + Retrieve the data and store result and pagination item in cache + """ + if self._result_cache is None: + self._result_cache, self._pagination_item = self.model.get(self.request_options) + + @property + def total_available(self): + self._fetch_all() + return self._pagination_item.total_available + + @property + def page_number(self): + self._fetch_all() + return self._pagination_item.page_number + + @property + def page_size(self): + self._fetch_all() + return self._pagination_item.page_size + + def filter(self, **kwargs): + for kwarg_key, value in kwargs.items(): + field_name, operator = self._parse_shorthand_filter(kwarg_key) + self.request_options.filter.add(Filter(field_name, operator, value)) + return self + + def order_by(self, *args): + for arg in args: + field_name, direction = self._parse_shorthand_sort(arg) + self.request_options.sort.add(Sort(field_name, direction)) + return self + + def paginate(self, **kwargs): + if "page_number" in kwargs: + self.request_options.pagenumber = kwargs["page_number"] + if "page_size" in kwargs: + self.request_options.pagesize = kwargs["page_size"] + return self + + def _parse_shorthand_filter(self, key): + tokens = key.split("__", 1) + if len(tokens) == 1: + operator = RequestOptions.Operator.Equals + else: + operator = tokens[1] + if operator not in RequestOptions.Operator.__dict__.values(): + raise ValueError("Operator `{}` is not valid.".format(operator)) + + field = to_camel_case(tokens[0]) + if field not in RequestOptions.Field.__dict__.values(): + raise ValueError("Field name `{}` is not valid.".format(field)) + return (field, operator) + + def _parse_shorthand_sort(self, key): + direction = RequestOptions.Direction.Asc + if key.startswith("-"): + direction = RequestOptions.Direction.Desc + key = key[1:] + + key = to_camel_case(key) + if key not in RequestOptions.Field.__dict__.values(): + raise ValueError("Sort key name %s is not valid.", key) + return (key, direction) diff --git a/test/test_request_option.py b/test/test_request_option.py index c5afcc3b2..e738a8eca 100644 --- a/test/test_request_option.py +++ b/test/test_request_option.py @@ -76,6 +76,17 @@ def test_filter_equals(self): self.assertEqual('RESTAPISample', matching_workbooks[0].name) self.assertEqual('RESTAPISample', matching_workbooks[1].name) + def test_filter_equals_shorthand(self): + with open(FILTER_EQUALS, 'rb') as f: + response_xml = f.read().decode('utf-8') + with requests_mock.mock() as m: + m.get(self.baseurl + '/workbooks?filter=name:eq:RESTAPISample', text=response_xml) + matching_workbooks = self.server.workbooks.filter(name='RESTAPISample').order_by("name") + + self.assertEqual(2, matching_workbooks.total_available) + self.assertEqual('RESTAPISample', matching_workbooks[0].name) + self.assertEqual('RESTAPISample', matching_workbooks[1].name) + def test_filter_tags_in(self): with open(FILTER_TAGS_IN, 'rb') as f: response_xml = f.read().decode('utf-8') @@ -91,6 +102,22 @@ def test_filter_tags_in(self): self.assertEqual(set(['safari']), matching_workbooks[1].tags) self.assertEqual(set(['sample']), matching_workbooks[2].tags) + def test_filter_tags_in_shorthand(self): + with open(FILTER_TAGS_IN, 'rb') as f: + response_xml = f.read().decode('utf-8') + with requests_mock.mock() as m: + m.get(self.baseurl + '/workbooks?filter=tags:in:[sample,safari,weather]', text=response_xml) + matching_workbooks = self.server.workbooks.filter(tags__in=['sample', 'safari', 'weather']) + + self.assertEqual(3, matching_workbooks.total_available) + self.assertEqual(set(['weather']), matching_workbooks[0].tags) + self.assertEqual(set(['safari']), matching_workbooks[1].tags) + self.assertEqual(set(['sample']), matching_workbooks[2].tags) + + def test_invalid_shorthand_option(self): + with self.assertRaises(ValueError): + self.server.workbooks.filter(nonexistant__in=['sample', 'safari']) + def test_multiple_filter_options(self): with open(FILTER_MULTIPLE, 'rb') as f: response_xml = f.read().decode('utf-8') @@ -107,3 +134,19 @@ def test_multiple_filter_options(self): for _ in range(100): matching_workbooks, pagination_item = self.server.workbooks.get(req_option) self.assertEqual(3, pagination_item.total_available) + + def test_multiple_filter_options_shorthand(self): + with open(FILTER_MULTIPLE, 'rb') as f: + response_xml = f.read().decode('utf-8') + # To ensure that this is deterministic, run this a few times + with requests_mock.mock() as m: + # Sometimes pep8 requires you to do things you might not otherwise do + url = ''.join((self.baseurl, '/workbooks?pageNumber=1&pageSize=100&', + 'filter=name:eq:foo,tags:in:[sample,safari,weather]')) + m.get(url, text=response_xml) + + for _ in range(100): + matching_workbooks = self.server.workbooks.filter( + tags__in=['sample', 'safari', 'weather'], name='foo' + ) + self.assertEqual(3, matching_workbooks.total_available)