diff --git a/tableauserverclient/server/endpoint/datasources_endpoint.py b/tableauserverclient/server/endpoint/datasources_endpoint.py index 77d898d86..7f3a47075 100644 --- a/tableauserverclient/server/endpoint/datasources_endpoint.py +++ b/tableauserverclient/server/endpoint/datasources_endpoint.py @@ -55,7 +55,7 @@ PathOrFileW = Union[FilePath, FileObjectW] -class Datasources(QuerysetEndpoint[DatasourceItem], TaggingMixin): +class Datasources(QuerysetEndpoint[DatasourceItem], TaggingMixin[DatasourceItem]): def __init__(self, parent_srv: "Server") -> None: super(Datasources, self).__init__(parent_srv) self._permissions = _PermissionsEndpoint(parent_srv, lambda: self.baseurl) @@ -126,7 +126,7 @@ def download( datasource_id: str, filepath: Optional[PathOrFileW] = None, include_extract: bool = True, - ) -> str: + ) -> PathOrFileW: return self.download_revision( datasource_id, None, @@ -405,7 +405,7 @@ def _get_datasource_revisions( def download_revision( self, datasource_id: str, - revision_number: str, + revision_number: Optional[str], filepath: Optional[PathOrFileW] = None, include_extract: bool = True, ) -> PathOrFileW: diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index 6b29e736a..0e55d5739 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -1,30 +1,41 @@ +from typing_extensions import Concatenate, ParamSpec from tableauserverclient import datetime_helpers as datetime import abc from packaging.version import Version from functools import wraps from xml.etree.ElementTree import ParseError -from typing import Any, Callable, Dict, Generic, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + TYPE_CHECKING, + Tuple, + TypeVar, + Union, +) from tableauserverclient.models.pagination_item import PaginationItem from tableauserverclient.server.request_options import RequestOptions -from .exceptions import ( +from tableauserverclient.server.endpoint.exceptions import ( ServerResponseError, InternalServerError, NonXMLResponseError, NotSignedInError, ) -from ..exceptions import EndpointUnavailableError +from tableauserverclient.server.exceptions import EndpointUnavailableError from tableauserverclient.server.query import QuerySet from tableauserverclient import helpers, get_versions from tableauserverclient.helpers.logging import logger -from tableauserverclient.config import DELAY_SLEEP_SECONDS if TYPE_CHECKING: - from ..server import Server + from tableauserverclient.server.server import Server from requests import Response @@ -38,7 +49,7 @@ USER_AGENT_HEADER = "User-Agent" -class Endpoint(object): +class Endpoint: def __init__(self, parent_srv: "Server"): self.parent_srv = parent_srv @@ -232,7 +243,12 @@ def patch_request(self, url, xml_request, content_type=XML_CONTENT_TYPE, paramet ) -def api(version): +E = TypeVar("E", bound="Endpoint") +P = ParamSpec("P") +R = TypeVar("R") + + +def api(version: str) -> Callable[[Callable[Concatenate[E, P], R]], Callable[Concatenate[E, P], R]]: """Annotate the minimum supported version for an endpoint. Checks the version on the server object and compares normalized versions. @@ -251,9 +267,9 @@ def api(version): >>> ... """ - def _decorator(func): + def _decorator(func: Callable[Concatenate[E, P], R]) -> Callable[Concatenate[E, P], R]: @wraps(func) - def wrapper(self, *args, **kwargs): + def wrapper(self: E, *args: P.args, **kwargs: P.kwargs) -> R: self.parent_srv.assert_at_least_version(version, self.__class__.__name__) return func(self, *args, **kwargs) @@ -262,7 +278,7 @@ def wrapper(self, *args, **kwargs): return _decorator -def parameter_added_in(**params): +def parameter_added_in(**params: str) -> Callable[[Callable[Concatenate[E, P], R]], Callable[Concatenate[E, P], R]]: """Annotate minimum versions for new parameters or request options on an endpoint. The api decorator documents when an endpoint was added, this decorator annotates @@ -285,9 +301,9 @@ def parameter_added_in(**params): >>> ... """ - def _decorator(func): + def _decorator(func: Callable[Concatenate[E, P], R]) -> Callable[Concatenate[E, P], R]: @wraps(func) - def wrapper(self, *args, **kwargs): + def wrapper(self: E, *args: P.args, **kwargs: P.kwargs) -> R: import warnings server_ver = Version(self.parent_srv.version or "0.0") @@ -335,5 +351,5 @@ def paginate(self, **kwargs) -> QuerySet[T]: return queryset @abc.abstractmethod - def get(self, request_options: RequestOptions) -> Tuple[List[T], PaginationItem]: + def get(self, request_options: Optional[RequestOptions] = None) -> Tuple[List[T], PaginationItem]: raise NotImplementedError(f".get has not been implemented for {self.__class__.__qualname__}") diff --git a/tableauserverclient/server/endpoint/flows_endpoint.py b/tableauserverclient/server/endpoint/flows_endpoint.py index 2adbe1f92..53d072f50 100644 --- a/tableauserverclient/server/endpoint/flows_endpoint.py +++ b/tableauserverclient/server/endpoint/flows_endpoint.py @@ -51,7 +51,7 @@ PathOrFileW = Union[FilePath, FileObjectW] -class Flows(QuerysetEndpoint[FlowItem], TaggingMixin): +class Flows(QuerysetEndpoint[FlowItem], TaggingMixin[FlowItem]): def __init__(self, parent_srv): super(Flows, self).__init__(parent_srv) self._resource_tagger = _ResourceTagger(parent_srv) diff --git a/tableauserverclient/server/endpoint/jobs_endpoint.py b/tableauserverclient/server/endpoint/jobs_endpoint.py index a48a3244c..54e699722 100644 --- a/tableauserverclient/server/endpoint/jobs_endpoint.py +++ b/tableauserverclient/server/endpoint/jobs_endpoint.py @@ -1,4 +1,5 @@ import logging +from typing_extensions import Self, overload from tableauserverclient.server.query import QuerySet @@ -13,15 +14,25 @@ from typing import List, Optional, Tuple, Union -class Jobs(QuerysetEndpoint[JobItem]): +class Jobs(QuerysetEndpoint[BackgroundJobItem]): @property def baseurl(self): return "{0}/sites/{1}/jobs".format(self.parent_srv.baseurl, self.parent_srv.site_id) + @overload # type: ignore[override] + def get(self: Self, job_id: str, req_options: Optional[RequestOptionsBase] = None) -> JobItem: # type: ignore[override] + ... + + @overload # type: ignore[override] + def get(self: Self, job_id: RequestOptionsBase, req_options: None) -> Tuple[List[BackgroundJobItem], PaginationItem]: # type: ignore[override] + ... + + @overload # type: ignore[override] + def get(self: Self, job_id: None, req_options: Optional[RequestOptionsBase]) -> Tuple[List[BackgroundJobItem], PaginationItem]: # type: ignore[override] + ... + @api(version="2.6") - def get( - self, job_id: Optional[str] = None, req_options: Optional[RequestOptionsBase] = None - ) -> Tuple[List[BackgroundJobItem], PaginationItem]: + def get(self, job_id=None, req_options=None): # Backwards Compatibility fix until we rev the major version if job_id is not None and isinstance(job_id, str): import warnings @@ -77,7 +88,7 @@ def wait_for_job(self, job_id: Union[str, JobItem], *, timeout: Optional[float] else: raise AssertionError("Unexpected finish_code in job", job) - def filter(self, *invalid, page_size: Optional[int] = None, **kwargs) -> QuerySet[JobItem]: + def filter(self, *invalid, page_size: Optional[int] = None, **kwargs) -> QuerySet[BackgroundJobItem]: """ Queries the Tableau Server for items using the specified filters. Page size can be specified to limit the number of items returned in a single diff --git a/tableauserverclient/server/endpoint/resource_tagger.py b/tableauserverclient/server/endpoint/resource_tagger.py index f6b1cab05..1894e3b8a 100644 --- a/tableauserverclient/server/endpoint/resource_tagger.py +++ b/tableauserverclient/server/endpoint/resource_tagger.py @@ -1,6 +1,6 @@ import abc import copy -from typing import Iterable, Optional, Protocol, Set, Union, TYPE_CHECKING, runtime_checkable +from typing import Generic, Iterable, Optional, Protocol, Set, TypeVar, Union, TYPE_CHECKING, runtime_checkable import urllib.parse from tableauserverclient.server.endpoint.endpoint import Endpoint, api @@ -62,27 +62,24 @@ def update_tags(self, baseurl, resource_item): logger.info("Updated tags to {0}".format(resource_item.tags)) -class HasID(Protocol): - @property - def id(self) -> Optional[str]: - pass +class Response(Protocol): + content: bytes @runtime_checkable class Taggable(Protocol): - _initial_tags: Set[str] tags: Set[str] + _initial_tags: Set[str] @property def id(self) -> Optional[str]: pass -class Response(Protocol): - content: bytes +T = TypeVar("T") -class TaggingMixin(abc.ABC): +class TaggingMixin(abc.ABC, Generic[T]): parent_srv: "Server" @property @@ -98,7 +95,7 @@ def put_request(self, url, request) -> Response: def delete_request(self, url) -> None: pass - def add_tags(self, item: Union[HasID, Taggable, str], tags: Union[Iterable[str], str]) -> Set[str]: + def add_tags(self, item: Union[T, str], tags: Union[Iterable[str], str]) -> Set[str]: item_id = getattr(item, "id", item) if not isinstance(item_id, str): @@ -114,7 +111,7 @@ def add_tags(self, item: Union[HasID, Taggable, str], tags: Union[Iterable[str], server_response = self.put_request(url, add_req) return TagItem.from_response(server_response.content, self.parent_srv.namespace) - def delete_tags(self, item: Union[HasID, Taggable, str], tags: Union[Iterable[str], str]) -> None: + def delete_tags(self, item: Union[T, str], tags: Union[Iterable[str], str]) -> None: item_id = getattr(item, "id", item) if not isinstance(item_id, str): @@ -130,17 +127,23 @@ def delete_tags(self, item: Union[HasID, Taggable, str], tags: Union[Iterable[st url = f"{self.baseurl}/{item_id}/tags/{encoded_tag_name}" self.delete_request(url) - def update_tags(self, item: Taggable) -> None: - if item.tags == item._initial_tags: + def update_tags(self, item: T) -> None: + if (initial_tags := getattr(item, "_initial_tags", None)) is None: + raise ValueError(f"{item} does not have initial tags.") + if (tags := getattr(item, "tags", None)) is None: + raise ValueError(f"{item} does not have tags.") + if tags == initial_tags: return - add_set = item.tags - item._initial_tags - remove_set = item._initial_tags - item.tags + add_set = tags - initial_tags + remove_set = initial_tags - tags self.delete_tags(item, remove_set) if add_set: - item.tags = self.add_tags(item, add_set) - item._initial_tags = copy.copy(item.tags) - logger.info(f"Updated tags to {item.tags}") + tags = self.add_tags(item, add_set) + setattr(item, "tags", tags) + + setattr(item, "_initial_tags", copy.copy(tags)) + logger.info(f"Updated tags to {tags}") content = Iterable[Union["ColumnItem", "DatabaseItem", "DatasourceItem", "FlowItem", "TableItem", "WorkbookItem"]] diff --git a/tableauserverclient/server/endpoint/tables_endpoint.py b/tableauserverclient/server/endpoint/tables_endpoint.py index b2e41df8b..36ef78c0a 100644 --- a/tableauserverclient/server/endpoint/tables_endpoint.py +++ b/tableauserverclient/server/endpoint/tables_endpoint.py @@ -13,7 +13,7 @@ from tableauserverclient.helpers.logging import logger -class Tables(Endpoint, TaggingMixin): +class Tables(Endpoint, TaggingMixin[TableItem]): def __init__(self, parent_srv): super(Tables, self).__init__(parent_srv) diff --git a/tableauserverclient/server/endpoint/views_endpoint.py b/tableauserverclient/server/endpoint/views_endpoint.py index 7a8623614..f2ccf658e 100644 --- a/tableauserverclient/server/endpoint/views_endpoint.py +++ b/tableauserverclient/server/endpoint/views_endpoint.py @@ -23,7 +23,7 @@ ) -class Views(QuerysetEndpoint[ViewItem], TaggingMixin): +class Views(QuerysetEndpoint[ViewItem], TaggingMixin[ViewItem]): def __init__(self, parent_srv): super(Views, self).__init__(parent_srv) self._permissions = _PermissionsEndpoint(parent_srv, lambda: self.baseurl) diff --git a/tableauserverclient/server/endpoint/workbooks_endpoint.py b/tableauserverclient/server/endpoint/workbooks_endpoint.py index 55f61370f..da6eda3de 100644 --- a/tableauserverclient/server/endpoint/workbooks_endpoint.py +++ b/tableauserverclient/server/endpoint/workbooks_endpoint.py @@ -59,7 +59,7 @@ PathOrFileW = Union[FilePath, FileObjectW] -class Workbooks(QuerysetEndpoint[WorkbookItem], TaggingMixin): +class Workbooks(QuerysetEndpoint[WorkbookItem], TaggingMixin[WorkbookItem]): def __init__(self, parent_srv: "Server") -> None: super(Workbooks, self).__init__(parent_srv) self._permissions = _PermissionsEndpoint(parent_srv, lambda: self.baseurl) @@ -184,7 +184,7 @@ def download( workbook_id: str, filepath: Optional[PathOrFileW] = None, include_extract: bool = True, - ) -> str: + ) -> PathOrFileW: return self.download_revision( workbook_id, None, diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index aeb355ea6..7fc9c9555 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -1,8 +1,11 @@ import xml.etree.ElementTree as ET -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union, TYPE_CHECKING +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, TYPE_CHECKING, Union + +from typing_extensions import ParamSpec from requests.packages.urllib3.fields import RequestField from requests.packages.urllib3.filepost import encode_multipart_formdata +from typing_extensions import Concatenate from tableauserverclient.models import * @@ -23,8 +26,12 @@ def _add_multipart(parts: Dict) -> Tuple[Any, str]: return xml_request, content_type -def _tsrequest_wrapped(func): - def wrapper(self, *args, **kwargs) -> bytes: +T = TypeVar("T") +P = ParamSpec("P") + + +def _tsrequest_wrapped(func: Callable[Concatenate[T, ET.Element, P], Any]) -> Callable[Concatenate[T, P], bytes]: + def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> bytes: xml_request = ET.Element("tsRequest") func(self, xml_request, *args, **kwargs) return ET.tostring(xml_request) @@ -388,7 +395,7 @@ def add_user_req(self, user_id: str) -> bytes: return ET.tostring(xml_request) @_tsrequest_wrapped - def add_users_req(self, xml_request, users: Iterable[Union[str, UserItem]]) -> bytes: + def add_users_req(self, xml_request: ET.Element, users: Iterable[Union[str, UserItem]]) -> bytes: users_element = ET.SubElement(xml_request, "users") for user in users: user_element = ET.SubElement(users_element, "user") @@ -399,7 +406,7 @@ def add_users_req(self, xml_request, users: Iterable[Union[str, UserItem]]) -> b return ET.tostring(xml_request) @_tsrequest_wrapped - def remove_users_req(self, xml_request, users: Iterable[Union[str, UserItem]]) -> bytes: + def remove_users_req(self, xml_request: ET.Element, users: Iterable[Union[str, UserItem]]) -> bytes: users_element = ET.SubElement(xml_request, "users") for user in users: user_element = ET.SubElement(users_element, "user") @@ -1055,14 +1062,17 @@ def publish_req_chunked( return _add_multipart(parts) @_tsrequest_wrapped - def embedded_extract_req(self, xml_request, include_all=True, datasources=None): + def embedded_extract_req( + self, xml_request: ET.Element, include_all: bool = True, datasources: Optional[Iterable[DatasourceItem]] = None + ) -> None: list_element = ET.SubElement(xml_request, "datasources") if include_all: list_element.attrib["includeAll"] = "true" elif datasources: for datasource_item in datasources: datasource_element = ET.SubElement(list_element, "datasource") - datasource_element.attrib["id"] = datasource_item.id + if (id_ := datasource_item.id) is not None: + datasource_element.attrib["id"] = id_ class Connection(object): @@ -1090,7 +1100,7 @@ def update_req(self, xml_request: ET.Element, connection_item: "ConnectionItem") class TaskRequest(object): @_tsrequest_wrapped - def run_req(self, xml_request, task_item): + def run_req(self, xml_request: ET.Element, task_item: Any) -> None: # Send an empty tsRequest pass @@ -1227,7 +1237,7 @@ def update_req(self, xml_request: ET.Element, subscription_item: "SubscriptionIt class EmptyRequest(object): @_tsrequest_wrapped - def empty_req(self, xml_request): + def empty_req(self, xml_request: ET.Element) -> None: pass diff --git a/test/test_tagging.py b/test/test_tagging.py index d3f23d40e..fc88eea8a 100644 --- a/test/test_tagging.py +++ b/test/test_tagging.py @@ -1,3 +1,4 @@ +from contextlib import ExitStack import re from typing import Iterable import uuid @@ -6,7 +7,6 @@ import pytest import requests_mock import tableauserverclient as TSC -from tableauserverclient.server.endpoint.resource_tagger import content @pytest.fixture @@ -154,6 +154,42 @@ def test_delete_tags(get_server, endpoint_type, item, tags) -> None: assert urls == tag_set +@pytest.mark.parametrize("endpoint_type, item", *sample_taggable_items) +@pytest.mark.parametrize("tags", sample_tags) +def test_update_tags(get_server, endpoint_type, item, tags) -> None: + endpoint = getattr(get_server, endpoint_type) + id_ = getattr(item, "id", item) + tags = set([tags] if isinstance(tags, str) else tags) + with ExitStack() as stack: + if isinstance(item, str): + stack.enter_context(pytest.raises((ValueError, NotImplementedError))) + elif hasattr(item, "_initial_tags"): + initial_tags = set(["x", "y", "z"]) + item._initial_tags = initial_tags + add_tags_xml = add_tag_xml_response_factory(tags - initial_tags) + delete_tags_xml = add_tag_xml_response_factory(initial_tags - tags) + m = stack.enter_context(requests_mock.mock()) + m.put( + f"{endpoint.baseurl}/{id_}/tags", + status_code=200, + text=add_tags_xml, + ) + + tag_paths = "|".join(initial_tags - tags) + tag_paths = f"({tag_paths})" + matcher = re.compile(rf"{endpoint.baseurl}\/{id_}\/tags\/{tag_paths}") + m.delete( + matcher, + status_code=200, + text=delete_tags_xml, + ) + + else: + stack.enter_context(pytest.raises(NotImplementedError)) + + endpoint.update_tags(item) + + def test_tags_batch_add(get_server) -> None: server = get_server content = [make_workbook(), make_view(), make_datasource(), make_table(), make_database()]