diff --git a/tableauserverclient/server/endpoint/__init__.py b/tableauserverclient/server/endpoint/__init__.py index e6b50b27d..7b89339bc 100644 --- a/tableauserverclient/server/endpoint/__init__.py +++ b/tableauserverclient/server/endpoint/__init__.py @@ -22,6 +22,7 @@ from tableauserverclient.server.endpoint.sites_endpoint import Sites from tableauserverclient.server.endpoint.subscriptions_endpoint import Subscriptions from tableauserverclient.server.endpoint.tables_endpoint import Tables +from tableauserverclient.server.endpoint.resource_tagger import Tags from tableauserverclient.server.endpoint.tasks_endpoint import Tasks from tableauserverclient.server.endpoint.users_endpoint import Users from tableauserverclient.server.endpoint.views_endpoint import Views @@ -55,6 +56,7 @@ "Sites", "Subscriptions", "Tables", + "Tags", "Tasks", "Users", "Views", diff --git a/tableauserverclient/server/endpoint/databases_endpoint.py b/tableauserverclient/server/endpoint/databases_endpoint.py index 849072a17..2f8fece07 100644 --- a/tableauserverclient/server/endpoint/databases_endpoint.py +++ b/tableauserverclient/server/endpoint/databases_endpoint.py @@ -1,17 +1,19 @@ import logging - -from .default_permissions_endpoint import _DefaultPermissionsEndpoint -from .dqw_endpoint import _DataQualityWarningEndpoint -from .endpoint import api, Endpoint -from .exceptions import MissingRequiredFieldError -from .permissions_endpoint import _PermissionsEndpoint +from typing import Union, Iterable, Set + +from tableauserverclient.server.endpoint.default_permissions_endpoint import _DefaultPermissionsEndpoint +from tableauserverclient.server.endpoint.dqw_endpoint import _DataQualityWarningEndpoint +from tableauserverclient.server.endpoint.endpoint import api, Endpoint +from tableauserverclient.server.endpoint.exceptions import MissingRequiredFieldError +from tableauserverclient.server.endpoint.permissions_endpoint import _PermissionsEndpoint +from tableauserverclient.server.endpoint.resource_tagger import TaggingMixin from tableauserverclient.server import RequestFactory from tableauserverclient.models import DatabaseItem, TableItem, PaginationItem, Resource from tableauserverclient.helpers.logging import logger -class Databases(Endpoint): +class Databases(Endpoint, TaggingMixin): def __init__(self, parent_srv): super(Databases, self).__init__(parent_srv) @@ -123,3 +125,15 @@ def add_dqw(self, item, warning): @api(version="3.5") def delete_dqw(self, item): self._data_quality_warnings.clear(item) + + @api(version="3.9") + def add_tags(self, item: Union[DatabaseItem, str], tags: Iterable[str]) -> Set[str]: + return super().add_tags(item, tags) + + @api(version="3.9") + def delete_tags(self, item: Union[DatabaseItem, str], tags: Iterable[str]) -> None: + super().delete_tags(item, tags) + + @api(version="3.9") + def update_tags(self, item: DatabaseItem) -> None: + raise NotImplementedError("Update tags is not supported for databases.") diff --git a/tableauserverclient/server/endpoint/datasources_endpoint.py b/tableauserverclient/server/endpoint/datasources_endpoint.py index a612adfe0..c01e57047 100644 --- a/tableauserverclient/server/endpoint/datasources_endpoint.py +++ b/tableauserverclient/server/endpoint/datasources_endpoint.py @@ -6,7 +6,7 @@ from contextlib import closing from pathlib import Path -from typing import List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import Iterable, List, Mapping, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union from tableauserverclient.helpers.headers import fix_filename from tableauserverclient.server.query import QuerySet @@ -20,7 +20,7 @@ from tableauserverclient.server.endpoint.endpoint import QuerysetEndpoint, api, parameter_added_in from tableauserverclient.server.endpoint.exceptions import InternalServerError, MissingRequiredFieldError from tableauserverclient.server.endpoint.permissions_endpoint import _PermissionsEndpoint -from tableauserverclient.server.endpoint.resource_tagger import _ResourceTagger +from tableauserverclient.server.endpoint.resource_tagger import TaggingMixin from tableauserverclient.config import ALLOWED_FILE_EXTENSIONS, FILESIZE_LIMIT_MB, BYTES_PER_MB, CHUNK_SIZE_MB from tableauserverclient.filesys_helpers import ( @@ -55,10 +55,9 @@ PathOrFileW = Union[FilePath, FileObjectW] -class Datasources(QuerysetEndpoint[DatasourceItem]): +class Datasources(QuerysetEndpoint[DatasourceItem], TaggingMixin): def __init__(self, parent_srv: "Server") -> None: super(Datasources, self).__init__(parent_srv) - self._resource_tagger = _ResourceTagger(parent_srv) self._permissions = _PermissionsEndpoint(parent_srv, lambda: self.baseurl) self._data_quality_warnings = _DataQualityWarningEndpoint(self.parent_srv, "datasource") @@ -150,7 +149,7 @@ def update(self, datasource_item: DatasourceItem) -> DatasourceItem: ) raise MissingRequiredFieldError(error) - self._resource_tagger.update_tags(self.baseurl, datasource_item) + self.update_tags(datasource_item) # Update the datasource itself url = "{0}/{1}".format(self.baseurl, datasource_item.id) @@ -461,6 +460,18 @@ def schedule_extract_refresh( ) -> List["AddResponse"]: # actually should return a task return self.parent_srv.schedules.add_to_schedule(schedule_id, datasource=item) + @api(version="1.0") + def add_tags(self, item: Union[DatasourceItem, str], tags: Union[Iterable[str], str]) -> Set[str]: + return super().add_tags(item, tags) + + @api(version="1.0") + def delete_tags(self, item: Union[DatasourceItem, str], tags: Union[Iterable[str], str]) -> None: + return super().delete_tags(item, tags) + + @api(version="1.0") + def update_tags(self, item: DatasourceItem) -> None: + return super().update_tags(item) + def filter(self, *invalid, page_size: Optional[int] = None, **kwargs) -> QuerySet[DatasourceItem]: """ Queries the Tableau Server for items using the specified filters. Page diff --git a/tableauserverclient/server/endpoint/flows_endpoint.py b/tableauserverclient/server/endpoint/flows_endpoint.py index a2458ad87..2adbe1f92 100644 --- a/tableauserverclient/server/endpoint/flows_endpoint.py +++ b/tableauserverclient/server/endpoint/flows_endpoint.py @@ -13,7 +13,7 @@ from tableauserverclient.server.endpoint.endpoint import QuerysetEndpoint, api from tableauserverclient.server.endpoint.exceptions import InternalServerError, MissingRequiredFieldError from tableauserverclient.server.endpoint.permissions_endpoint import _PermissionsEndpoint -from tableauserverclient.server.endpoint.resource_tagger import _ResourceTagger +from tableauserverclient.server.endpoint.resource_tagger import _ResourceTagger, TaggingMixin from tableauserverclient.models import FlowItem, PaginationItem, ConnectionItem, JobItem from tableauserverclient.server import RequestFactory from tableauserverclient.filesys_helpers import ( @@ -51,7 +51,7 @@ PathOrFileW = Union[FilePath, FileObjectW] -class Flows(QuerysetEndpoint[FlowItem]): +class Flows(QuerysetEndpoint[FlowItem], TaggingMixin): def __init__(self, parent_srv): super(Flows, self).__init__(parent_srv) self._resource_tagger = _ResourceTagger(parent_srv) diff --git a/tableauserverclient/server/endpoint/resource_tagger.py b/tableauserverclient/server/endpoint/resource_tagger.py index 8177bd733..f6b1cab05 100644 --- a/tableauserverclient/server/endpoint/resource_tagger.py +++ b/tableauserverclient/server/endpoint/resource_tagger.py @@ -1,14 +1,25 @@ +import abc import copy +from typing import Iterable, Optional, Protocol, Set, Union, TYPE_CHECKING, runtime_checkable import urllib.parse -from .endpoint import Endpoint -from .exceptions import ServerResponseError -from ..exceptions import EndpointUnavailableError +from tableauserverclient.server.endpoint.endpoint import Endpoint, api +from tableauserverclient.server.endpoint.exceptions import ServerResponseError +from tableauserverclient.server.exceptions import EndpointUnavailableError from tableauserverclient.server import RequestFactory from tableauserverclient.models import TagItem from tableauserverclient.helpers.logging import logger +if TYPE_CHECKING: + from tableauserverclient.models.column_item import ColumnItem + from tableauserverclient.models.database_item import DatabaseItem + from tableauserverclient.models.datasource_item import DatasourceItem + from tableauserverclient.models.flow_item import FlowItem + from tableauserverclient.models.table_item import TableItem + from tableauserverclient.models.workbook_item import WorkbookItem + from tableauserverclient.server.server import Server + class _ResourceTagger(Endpoint): # Add new tags to resource @@ -49,3 +60,121 @@ def update_tags(self, baseurl, resource_item): resource_item.tags = self._add_tags(baseurl, resource_item.id, add_set) resource_item._initial_tags = copy.copy(resource_item.tags) logger.info("Updated tags to {0}".format(resource_item.tags)) + + +class HasID(Protocol): + @property + def id(self) -> Optional[str]: + pass + + +@runtime_checkable +class Taggable(Protocol): + _initial_tags: Set[str] + tags: Set[str] + + @property + def id(self) -> Optional[str]: + pass + + +class Response(Protocol): + content: bytes + + +class TaggingMixin(abc.ABC): + parent_srv: "Server" + + @property + @abc.abstractmethod + def baseurl(self) -> str: + pass + + @abc.abstractmethod + def put_request(self, url, request) -> Response: + pass + + @abc.abstractmethod + def delete_request(self, url) -> None: + pass + + def add_tags(self, item: Union[HasID, Taggable, str], tags: Union[Iterable[str], str]) -> Set[str]: + item_id = getattr(item, "id", item) + + if not isinstance(item_id, str): + raise ValueError("ID not found.") + + if isinstance(tags, str): + tag_set = set([tags]) + else: + tag_set = set(tags) + + url = f"{self.baseurl}/{item_id}/tags" + add_req = RequestFactory.Tag.add_req(tag_set) + server_response = self.put_request(url, add_req) + return TagItem.from_response(server_response.content, self.parent_srv.namespace) + + def delete_tags(self, item: Union[HasID, Taggable, str], tags: Union[Iterable[str], str]) -> None: + item_id = getattr(item, "id", item) + + if not isinstance(item_id, str): + raise ValueError("ID not found.") + + if isinstance(tags, str): + tag_set = set([tags]) + else: + tag_set = set(tags) + + for tag in tag_set: + encoded_tag_name = urllib.parse.quote(tag) + url = f"{self.baseurl}/{item_id}/tags/{encoded_tag_name}" + self.delete_request(url) + + def update_tags(self, item: Taggable) -> None: + if item.tags == item._initial_tags: + return + + add_set = item.tags - item._initial_tags + remove_set = item._initial_tags - item.tags + self.delete_tags(item, remove_set) + if add_set: + item.tags = self.add_tags(item, add_set) + item._initial_tags = copy.copy(item.tags) + logger.info(f"Updated tags to {item.tags}") + + +content = Iterable[Union["ColumnItem", "DatabaseItem", "DatasourceItem", "FlowItem", "TableItem", "WorkbookItem"]] + + +class Tags(Endpoint): + def __init__(self, parent_srv: "Server"): + super().__init__(parent_srv) + + @property + def baseurl(self): + return f"{self.parent_srv.baseurl}/tags" + + @api(version="3.9") + def batch_add(self, tags: Union[Iterable[str], str], content: content) -> Set[str]: + if isinstance(tags, str): + tag_set = set([tags]) + else: + tag_set = set(tags) + + url = f"{self.baseurl}:batchCreate" + batch_create_req = RequestFactory.Tag.batch_create(tag_set, content) + server_response = self.put_request(url, batch_create_req) + return TagItem.from_response(server_response.content, self.parent_srv.namespace) + + @api(version="3.9") + def batch_delete(self, tags: Union[Iterable[str], str], content: content) -> Set[str]: + if isinstance(tags, str): + tag_set = set([tags]) + else: + tag_set = set(tags) + + url = f"{self.baseurl}:batchDelete" + # The batch delete XML is the same as the batch create XML. + batch_delete_req = RequestFactory.Tag.batch_create(tag_set, content) + server_response = self.put_request(url, batch_delete_req) + return TagItem.from_response(server_response.content, self.parent_srv.namespace) diff --git a/tableauserverclient/server/endpoint/tables_endpoint.py b/tableauserverclient/server/endpoint/tables_endpoint.py index b4c5181e9..b2e41df8b 100644 --- a/tableauserverclient/server/endpoint/tables_endpoint.py +++ b/tableauserverclient/server/endpoint/tables_endpoint.py @@ -1,17 +1,19 @@ import logging +from typing import Iterable, Set, Union -from .dqw_endpoint import _DataQualityWarningEndpoint -from .endpoint import api, Endpoint -from .exceptions import MissingRequiredFieldError -from .permissions_endpoint import _PermissionsEndpoint +from tableauserverclient.server.endpoint.dqw_endpoint import _DataQualityWarningEndpoint +from tableauserverclient.server.endpoint.endpoint import api, Endpoint +from tableauserverclient.server.endpoint.exceptions import MissingRequiredFieldError +from tableauserverclient.server.endpoint.permissions_endpoint import _PermissionsEndpoint +from tableauserverclient.server.endpoint.resource_tagger import TaggingMixin from tableauserverclient.server import RequestFactory from tableauserverclient.models import TableItem, ColumnItem, PaginationItem -from ..pager import Pager +from tableauserverclient.server.pager import Pager from tableauserverclient.helpers.logging import logger -class Tables(Endpoint): +class Tables(Endpoint, TaggingMixin): def __init__(self, parent_srv): super(Tables, self).__init__(parent_srv) @@ -124,3 +126,14 @@ def add_dqw(self, item, warning): @api(version="3.5") def delete_dqw(self, item): self._data_quality_warnings.clear(item) + + @api(version="3.9") + def add_tags(self, item: Union[TableItem, str], tags: Union[Iterable[str], str]) -> Set[str]: + return super().add_tags(item, tags) + + @api(version="3.9") + def delete_tags(self, item: Union[TableItem, str], tags: Union[Iterable[str], str]) -> None: + return super().delete_tags(item, tags) + + def update_tags(self, item: TableItem) -> None: # type: ignore + raise NotImplementedError("Update tags is not implemented for TableItem") diff --git a/tableauserverclient/server/endpoint/views_endpoint.py b/tableauserverclient/server/endpoint/views_endpoint.py index f8c50caaf..7a8623614 100644 --- a/tableauserverclient/server/endpoint/views_endpoint.py +++ b/tableauserverclient/server/endpoint/views_endpoint.py @@ -1,20 +1,20 @@ import logging from contextlib import closing +from tableauserverclient.server.endpoint.endpoint import QuerysetEndpoint, api +from tableauserverclient.server.endpoint.exceptions import MissingRequiredFieldError +from tableauserverclient.server.endpoint.permissions_endpoint import _PermissionsEndpoint +from tableauserverclient.server.endpoint.resource_tagger import TaggingMixin from tableauserverclient.server.query import QuerySet -from .endpoint import QuerysetEndpoint, api -from .exceptions import MissingRequiredFieldError -from .permissions_endpoint import _PermissionsEndpoint -from .resource_tagger import _ResourceTagger from tableauserverclient.models import ViewItem, PaginationItem from tableauserverclient.helpers.logging import logger -from typing import Iterator, List, Optional, Tuple, TYPE_CHECKING +from typing import Iterable, Iterator, List, Optional, Set, Tuple, TYPE_CHECKING, Union if TYPE_CHECKING: - from ..request_options import ( + from tableauserverclient.server.request_options import ( RequestOptions, CSVRequestOptions, PDFRequestOptions, @@ -23,10 +23,9 @@ ) -class Views(QuerysetEndpoint[ViewItem]): +class Views(QuerysetEndpoint[ViewItem], TaggingMixin): def __init__(self, parent_srv): super(Views, self).__init__(parent_srv) - self._resource_tagger = _ResourceTagger(parent_srv) self._permissions = _PermissionsEndpoint(parent_srv, lambda: self.baseurl) # Used because populate_preview_image functionaliy requires workbook endpoint @@ -171,11 +170,23 @@ def update(self, view_item: ViewItem) -> ViewItem: error = "View item missing ID. View must be retrieved from server first." raise MissingRequiredFieldError(error) - self._resource_tagger.update_tags(self.baseurl, view_item) + self.update_tags(view_item) # Returning view item to stay consistent with datasource/view update functions return view_item + @api(version="1.0") + def add_tags(self, item: Union[ViewItem, str], tags: Union[Iterable[str], str]) -> Set[str]: + return super().add_tags(item, tags) + + @api(version="1.0") + def delete_tags(self, item: Union[ViewItem, str], tags: Union[Iterable[str], str]) -> None: + return super().delete_tags(item, tags) + + @api(version="1.0") + def update_tags(self, item: ViewItem) -> None: + return super().update_tags(item) + def filter(self, *invalid, page_size: Optional[int] = None, **kwargs) -> QuerySet[ViewItem]: """ Queries the Tableau Server for items using the specified filters. Page diff --git a/tableauserverclient/server/endpoint/workbooks_endpoint.py b/tableauserverclient/server/endpoint/workbooks_endpoint.py index e80fa2daf..55f61370f 100644 --- a/tableauserverclient/server/endpoint/workbooks_endpoint.py +++ b/tableauserverclient/server/endpoint/workbooks_endpoint.py @@ -9,10 +9,10 @@ from tableauserverclient.helpers.headers import fix_filename from tableauserverclient.server.query import QuerySet -from .endpoint import QuerysetEndpoint, api, parameter_added_in -from .exceptions import InternalServerError, MissingRequiredFieldError -from .permissions_endpoint import _PermissionsEndpoint -from .resource_tagger import _ResourceTagger +from tableauserverclient.server.endpoint.endpoint import QuerysetEndpoint, api, parameter_added_in +from tableauserverclient.server.endpoint.exceptions import InternalServerError, MissingRequiredFieldError +from tableauserverclient.server.endpoint.permissions_endpoint import _PermissionsEndpoint +from tableauserverclient.server.endpoint.resource_tagger import TaggingMixin from tableauserverclient.filesys_helpers import ( to_filename, @@ -25,9 +25,11 @@ from tableauserverclient.server import RequestFactory from typing import ( + Iterable, List, Optional, Sequence, + Set, Tuple, TYPE_CHECKING, Union, @@ -36,8 +38,8 @@ if TYPE_CHECKING: from tableauserverclient.server import Server from tableauserverclient.server.request_options import RequestOptions - from tableauserverclient.models import DatasourceItem, ConnectionCredentials - from .schedules_endpoint import AddResponse + from tableauserverclient.models import DatasourceItem + from tableauserverclient.server.endpoint.schedules_endpoint import AddResponse io_types_r = (io.BytesIO, io.BufferedReader) io_types_w = (io.BytesIO, io.BufferedWriter) @@ -57,10 +59,9 @@ PathOrFileW = Union[FilePath, FileObjectW] -class Workbooks(QuerysetEndpoint[WorkbookItem]): +class Workbooks(QuerysetEndpoint[WorkbookItem], TaggingMixin): def __init__(self, parent_srv: "Server") -> None: super(Workbooks, self).__init__(parent_srv) - self._resource_tagger = _ResourceTagger(parent_srv) self._permissions = _PermissionsEndpoint(parent_srv, lambda: self.baseurl) return None @@ -148,7 +149,7 @@ def update( error = "Workbook item missing ID. Workbook must be retrieved from server first." raise MissingRequiredFieldError(error) - self._resource_tagger.update_tags(self.baseurl, workbook_item) + self.update_tags(workbook_item) # Update the workbook itself url = "{0}/{1}".format(self.baseurl, workbook_item.id) @@ -500,6 +501,18 @@ def schedule_extract_refresh( ) -> List["AddResponse"]: # actually should return a task return self.parent_srv.schedules.add_to_schedule(schedule_id, workbook=item) + @api(version="1.0") + def add_tags(self, item: Union[WorkbookItem, str], tags: Union[Iterable[str], str]) -> Set[str]: + return super().add_tags(item, tags) + + @api(version="1.0") + def delete_tags(self, item: Union[WorkbookItem, str], tags: Union[Iterable[str], str]) -> None: + return super().delete_tags(item, tags) + + @api(version="1.0") + def update_tags(self, item: WorkbookItem) -> None: + return super().update_tags(item) + def filter(self, *invalid, page_size: Optional[int] = None, **kwargs) -> QuerySet[WorkbookItem]: """ Queries the Tableau Server for items using the specified filters. Page diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index b0c8b37b0..aeb355ea6 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -1,6 +1,5 @@ import xml.etree.ElementTree as ET - -from typing import Any, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union, TYPE_CHECKING from requests.packages.urllib3.fields import RequestField from requests.packages.urllib3.filepost import encode_multipart_formdata @@ -862,6 +861,9 @@ def update_req(self, table_item): return ET.tostring(xml_request) +content_types = Iterable[Union["ColumnItem", "DatabaseItem", "DatasourceItem", "FlowItem", "TableItem", "WorkbookItem"]] + + class TagRequest(object): def add_req(self, tag_set): xml_request = ET.Element("tsRequest") @@ -871,6 +873,22 @@ def add_req(self, tag_set): tag_element.attrib["label"] = tag return ET.tostring(xml_request) + @_tsrequest_wrapped + def batch_create(self, element: ET.Element, tags: Set[str], content: content_types) -> bytes: + tag_batch = ET.SubElement(element, "tagBatch") + tags_element = ET.SubElement(tag_batch, "tags") + for tag in tags: + tag_element = ET.SubElement(tags_element, "tag") + tag_element.attrib["label"] = tag + contents_element = ET.SubElement(tag_batch, "contents") + for item in content: + content_element = ET.SubElement(contents_element, "content") + if item.id is None: + raise ValueError(f"Item {item} must have an ID to be tagged.") + content_element.attrib["id"] = item.id + + return ET.tostring(element) + class UserRequest(object): def update_req(self, user_item: UserItem, password: Optional[str]) -> bytes: diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index 18d67fa07..1de865ba8 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -34,6 +34,7 @@ Endpoint, CustomViews, GroupSets, + Tags, ) from tableauserverclient.server.exceptions import ( ServerInfoEndpointNotFoundError, @@ -101,6 +102,7 @@ def __init__(self, server_address, use_server_version=False, http_options=None, self.metrics = Metrics(self) self.custom_views = CustomViews(self) self.group_sets = GroupSets(self) + self.tags = Tags(self) self._session = self._session_factory() self._http_options = dict() # must set this before making a server call diff --git a/test/test_tagging.py b/test/test_tagging.py new file mode 100644 index 000000000..d3f23d40e --- /dev/null +++ b/test/test_tagging.py @@ -0,0 +1,186 @@ +import re +from typing import Iterable +import uuid +from xml.etree import ElementTree as ET + +import pytest +import requests_mock +import tableauserverclient as TSC +from tableauserverclient.server.endpoint.resource_tagger import content + + +@pytest.fixture +def get_server() -> TSC.Server: + server = TSC.Server("http://test", False) + + # Fake sign in + server._site_id = "dad65087-b08b-4603-af4e-2887b8aafc67" + server._auth_token = "j80k54ll2lfMZ0tv97mlPvvSCRyD0DOM" + server.version = "3.28" + return server + + +def add_tag_xml_response_factory(tags: Iterable[str]) -> str: + root = ET.Element("tsResponse") + tags_element = ET.SubElement(root, "tags") + for tag in tags: + tag_element = ET.SubElement(tags_element, "tag") + tag_element.attrib["label"] = tag + root.attrib["xmlns"] = "http://tableau.com/api" + return ET.tostring(root, encoding="utf-8").decode("utf-8") + + +def batch_add_tags_xml_response_factory(tags, content): + root = ET.Element("tsResponse") + tag_batch = ET.SubElement(root, "tagBatch") + tags_element = ET.SubElement(tag_batch, "tags") + for tag in tags: + tag_element = ET.SubElement(tags_element, "tag") + tag_element.attrib["label"] = tag + contents_element = ET.SubElement(tag_batch, "contents") + for item in content: + content_elem = ET.SubElement(contents_element, "content") + content_elem.attrib["id"] = item.id or "some_id" + t = item.__class__.__name__.replace("Item", "") or "" + content_elem.attrib["contentType"] = t + root.attrib["xmlns"] = "http://tableau.com/api" + return ET.tostring(root, encoding="utf-8").decode("utf-8") + + +def make_workbook() -> TSC.WorkbookItem: + workbook = TSC.WorkbookItem("project", "test") + workbook._id = str(uuid.uuid4()) + return workbook + + +def make_view() -> TSC.ViewItem: + view = TSC.ViewItem() + view._id = str(uuid.uuid4()) + return view + + +def make_datasource() -> TSC.DatasourceItem: + datasource = TSC.DatasourceItem("project", "test") + datasource._id = str(uuid.uuid4()) + return datasource + + +def make_table() -> TSC.TableItem: + table = TSC.TableItem("project", "test") + table._id = str(uuid.uuid4()) + return table + + +def make_database() -> TSC.DatabaseItem: + database = TSC.DatabaseItem("project", "test") + database._id = str(uuid.uuid4()) + return database + + +def make_flow() -> TSC.FlowItem: + flow = TSC.FlowItem("project", "test") + flow._id = str(uuid.uuid4()) + return flow + + +sample_taggable_items = ( + [ + ("workbooks", make_workbook()), + ("workbooks", "some_id"), + ("views", make_view()), + ("views", "some_id"), + ("datasources", make_datasource()), + ("datasources", "some_id"), + ("tables", make_table()), + ("tables", "some_id"), + ("databases", make_database()), + ("databases", "some_id"), + ("flows", make_flow()), + ("flows", "some_id"), + ], +) + +sample_tags = [ + "a", + ["a", "b"], + ["a", "b", "c", "c"], +] + + +@pytest.mark.parametrize("endpoint_type, item", *sample_taggable_items) +@pytest.mark.parametrize("tags", sample_tags) +def test_add_tags(get_server, endpoint_type, item, tags) -> None: + add_tags_xml = add_tag_xml_response_factory(tags) + endpoint = getattr(get_server, endpoint_type) + id_ = getattr(item, "id", item) + + with requests_mock.mock() as m: + m.put( + f"{endpoint.baseurl}/{id_}/tags", + status_code=200, + text=add_tags_xml, + ) + tag_result = endpoint.add_tags(item, tags) + + if isinstance(tags, str): + tags = [tags] + assert set(tag_result) == set(tags) + + +@pytest.mark.parametrize("endpoint_type, item", *sample_taggable_items) +@pytest.mark.parametrize("tags", sample_tags) +def test_delete_tags(get_server, endpoint_type, item, tags) -> None: + add_tags_xml = add_tag_xml_response_factory(tags) + endpoint = getattr(get_server, endpoint_type) + id_ = getattr(item, "id", item) + + if isinstance(tags, str): + tags = [tags] + tag_paths = "|".join(tags) + tag_paths = f"({tag_paths})" + matcher = re.compile(rf"{endpoint.baseurl}\/{id_}\/tags\/{tag_paths}") + with requests_mock.mock() as m: + m.delete( + matcher, + status_code=200, + text=add_tags_xml, + ) + endpoint.delete_tags(item, tags) + history = m.request_history + + tag_set = set(tags) + assert len(history) == len(tag_set) + urls = {r.url.split("/")[-1] for r in history} + assert urls == tag_set + + +def test_tags_batch_add(get_server) -> None: + server = get_server + content = [make_workbook(), make_view(), make_datasource(), make_table(), make_database()] + tags = ["a", "b"] + add_tags_xml = batch_add_tags_xml_response_factory(tags, content) + with requests_mock.mock() as m: + m.put( + f"{server.tags.baseurl}:batchCreate", + status_code=200, + text=add_tags_xml, + ) + tag_result = server.tags.batch_add(tags, content) + + assert set(tag_result) == set(tags) + + +def test_tags_batch_delete(get_server) -> None: + server = get_server + content = [make_workbook(), make_view(), make_datasource(), make_table(), make_database()] + tags = ["a", "b"] + add_tags_xml = batch_add_tags_xml_response_factory(tags, content) + with requests_mock.mock() as m: + m.put( + f"{server.tags.baseurl}:batchDelete", + status_code=200, + text=add_tags_xml, + ) + tag_result = server.tags.batch_delete(tags, content) + + assert set(tag_result) == set(tags)