diff --git a/tableauserverclient/models/data_alert_item.py b/tableauserverclient/models/data_alert_item.py index a4d11ca5e..62796fd6a 100644 --- a/tableauserverclient/models/data_alert_item.py +++ b/tableauserverclient/models/data_alert_item.py @@ -9,6 +9,12 @@ from .view_item import ViewItem +from typing import List, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from datetime import datetime + + class DataAlertItem(object): class Frequency: Once = "Once" @@ -18,35 +24,35 @@ class Frequency: Weekly = "Weekly" def __init__(self): - self._id = None - self._subject = None - self._creatorId = None - self._createdAt = None - self._updatedAt = None - self._frequency = None - self._public = None - self._owner_id = None - self._owner_name = None - self._view_id = None - self._view_name = None - self._workbook_id = None - self._workbook_name = None - self._project_id = None - self._project_name = None - self._recipients = None - - def __repr__(self): + self._id: Optional[str] = None + self._subject: Optional[str] = None + self._creatorId: Optional[str] = None + self._createdAt: Optional["datetime"] = None + self._updatedAt: Optional["datetime"] = None + self._frequency: Optional[str] = None + self._public: Optional[bool] = None + self._owner_id: Optional[str] = None + self._owner_name: Optional[str] = None + self._view_id: Optional[str] = None + self._view_name: Optional[str] = None + self._workbook_id: Optional[str] = None + self._workbook_name: Optional[str] = None + self._project_id: Optional[str] = None + self._project_name: Optional[str] = None + self._recipients: Optional[List[str]] = None + + def __repr__(self) -> str: return "".format( **self.__dict__ ) @property - def id(self): + def id(self) -> Optional[str]: return self._id @property - def subject(self): + def subject(self) -> Optional[str]: return self._subject @subject.setter @@ -55,69 +61,69 @@ def subject(self, value): self._subject = value @property - def frequency(self): + def frequency(self) -> Optional[str]: return self._frequency @frequency.setter @property_is_enum(Frequency) - def frequency(self, value): + def frequency(self, value: str) -> None: self._frequency = value @property - def public(self): + def public(self) -> Optional[bool]: return self._public @public.setter @property_is_boolean - def public(self, value): + def public(self, value: bool) -> None: self._public = value @property - def creatorId(self): + def creatorId(self) -> Optional[str]: return self._creatorId @property - def recipients(self): + def recipients(self) -> List[str]: return self._recipients or list() @property - def createdAt(self): + def createdAt(self) -> Optional["datetime"]: return self._createdAt @property - def updatedAt(self): + def updatedAt(self) -> Optional["datetime"]: return self._updatedAt @property - def owner_id(self): + def owner_id(self) -> Optional[str]: return self._owner_id @property - def owner_name(self): + def owner_name(self) -> Optional[str]: return self._owner_name @property - def view_id(self): + def view_id(self) -> Optional[str]: return self._view_id @property - def view_name(self): + def view_name(self) -> Optional[str]: return self._view_name @property - def workbook_id(self): + def workbook_id(self) -> Optional[str]: return self._workbook_id @property - def workbook_name(self): + def workbook_name(self) -> Optional[str]: return self._workbook_name @property - def project_id(self): + def project_id(self) -> Optional[str]: return self._project_id @property - def project_name(self): + def project_name(self) -> Optional[str]: return self._project_name def _set_values( @@ -173,7 +179,7 @@ def _set_values( self._recipients = recipients @classmethod - def from_response(cls, resp, ns): + def from_response(cls, resp, ns) -> List["DataAlertItem"]: all_alert_items = list() parsed_response = ET.fromstring(resp) all_alert_xml = parsed_response.findall(".//t:dataAlert", namespaces=ns) diff --git a/tableauserverclient/models/workbook_item.py b/tableauserverclient/models/workbook_item.py index 8e342686c..f7ba75c73 100644 --- a/tableauserverclient/models/workbook_item.py +++ b/tableauserverclient/models/workbook_item.py @@ -12,7 +12,7 @@ import copy import uuid -from typing import Dict, List, Optional, Set, TYPE_CHECKING +from typing import Dict, List, Optional, Set, TYPE_CHECKING, Union if TYPE_CHECKING: from .connection_item import ConnectionItem diff --git a/tableauserverclient/server/endpoint/data_alert_endpoint.py b/tableauserverclient/server/endpoint/data_alert_endpoint.py index d2e5d55a7..78d28a434 100644 --- a/tableauserverclient/server/endpoint/data_alert_endpoint.py +++ b/tableauserverclient/server/endpoint/data_alert_endpoint.py @@ -1,7 +1,5 @@ from .endpoint import api, Endpoint from .exceptions import MissingRequiredFieldError -from .permissions_endpoint import _PermissionsEndpoint -from .default_permissions_endpoint import _DefaultPermissionsEndpoint from .. import RequestFactory, DataAlertItem, PaginationItem, UserItem @@ -9,17 +7,24 @@ logger = logging.getLogger("tableau.endpoint.dataAlerts") +from typing import List, Optional, TYPE_CHECKING, Tuple, Union + + +if TYPE_CHECKING: + from ..server import Server + from ..request_options import RequestOptions + class DataAlerts(Endpoint): - def __init__(self, parent_srv): + def __init__(self, parent_srv: "Server") -> None: super(DataAlerts, self).__init__(parent_srv) @property - def baseurl(self): + def baseurl(self) -> str: return "{0}/sites/{1}/dataAlerts".format(self.parent_srv.baseurl, self.parent_srv.site_id) @api(version="3.2") - def get(self, req_options=None): + def get(self, req_options: Optional["RequestOptions"] = None) -> Tuple[List[DataAlertItem], PaginationItem]: logger.info("Querying all dataAlerts on site") url = self.baseurl server_response = self.get_request(url, req_options) @@ -29,7 +34,7 @@ def get(self, req_options=None): # Get 1 dataAlert @api(version="3.2") - def get_by_id(self, dataAlert_id): + def get_by_id(self, dataAlert_id: str) -> DataAlertItem: if not dataAlert_id: error = "dataAlert ID undefined." raise ValueError(error) @@ -39,8 +44,13 @@ def get_by_id(self, dataAlert_id): return DataAlertItem.from_response(server_response.content, self.parent_srv.namespace)[0] @api(version="3.2") - def delete(self, dataAlert): - dataAlert_id = getattr(dataAlert, "id", dataAlert) + def delete(self, dataAlert: Union[DataAlertItem, str]) -> None: + if isinstance(dataAlert, DataAlertItem): + dataAlert_id = dataAlert.id + elif isinstance(dataAlert, str): + dataAlert_id = dataAlert + else: + raise TypeError("dataAlert should be a DataAlertItem or a string of an id.") if not dataAlert_id: error = "Dataalert ID undefined." raise ValueError(error) @@ -50,9 +60,19 @@ def delete(self, dataAlert): logger.info("Deleted single dataAlert (ID: {0})".format(dataAlert_id)) @api(version="3.2") - def delete_user_from_alert(self, dataAlert, user): - dataAlert_id = getattr(dataAlert, "id", dataAlert) - user_id = getattr(user, "id", user) + def delete_user_from_alert(self, dataAlert: Union[DataAlertItem, str], user: Union[UserItem, str]) -> None: + if isinstance(dataAlert, DataAlertItem): + dataAlert_id = dataAlert.id + elif isinstance(dataAlert, str): + dataAlert_id = dataAlert + else: + raise TypeError("dataAlert should be a DataAlertItem or a string of an id.") + if isinstance(user, UserItem): + user_id = user.id + elif isinstance(user, str): + user_id = user + else: + raise TypeError("user should be a UserItem or a string of an id.") if not dataAlert_id: error = "Dataalert ID undefined." raise ValueError(error) @@ -65,11 +85,16 @@ def delete_user_from_alert(self, dataAlert, user): logger.info("Deleted User (ID {0}) from dataAlert (ID: {1})".format(user_id, dataAlert_id)) @api(version="3.2") - def add_user_to_alert(self, dataAlert_item, user): + def add_user_to_alert(self, dataAlert_item: DataAlertItem, user: Union[UserItem, str]) -> UserItem: + if isinstance(user, UserItem): + user_id = user.id + elif isinstance(user, str): + user_id = user + else: + raise TypeError("user should be a UserItem or a string of an id.") if not dataAlert_item.id: error = "Dataalert item missing ID." raise MissingRequiredFieldError(error) - user_id = getattr(user, "id", user) if not user_id: error = "User ID undefined." raise ValueError(error) @@ -77,11 +102,11 @@ def add_user_to_alert(self, dataAlert_item, user): update_req = RequestFactory.DataAlert.add_user_to_alert(dataAlert_item, user_id) server_response = self.post_request(url, update_req) logger.info("Added user (ID {0}) to dataAlert item (ID: {1})".format(user_id, dataAlert_item.id)) - user = UserItem.from_response(server_response.content, self.parent_srv.namespace)[0] - return user + added_user = UserItem.from_response(server_response.content, self.parent_srv.namespace)[0] + return added_user @api(version="3.2") - def update(self, dataAlert_item): + def update(self, dataAlert_item: DataAlertItem) -> DataAlertItem: if not dataAlert_item.id: error = "Dataalert item missing ID." raise MissingRequiredFieldError(error) diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index 1841cd06e..16a11a018 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -5,6 +5,11 @@ from ..models import TaskItem, UserItem, GroupItem, PermissionsRule, FavoriteItem +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..models import DataAlertItem + def _add_multipart(parts): mime_multipart_parts = list() @@ -87,22 +92,26 @@ def update_req(self, column_item): class DataAlertRequest(object): - def add_user_to_alert(self, alert_item, user_id): + def add_user_to_alert(self, alert_item: "DataAlertItem", user_id: str) -> bytes: xml_request = ET.Element("tsRequest") user_element = ET.SubElement(xml_request, "user") user_element.attrib["id"] = user_id return ET.tostring(xml_request) - def update_req(self, alert_item): + def update_req(self, alert_item: "DataAlertItem") -> bytes: xml_request = ET.Element("tsRequest") dataAlert_element = ET.SubElement(xml_request, "dataAlert") - dataAlert_element.attrib["subject"] = alert_item.subject - dataAlert_element.attrib["frequency"] = alert_item.frequency.lower() - dataAlert_element.attrib["public"] = alert_item.public + if alert_item.subject is not None: + dataAlert_element.attrib["subject"] = alert_item.subject + if alert_item.frequency is not None: + dataAlert_element.attrib["frequency"] = alert_item.frequency.lower() + if alert_item.public is not None: + dataAlert_element.attrib["public"] = str(alert_item.public).lower() owner = ET.SubElement(dataAlert_element, "owner") - owner.attrib["id"] = alert_item.owner_id + if alert_item.owner_id is not None: + owner.attrib["id"] = alert_item.owner_id return ET.tostring(xml_request) diff --git a/test/test_dataalert.py b/test/test_dataalert.py index a2810ce45..059046123 100644 --- a/test/test_dataalert.py +++ b/test/test_dataalert.py @@ -15,7 +15,7 @@ class DataAlertTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.server = TSC.Server("http://test") # Fake signin @@ -25,7 +25,7 @@ def setUp(self): self.baseurl = self.server.data_alerts.baseurl - def test_get(self): + def test_get(self) -> None: response_xml = read_xml_asset(GET_XML) with requests_mock.mock() as m: m.get(self.baseurl, text=response_xml) @@ -48,7 +48,7 @@ def test_get(self): self.assertEqual("5241e88d-d384-4fd7-9c2f-648b5247efc5", all_alerts[0].project_id) self.assertEqual("Default", all_alerts[0].project_name) - def test_get_by_id(self): + def test_get_by_id(self) -> None: response_xml = read_xml_asset(GET_BY_ID_XML) with requests_mock.mock() as m: m.get(self.baseurl + "/5ea59b45-e497-5673-8809-bfe213236f75", text=response_xml) @@ -58,7 +58,7 @@ def test_get_by_id(self): self.assertEqual(len(alert.recipients), 1) self.assertEqual(alert.recipients[0], "dd2239f6-ddf1-4107-981a-4cf94e415794") - def test_update(self): + def test_update(self) -> None: response_xml = read_xml_asset(UPDATE_XML) with requests_mock.mock() as m: m.put(self.baseurl + "/5ea59b45-e497-5673-8809-bfe213236f75", text=response_xml) @@ -66,7 +66,7 @@ def test_update(self): single_alert._id = "5ea59b45-e497-5673-8809-bfe213236f75" single_alert._subject = "Data Alert test" single_alert._frequency = "Daily" - single_alert._public = "true" + single_alert._public = True single_alert._owner_id = "5de011f8-5aa9-4d5b-b991-f462c8dd6bb7" single_alert = self.server.data_alerts.update(single_alert) @@ -86,7 +86,7 @@ def test_update(self): self.assertEqual("5241e88d-d384-4fd7-9c2f-648b5247efc5", single_alert.project_id) self.assertEqual("Default", single_alert.project_name) - def test_add_user_to_alert(self): + def test_add_user_to_alert(self) -> None: response_xml = read_xml_asset(ADD_USER_TO_ALERT) single_alert = TSC.DataAlertItem() single_alert._id = "0448d2ed-590d-4fa0-b272-a2a8a24555b5" @@ -102,12 +102,12 @@ def test_add_user_to_alert(self): self.assertEqual(out_user.name, in_user.name) self.assertEqual(out_user.site_role, in_user.site_role) - def test_delete(self): + def test_delete(self) -> None: with requests_mock.mock() as m: m.delete(self.baseurl + "/0448d2ed-590d-4fa0-b272-a2a8a24555b5", status_code=204) self.server.data_alerts.delete("0448d2ed-590d-4fa0-b272-a2a8a24555b5") - def test_delete_user_from_alert(self): + def test_delete_user_from_alert(self) -> None: alert_id = "5ea59b45-e497-5673-8809-bfe213236f75" user_id = "5de011f8-5aa9-4d5b-b991-f462c8dd6bb7" with requests_mock.mock() as m: