diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 61476132f..2b4039fec 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -1,6 +1,6 @@ name: Python tests -on: [push] +on: [push, pull_request] jobs: build: @@ -24,13 +24,11 @@ jobs: run: | python -m pip install --upgrade pip pip install -e .[test] - pip install mypy - name: Test with pytest run: | pytest test - - name: Run Mypy but allow failures + - name: Run Mypy tests run: | - mypy --show-error-codes --disable-error-code misc tableauserverclient - continue-on-error: true + mypy --show-error-codes --disable-error-code misc --disable-error-code import tableauserverclient test diff --git a/setup.cfg b/setup.cfg index 6136b814a..1debabe18 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,3 +26,6 @@ smoke=pytest [tool:pytest] testpaths = test smoke addopts = --junitxml=./test.junit.xml + +[mypy] +ignore_missing_imports = True diff --git a/setup.py b/setup.py index 8b374f0ce..429e7c09d 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ # This makes work easier for offline installs or low bandwidth machines needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv) pytest_runner = ['pytest-runner'] if needs_pytest else [] -test_requirements = ['mock', 'pycodestyle', 'pytest', 'requests-mock>=1.0,<2.0'] +test_requirements = ['mock', 'pycodestyle', 'pytest', 'requests-mock>=1.0,<2.0', 'mypy==0.910'] setup( name='tableauserverclient', diff --git a/tableauserverclient/_version.py b/tableauserverclient/_version.py index 1737a980a..62c3ce236 100644 --- a/tableauserverclient/_version.py +++ b/tableauserverclient/_version.py @@ -51,7 +51,7 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY = {} +LONG_VERSION_PY = {} # type: ignore HANDLERS = {} diff --git a/tableauserverclient/exponential_backoff.py b/tableauserverclient/exponential_backoff.py index 2b3ded109..69ffdc96b 100644 --- a/tableauserverclient/exponential_backoff.py +++ b/tableauserverclient/exponential_backoff.py @@ -1,12 +1,12 @@ import time # Polling for server-side events (such as job completion) uses exponential backoff for the sleep intervals between polls -ASYNC_POLL_MIN_INTERVAL=0.5 -ASYNC_POLL_MAX_INTERVAL=30 -ASYNC_POLL_BACKOFF_FACTOR=1.4 +ASYNC_POLL_MIN_INTERVAL = 0.5 +ASYNC_POLL_MAX_INTERVAL = 30 +ASYNC_POLL_BACKOFF_FACTOR = 1.4 -class ExponentialBackoffTimer(): +class ExponentialBackoffTimer: def __init__(self, *, timeout=None): self.start_time = time.time() self.timeout = timeout @@ -15,7 +15,7 @@ def __init__(self, *, timeout=None): def sleep(self): max_sleep_time = ASYNC_POLL_MAX_INTERVAL if self.timeout is not None: - elapsed = (time.time() - self.start_time) + elapsed = time.time() - self.start_time if elapsed >= self.timeout: raise TimeoutError(f"Timeout after {elapsed} seconds waiting for asynchronous event") remaining_time = self.timeout - elapsed @@ -27,4 +27,4 @@ def sleep(self): max_sleep_time = max(max_sleep_time, ASYNC_POLL_MIN_INTERVAL) time.sleep(min(self.current_sleep_interval, max_sleep_time)) - self.current_sleep_interval *= ASYNC_POLL_BACKOFF_FACTOR \ No newline at end of file + self.current_sleep_interval *= ASYNC_POLL_BACKOFF_FACTOR diff --git a/tableauserverclient/models/database_item.py b/tableauserverclient/models/database_item.py index 4934af81b..514bf92bc 100644 --- a/tableauserverclient/models/database_item.py +++ b/tableauserverclient/models/database_item.py @@ -53,6 +53,11 @@ def dqws(self): def content_permissions(self): return self._content_permissions + @content_permissions.setter + @property_is_enum(ContentPermissions) + def content_permissions(self, value): + self._content_permissions = value + @property def permissions(self): if self._permissions is None: @@ -67,11 +72,6 @@ def default_table_permissions(self): raise UnpopulatedPropertyError(error) return self._default_table_permissions() - @content_permissions.setter - @property_is_enum(ContentPermissions) - def content_permissions(self, value): - self._content_permissions = value - @property def id(self): return self._id diff --git a/tableauserverclient/models/datasource_item.py b/tableauserverclient/models/datasource_item.py index 5b23341d0..665be9db1 100644 --- a/tableauserverclient/models/datasource_item.py +++ b/tableauserverclient/models/datasource_item.py @@ -9,6 +9,13 @@ from ..datetime_helpers import parse_datetime import copy +from typing import Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union + +if TYPE_CHECKING: + from .permissions_item import PermissionsRule + from .connection_item import ConnectionItem + import datetime + class DatasourceItem(object): class AskDataEnablement: @@ -16,79 +23,81 @@ class AskDataEnablement: Disabled = "Disabled" SiteDefault = "SiteDefault" - def __init__(self, project_id, name=None): + def __init__(self, project_id: str, name: str = None) -> None: self._ask_data_enablement = None self._certified = None self._certification_note = None self._connections = None - self._content_url = None + self._content_url: Optional[str] = None self._created_at = None self._datasource_type = None self._description = None self._encrypt_extracts = None self._has_extracts = None - self._id = None - self._initial_tags = set() - self._project_name = None + self._id: Optional[str] = None + self._initial_tags: Set = set() + self._project_name: Optional[str] = None self._updated_at = None self._use_remote_query_agent = None self._webpage_url = None self.description = None self.name = name - self.owner_id = None + self.owner_id: Optional[str] = None self.project_id = project_id - self.tags = set() + self.tags: Set[str] = set() self._permissions = None self._data_quality_warnings = None + return None + @property - def ask_data_enablement(self): + def ask_data_enablement(self) -> Optional["DatasourceItem.AskDataEnablement"]: return self._ask_data_enablement @ask_data_enablement.setter @property_is_enum(AskDataEnablement) - def ask_data_enablement(self, value): + def ask_data_enablement(self, value: Optional["DatasourceItem.AskDataEnablement"]): self._ask_data_enablement = value @property - def connections(self): + def connections(self) -> Optional[List["ConnectionItem"]]: if self._connections is None: error = "Datasource item must be populated with connections first." raise UnpopulatedPropertyError(error) return self._connections() @property - def permissions(self): + def permissions(self) -> Optional[List["PermissionsRule"]]: if self._permissions is None: error = "Project item must be populated with permissions first." raise UnpopulatedPropertyError(error) return self._permissions() @property - def content_url(self): + def content_url(self) -> Optional[str]: return self._content_url @property - def created_at(self): + def created_at(self) -> Optional["datetime.datetime"]: return self._created_at @property - def certified(self): + def certified(self) -> Optional[bool]: return self._certified @certified.setter @property_not_nullable @property_is_boolean - def certified(self, value): + def certified(self, value: Optional[bool]): self._certified = value @property - def certification_note(self): + def certification_note(self) -> Optional[str]: return self._certification_note @certification_note.setter - def certification_note(self, value): + def certification_note(self, value: Optional[str]): self._certification_note = value @property @@ -97,7 +106,7 @@ def encrypt_extracts(self): @encrypt_extracts.setter @property_is_boolean - def encrypt_extracts(self, value): + def encrypt_extracts(self, value: Optional[bool]): self._encrypt_extracts = value @property @@ -108,53 +117,53 @@ def dqws(self): return self._data_quality_warnings() @property - def has_extracts(self): + def has_extracts(self) -> Optional[bool]: return self._has_extracts @property - def id(self): + def id(self) -> Optional[str]: return self._id @property - def project_id(self): + def project_id(self) -> str: return self._project_id @project_id.setter @property_not_nullable - def project_id(self, value): + def project_id(self, value: str): self._project_id = value @property - def project_name(self): + def project_name(self) -> Optional[str]: return self._project_name @property - def datasource_type(self): + def datasource_type(self) -> Optional[str]: return self._datasource_type @property - def description(self): + def description(self) -> Optional[str]: return self._description @description.setter - def description(self, value): + def description(self, value: str): self._description = value @property - def updated_at(self): + def updated_at(self) -> Optional["datetime.datetime"]: return self._updated_at @property - def use_remote_query_agent(self): + def use_remote_query_agent(self) -> Optional[bool]: return self._use_remote_query_agent @use_remote_query_agent.setter @property_is_boolean - def use_remote_query_agent(self, value): + def use_remote_query_agent(self, value: bool): self._use_remote_query_agent = value @property - def webpage_url(self): + def webpage_url(self) -> Optional[str]: return self._webpage_url def _set_connections(self, connections): @@ -271,7 +280,7 @@ def _set_values( self._webpage_url = webpage_url @classmethod - def from_response(cls, resp, ns): + def from_response(cls, resp: str, ns: Dict) -> List["DatasourceItem"]: all_datasource_items = list() parsed_response = ET.fromstring(resp) all_datasource_xml = parsed_response.findall(".//t:datasource", namespaces=ns) @@ -322,16 +331,16 @@ def from_response(cls, resp, ns): return all_datasource_items @staticmethod - def _parse_element(datasource_xml, ns): - id_ = datasource_xml.get('id', None) - name = datasource_xml.get('name', None) - datasource_type = datasource_xml.get('type', None) - description = datasource_xml.get('description', None) - content_url = datasource_xml.get('contentUrl', None) - created_at = parse_datetime(datasource_xml.get('createdAt', None)) - updated_at = parse_datetime(datasource_xml.get('updatedAt', None)) - certification_note = datasource_xml.get('certificationNote', None) - certified = str(datasource_xml.get('isCertified', None)).lower() == 'true' + def _parse_element(datasource_xml: ET.Element, ns: Dict) -> Tuple: + id_ = datasource_xml.get("id", None) + name = datasource_xml.get("name", None) + datasource_type = datasource_xml.get("type", None) + description = datasource_xml.get("description", None) + content_url = datasource_xml.get("contentUrl", None) + created_at = parse_datetime(datasource_xml.get("createdAt", None)) + updated_at = parse_datetime(datasource_xml.get("updatedAt", None)) + certification_note = datasource_xml.get("certificationNote", None) + certified = str(datasource_xml.get("isCertified", None)).lower() == "true" certification_note = datasource_xml.get("certificationNote", None) certified = str(datasource_xml.get("isCertified", None)).lower() == "true" content_url = datasource_xml.get("contentUrl", None) diff --git a/tableauserverclient/models/dqw_item.py b/tableauserverclient/models/dqw_item.py index a7f8ec9cb..3285e3022 100644 --- a/tableauserverclient/models/dqw_item.py +++ b/tableauserverclient/models/dqw_item.py @@ -80,14 +80,6 @@ def severe(self): def severe(self, value): self._severe = value - @property - def active(self): - return self._active - - @active.setter - def active(self, value): - self._active = value - @property def created_at(self): return self._created_at diff --git a/tableauserverclient/models/job_item.py b/tableauserverclient/models/job_item.py index 2a8b6b509..f8c00b555 100644 --- a/tableauserverclient/models/job_item.py +++ b/tableauserverclient/models/job_item.py @@ -8,11 +8,11 @@ class FinishCode: Status codes as documented on https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_ref_jobs_tasks_and_schedules.htm#query_job """ + Success = 0 Failed = 1 Cancelled = 2 - def __init__( self, id_, diff --git a/tableauserverclient/models/project_item.py b/tableauserverclient/models/project_item.py index 3a7d01143..c6525a2ca 100644 --- a/tableauserverclient/models/project_item.py +++ b/tableauserverclient/models/project_item.py @@ -29,6 +29,11 @@ def __init__(self, name, description=None, content_permissions=None, parent_id=N def content_permissions(self): return self._content_permissions + @content_permissions.setter + @property_is_enum(ContentPermissions) + def content_permissions(self, value): + self._content_permissions = value + @property def permissions(self): if self._permissions is None: @@ -57,11 +62,6 @@ def default_flow_permissions(self): raise UnpopulatedPropertyError(error) return self._default_flow_permissions() - @content_permissions.setter - @property_is_enum(ContentPermissions) - def content_permissions(self, value): - self._content_permissions = value - @property def id(self): return self._id diff --git a/tableauserverclient/models/workbook_item.py b/tableauserverclient/models/workbook_item.py index 9c7e2022e..ef0dc6f6f 100644 --- a/tableauserverclient/models/workbook_item.py +++ b/tableauserverclient/models/workbook_item.py @@ -12,15 +12,29 @@ import copy import uuid +from typing import ( + Dict, + List, + Optional, + Set, + TYPE_CHECKING, + Union +) + +if TYPE_CHECKING: + from .connection_item import ConnectionItem + from .permissions_item import PermissionsRule + import datetime + class WorkbookItem(object): - def __init__(self, project_id, name=None, show_tabs=False): + def __init__(self, project_id: str, name: str = None, show_tabs: bool = False) -> None: self._connections = None self._content_url = None self._webpage_url = None self._created_at = None - self._id = None - self._initial_tags = set() + self._id: Optional[str] = None + self._initial_tags: set = set() self._pdf = None self._preview_image = None self._project_name = None @@ -29,10 +43,10 @@ def __init__(self, project_id, name=None, show_tabs=False): self._views = None self.name = name self._description = None - self.owner_id = None + self.owner_id: Optional[str] = None self.project_id = project_id self.show_tabs = show_tabs - self.tags = set() + self.tags: Set[str] = set() self.data_acceleration_config = { "acceleration_enabled": None, "accelerate_now": None, @@ -41,38 +55,40 @@ def __init__(self, project_id, name=None, show_tabs=False): } self._permissions = None + return None + @property - def connections(self): + def connections(self) -> List["ConnectionItem"]: if self._connections is None: error = "Workbook item must be populated with connections first." raise UnpopulatedPropertyError(error) return self._connections() @property - def permissions(self): + def permissions(self) -> List["PermissionsRule"]: if self._permissions is None: error = "Workbook item must be populated with permissions first." raise UnpopulatedPropertyError(error) return self._permissions() @property - def content_url(self): + def content_url(self) -> Optional[str]: return self._content_url @property - def webpage_url(self): + def webpage_url(self) -> Optional[str]: return self._webpage_url @property - def created_at(self): + def created_at(self) -> Optional["datetime.datetime"]: return self._created_at @property - def description(self): + def description(self) -> Optional[str]: return self._description @property - def id(self): + def id(self) -> Optional[str]: return self._id @property @@ -90,25 +106,25 @@ def preview_image(self): return self._preview_image() @property - def project_id(self): + def project_id(self) -> Optional[str]: return self._project_id @project_id.setter @property_not_nullable - def project_id(self, value): + def project_id(self, value: str): self._project_id = value @property - def project_name(self): + def project_name(self) -> Optional[str]: return self._project_name @property - def show_tabs(self): + def show_tabs(self) -> bool: return self._show_tabs @show_tabs.setter @property_is_boolean - def show_tabs(self, value): + def show_tabs(self, value: bool): self._show_tabs = value @property @@ -116,11 +132,11 @@ def size(self): return self._size @property - def updated_at(self): + def updated_at(self) -> Optional["datetime.datetime"]: return self._updated_at @property - def views(self): + def views(self) -> List[ViewItem]: # Views can be set in an initial workbook response OR by a call # to Server. Without getting too fancy, I think we can rely on # returning a list from the response, until they call @@ -253,7 +269,7 @@ def _set_values( self.data_acceleration_config = data_acceleration_config @classmethod - def from_response(cls, resp, ns): + def from_response(cls, resp: str, ns: Dict[str, str]) -> List["WorkbookItem"]: all_workbook_items = list() parsed_response = ET.fromstring(resp) all_workbook_xml = parsed_response.findall(".//t:workbook", namespaces=ns) @@ -394,5 +410,5 @@ def parse_data_acceleration_config(data_acceleration_elem): # Used to convert string represented boolean to a boolean type -def string_to_bool(s): +def string_to_bool(s: str) -> bool: return s.lower() == "true" diff --git a/tableauserverclient/server/endpoint/datasources_endpoint.py b/tableauserverclient/server/endpoint/datasources_endpoint.py index c031004e0..f7a2a4405 100644 --- a/tableauserverclient/server/endpoint/datasources_endpoint.py +++ b/tableauserverclient/server/endpoint/datasources_endpoint.py @@ -3,7 +3,7 @@ from .permissions_endpoint import _PermissionsEndpoint from .dqw_endpoint import _DataQualityWarningEndpoint from .resource_tagger import _ResourceTagger -from .. import RequestFactory, DatasourceItem, PaginationItem, ConnectionItem +from .. import RequestFactory, DatasourceItem, PaginationItem, ConnectionItem, RequestOptions from ..query import QuerySet from ...filesys_helpers import ( to_filename, @@ -12,7 +12,9 @@ get_file_object_size, ) from ...models.job_item import JobItem +from ...models import ConnectionCredentials +import io import os import logging import copy @@ -20,6 +22,19 @@ from contextlib import closing import json +from pathlib import Path +from typing import ( + List, + Mapping, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) + +io_types = (io.BytesIO, io.BufferedReader) + # The maximum size of a file that can be published in a single request is 64MB FILESIZE_LIMIT = 1024 * 1024 * 64 # 64MB @@ -27,21 +42,31 @@ logger = logging.getLogger("tableau.endpoint.datasources") +if TYPE_CHECKING: + from ..server import Server + from ...models import PermissionsRule + +FilePath = Union[str, os.PathLike] +FileObject = Union[io.BufferedReader, io.BytesIO] +PathOrFile = Union[FilePath, FileObject] + class Datasources(QuerysetEndpoint): - def __init__(self, parent_srv): + def __init__(self, parent_srv: "Server") -> None: super(Datasources, self).__init__(parent_srv) self._resource_tagger = _ResourceTagger(parent_srv) self._permissions = _PermissionsEndpoint(parent_srv, lambda: self.baseurl) self._data_quality_warnings = _DataQualityWarningEndpoint(self.parent_srv, "datasource") + return None + @property - def baseurl(self): + def baseurl(self) -> str: return "{0}/sites/{1}/datasources".format(self.parent_srv.baseurl, self.parent_srv.site_id) # Get all datasources @api(version="2.0") - def get(self, req_options=None): + def get(self, req_options: RequestOptions = None) -> Tuple[List[DatasourceItem], PaginationItem]: logger.info("Querying all datasources on site") url = self.baseurl server_response = self.get_request(url, req_options) @@ -51,7 +76,7 @@ def get(self, req_options=None): # Get 1 datasource by id @api(version="2.0") - def get_by_id(self, datasource_id): + def get_by_id(self, datasource_id: str) -> DatasourceItem: if not datasource_id: error = "Datasource ID undefined." raise ValueError(error) @@ -62,7 +87,7 @@ def get_by_id(self, datasource_id): # Populate datasource item's connections @api(version="2.0") - def populate_connections(self, datasource_item): + def populate_connections(self, datasource_item: DatasourceItem) -> None: if not datasource_item.id: error = "Datasource item missing ID. Datasource must be retrieved from server first." raise MissingRequiredFieldError(error) @@ -81,7 +106,7 @@ def _get_datasource_connections(self, datasource_item, req_options=None): # Delete 1 datasource by id @api(version="2.0") - def delete(self, datasource_id): + def delete(self, datasource_id: str) -> None: if not datasource_id: error = "Datasource ID undefined." raise ValueError(error) @@ -93,7 +118,13 @@ def delete(self, datasource_id): @api(version="2.0") @parameter_added_in(no_extract="2.5") @parameter_added_in(include_extract="2.5") - def download(self, datasource_id, filepath=None, include_extract=True, no_extract=None): + def download( + self, + datasource_id: str, + filepath: FilePath = None, + include_extract: bool = True, + no_extract: Optional[bool] = None, + ) -> str: if not datasource_id: error = "Datasource ID undefined." raise ValueError(error) @@ -126,7 +157,7 @@ def download(self, datasource_id, filepath=None, include_extract=True, no_extrac # Update datasource @api(version="2.0") - def update(self, datasource_item): + def update(self, datasource_item: DatasourceItem) -> DatasourceItem: if not datasource_item.id: error = "Datasource item missing ID. Datasource must be retrieved from server first." raise MissingRequiredFieldError(error) @@ -143,7 +174,7 @@ def update(self, datasource_item): # Update datasource connections @api(version="2.3") - def update_connection(self, datasource_item, connection_item): + def update_connection(self, datasource_item: DatasourceItem, connection_item: ConnectionItem) -> ConnectionItem: url = "{0}/{1}/connections/{2}".format(self.baseurl, datasource_item.id, connection_item.id) update_req = RequestFactory.Connection.update_req(connection_item) @@ -156,7 +187,7 @@ def update_connection(self, datasource_item, connection_item): return connection @api(version="2.8") - def refresh(self, datasource_item): + def refresh(self, datasource_item: DatasourceItem) -> JobItem: id_ = getattr(datasource_item, "id", datasource_item) url = "{0}/{1}/refresh".format(self.baseurl, id_) empty_req = RequestFactory.Empty.empty_req() @@ -165,7 +196,7 @@ def refresh(self, datasource_item): return new_job @api(version="3.5") - def create_extract(self, datasource_item, encrypt=False): + def create_extract(self, datasource_item: DatasourceItem, encrypt: bool = False) -> JobItem: id_ = getattr(datasource_item, "id", datasource_item) url = "{0}/{1}/createExtract?encrypt={2}".format(self.baseurl, id_, encrypt) empty_req = RequestFactory.Empty.empty_req() @@ -174,7 +205,7 @@ def create_extract(self, datasource_item, encrypt=False): return new_job @api(version="3.5") - def delete_extract(self, datasource_item): + def delete_extract(self, datasource_item: DatasourceItem) -> None: id_ = getattr(datasource_item, "id", datasource_item) url = "{0}/{1}/deleteExtract".format(self.baseurl, id_) empty_req = RequestFactory.Empty.empty_req() @@ -186,16 +217,15 @@ def delete_extract(self, datasource_item): @parameter_added_in(as_job="3.0") def publish( self, - datasource_item, - file, - mode, - connection_credentials=None, - connections=None, - as_job=False, - ): - - try: - + datasource_item: DatasourceItem, + file: PathOrFile, + mode: str, + connection_credentials: ConnectionCredentials = None, + connections: Sequence[ConnectionItem] = None, + as_job: bool = False, + ) -> Union[DatasourceItem, JobItem]: + + if isinstance(file, (os.PathLike, str)): if not os.path.isfile(file): error = "File path does not lead to an existing file." raise IOError(error) @@ -211,7 +241,7 @@ def publish( error = "Only {} files can be published as datasources.".format(", ".join(ALLOWED_FILE_EXTENSIONS)) raise ValueError(error) - except TypeError: + elif isinstance(file, io_types): if not datasource_item.name: error = "Datasource item must have a name when passing a file object" @@ -229,6 +259,9 @@ def publish( filename = "{}.{}".format(datasource_item.name, file_extension) file_size = get_file_object_size(file) + else: + raise TypeError("file should be a filepath or file object.") + if not mode or not hasattr(self.parent_srv.PublishMode, mode): error = "Invalid mode defined." raise ValueError(error) @@ -252,11 +285,13 @@ def publish( else: logger.info("Publishing {0} to server".format(filename)) - try: + if isinstance(file, (Path, str)): with open(file, "rb") as f: file_contents = f.read() - except TypeError: + elif isinstance(file, io_types): file_contents = file.read() + else: + raise TypeError("file should be a filepath or file object.") xml_request, content_type = RequestFactory.Datasource.publish_req( datasource_item, @@ -284,7 +319,14 @@ def publish( return new_datasource @api(version="3.13") - def update_hyper_data(self, datasource_or_connection_item, *, request_id, actions, payload = None): + def update_hyper_data( + self, + datasource_or_connection_item: Union[DatasourceItem, ConnectionItem, str], + *, + request_id: str, + actions: Sequence[Mapping], + payload: Optional[FilePath] = None + ) -> JobItem: if isinstance(datasource_or_connection_item, DatasourceItem): datasource_id = datasource_or_connection_item.id url = "{0}/{1}/data".format(self.baseurl, datasource_id) @@ -312,7 +354,7 @@ def update_hyper_data(self, datasource_or_connection_item, *, request_id, action return new_job @api(version="2.0") - def populate_permissions(self, item): + def populate_permissions(self, item: DatasourceItem) -> None: self._permissions.populate(item) @api(version="2.0") @@ -327,11 +369,11 @@ def update_permission(self, item, permission_item): self._permissions.update(item, permission_item) @api(version="2.0") - def update_permissions(self, item, permission_item): + def update_permissions(self, item: DatasourceItem, permission_item: List["PermissionsRule"]) -> None: self._permissions.update(item, permission_item) @api(version="2.0") - def delete_permission(self, item, capability_item): + def delete_permission(self, item: DatasourceItem, capability_item: "PermissionsRule") -> None: self._permissions.delete(item, capability_item) @api(version="3.5") diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index 31291abc9..9cc0a6050 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -9,10 +9,7 @@ from ..query import QuerySet import logging -try: - from distutils2.version import NormalizedVersion as Version -except ImportError: - from distutils.version import LooseVersion as Version +from distutils.version import LooseVersion as Version logger = logging.getLogger("tableau.endpoint") diff --git a/tableauserverclient/server/endpoint/exceptions.py b/tableauserverclient/server/endpoint/exceptions.py index 693817ddc..3d39e7102 100644 --- a/tableauserverclient/server/endpoint/exceptions.py +++ b/tableauserverclient/server/endpoint/exceptions.py @@ -70,7 +70,7 @@ class JobFailedException(Exception): def __init__(self, job): self.notes = job.notes self.job = job - + def __str__(self): return f"Job {self.job.id} failed with notes {self.notes}" diff --git a/tableauserverclient/server/endpoint/jobs_endpoint.py b/tableauserverclient/server/endpoint/jobs_endpoint.py index 4c975c523..ab0002ac0 100644 --- a/tableauserverclient/server/endpoint/jobs_endpoint.py +++ b/tableauserverclient/server/endpoint/jobs_endpoint.py @@ -8,6 +8,7 @@ logger = logging.getLogger("tableau.endpoint.jobs") + class Jobs(Endpoint): @property def baseurl(self): diff --git a/tableauserverclient/server/endpoint/workbooks_endpoint.py b/tableauserverclient/server/endpoint/workbooks_endpoint.py index a3f14c291..a631ae170 100644 --- a/tableauserverclient/server/endpoint/workbooks_endpoint.py +++ b/tableauserverclient/server/endpoint/workbooks_endpoint.py @@ -12,32 +12,54 @@ ) import os +from pathlib import Path +import io import logging import copy import cgi from contextlib import closing +from typing import ( + List, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) + +if TYPE_CHECKING: + from ..server import Server + from ..request_options import RequestOptions + from .. import DatasourceItem + from ...models.connection_credentials import ConnectionCredentials + # The maximum size of a file that can be published in a single request is 64MB FILESIZE_LIMIT = 1024 * 1024 * 64 # 64MB ALLOWED_FILE_EXTENSIONS = ["twb", "twbx"] logger = logging.getLogger("tableau.endpoint.workbooks") +FilePath = Union[str, os.PathLike] +FileObject = Union[io.BufferedReader, io.BytesIO] +PathOrFile = Union[FilePath, FileObject] class Workbooks(QuerysetEndpoint): - def __init__(self, parent_srv): + def __init__(self, parent_srv: "Server") -> None: super(Workbooks, self).__init__(parent_srv) self._resource_tagger = _ResourceTagger(parent_srv) self._permissions = _PermissionsEndpoint(parent_srv, lambda: self.baseurl) + return None + @property - def baseurl(self): + def baseurl(self) -> str: return "{0}/sites/{1}/workbooks".format(self.parent_srv.baseurl, self.parent_srv.site_id) # Get all workbooks on site @api(version="2.0") - def get(self, req_options=None): + def get(self, req_options: Optional["RequestOptions"] = None) -> Tuple[List[WorkbookItem], PaginationItem]: logger.info("Querying all workbooks on site") url = self.baseurl server_response = self.get_request(url, req_options) @@ -47,7 +69,7 @@ def get(self, req_options=None): # Get 1 workbook @api(version="2.0") - def get_by_id(self, workbook_id): + def get_by_id(self, workbook_id: str) -> WorkbookItem: if not workbook_id: error = "Workbook ID undefined." raise ValueError(error) @@ -57,7 +79,7 @@ def get_by_id(self, workbook_id): return WorkbookItem.from_response(server_response.content, self.parent_srv.namespace)[0] @api(version="2.8") - def refresh(self, workbook_id): + def refresh(self, workbook_id: str) -> JobItem: id_ = getattr(workbook_id, "id", workbook_id) url = "{0}/{1}/refresh".format(self.baseurl, id_) empty_req = RequestFactory.Empty.empty_req() @@ -67,7 +89,13 @@ def refresh(self, workbook_id): # create one or more extracts on 1 workbook, optionally encrypted @api(version="3.5") - def create_extract(self, workbook_item, encrypt=False, includeAll=True, datasources=None): + def create_extract( + self, + workbook_item: WorkbookItem, + encrypt: bool = False, + includeAll: bool = True, + datasources: Optional[List["DatasourceItem"]] = None, + ) -> JobItem: id_ = getattr(workbook_item, "id", workbook_item) url = "{0}/{1}/createExtract?encrypt={2}".format(self.baseurl, id_, encrypt) @@ -78,7 +106,7 @@ def create_extract(self, workbook_item, encrypt=False, includeAll=True, datasour # delete all the extracts on 1 workbook @api(version="3.5") - def delete_extract(self, workbook_item): + def delete_extract(self, workbook_item: WorkbookItem) -> None: id_ = getattr(workbook_item, "id", workbook_item) url = "{0}/{1}/deleteExtract".format(self.baseurl, id_) empty_req = RequestFactory.Empty.empty_req() @@ -86,7 +114,7 @@ def delete_extract(self, workbook_item): # Delete 1 workbook by id @api(version="2.0") - def delete(self, workbook_id): + def delete(self, workbook_id: str) -> None: if not workbook_id: error = "Workbook ID undefined." raise ValueError(error) @@ -96,7 +124,7 @@ def delete(self, workbook_id): # Update workbook @api(version="2.0") - def update(self, workbook_item): + def update(self, workbook_item: WorkbookItem) -> WorkbookItem: if not workbook_item.id: error = "Workbook item missing ID. Workbook must be retrieved from server first." raise MissingRequiredFieldError(error) @@ -120,7 +148,7 @@ def update_conn(self, *args, **kwargs): # Update workbook_connection @api(version="2.3") - def update_connection(self, workbook_item, connection_item): + def update_connection(self, workbook_item: WorkbookItem, connection_item: ConnectionItem) -> ConnectionItem: url = "{0}/{1}/connections/{2}".format(self.baseurl, workbook_item.id, connection_item.id) update_req = RequestFactory.Connection.update_req(connection_item) server_response = self.put_request(url, update_req) @@ -135,7 +163,13 @@ def update_connection(self, workbook_item, connection_item): @api(version="2.0") @parameter_added_in(no_extract="2.5") @parameter_added_in(include_extract="2.5") - def download(self, workbook_id, filepath=None, include_extract=True, no_extract=None): + def download( + self, + workbook_id: str, + filepath: FilePath = None, + include_extract: bool = True, + no_extract: Optional[bool] = None, + ) -> str: if not workbook_id: error = "Workbook ID undefined." raise ValueError(error) @@ -167,7 +201,7 @@ def download(self, workbook_id, filepath=None, include_extract=True, no_extract= # Get all views of workbook @api(version="2.0") - def populate_views(self, workbook_item, usage=False): + def populate_views(self, workbook_item: WorkbookItem, usage: bool = False) -> None: if not workbook_item.id: error = "Workbook item missing ID. Workbook must be retrieved from server first." raise MissingRequiredFieldError(error) @@ -178,7 +212,7 @@ def view_fetcher(): workbook_item._set_views(view_fetcher) logger.info("Populated views for workbook (ID: {0})".format(workbook_item.id)) - def _get_views_for_workbook(self, workbook_item, usage): + def _get_views_for_workbook(self, workbook_item: WorkbookItem, usage: bool) -> List[ViewItem]: url = "{0}/{1}/views".format(self.baseurl, workbook_item.id) if usage: url += "?includeUsageStatistics=true" @@ -192,7 +226,7 @@ def _get_views_for_workbook(self, workbook_item, usage): # Get all connections of workbook @api(version="2.0") - def populate_connections(self, workbook_item): + def populate_connections(self, workbook_item: WorkbookItem) -> None: if not workbook_item.id: error = "Workbook item missing ID. Workbook must be retrieved from server first." raise MissingRequiredFieldError(error) @@ -203,7 +237,9 @@ def connection_fetcher(): workbook_item._set_connections(connection_fetcher) logger.info("Populated connections for workbook (ID: {0})".format(workbook_item.id)) - def _get_workbook_connections(self, workbook_item, req_options=None): + def _get_workbook_connections( + self, workbook_item: WorkbookItem, req_options: "RequestOptions" = None + ) -> List[ConnectionItem]: url = "{0}/{1}/connections".format(self.baseurl, workbook_item.id) server_response = self.get_request(url, req_options) connections = ConnectionItem.from_response(server_response.content, self.parent_srv.namespace) @@ -211,7 +247,7 @@ def _get_workbook_connections(self, workbook_item, req_options=None): # Get the pdf of the entire workbook if its tabs are enabled, pdf of the default view if its tabs are disabled @api(version="3.4") - def populate_pdf(self, workbook_item, req_options=None): + def populate_pdf(self, workbook_item: WorkbookItem, req_options: "RequestOptions" = None) -> None: if not workbook_item.id: error = "Workbook item missing ID." raise MissingRequiredFieldError(error) @@ -230,7 +266,7 @@ def _get_wb_pdf(self, workbook_item, req_options): # Get preview image of workbook @api(version="2.0") - def populate_preview_image(self, workbook_item): + def populate_preview_image(self, workbook_item: WorkbookItem) -> None: if not workbook_item.id: error = "Workbook item missing ID. Workbook must be retrieved from server first." raise MissingRequiredFieldError(error) @@ -248,7 +284,7 @@ def _get_wb_preview_image(self, workbook_item): return preview_image @api(version="2.0") - def populate_permissions(self, item): + def populate_permissions(self, item: WorkbookItem) -> None: self._permissions.populate(item) @api(version="2.0") @@ -259,20 +295,19 @@ def update_permissions(self, resource, rules): def delete_permission(self, item, capability_item): return self._permissions.delete(item, capability_item) - # Publishes workbook. Chunking method if file over 64MB @api(version="2.0") @parameter_added_in(as_job="3.0") @parameter_added_in(connections="2.8") def publish( self, - workbook_item, - file, - mode, - connection_credentials=None, - connections=None, - as_job=False, - hidden_views=None, - skip_connection_check=False, + workbook_item: WorkbookItem, + file: PathOrFile, + mode: str, + connection_credentials: Optional["ConnectionCredentials"] = None, + connections: Optional[Sequence[ConnectionItem]] = None, + as_job: bool = False, + hidden_views: Optional[Sequence[str]] = None, + skip_connection_check: bool = False, ): if connection_credentials is not None: @@ -283,7 +318,7 @@ def publish( DeprecationWarning, ) - try: + if isinstance(file, (str, os.PathLike)): # Expect file to be a filepath if not os.path.isfile(file): error = "File path does not lead to an existing file." @@ -300,7 +335,7 @@ def publish( error = "Only {} files can be published as workbooks.".format(", ".join(ALLOWED_FILE_EXTENSIONS)) raise ValueError(error) - except TypeError: + elif isinstance(file, (io.BytesIO, io.BufferedReader)): # Expect file to be a file object file_size = get_file_object_size(file) @@ -322,6 +357,9 @@ def publish( # This is needed when publishing the workbook in a single request filename = "{}.{}".format(workbook_item.name, file_extension) + else: + raise TypeError("file should be a filepath or file object.") + if not hasattr(self.parent_srv.PublishMode, mode): error = "Invalid mode defined." raise ValueError(error) @@ -355,13 +393,16 @@ def publish( else: logger.info("Publishing {0} to server".format(filename)) - try: + if isinstance(file, (str, Path)): with open(file, "rb") as f: file_contents = f.read() - except TypeError: + elif isinstance(file, (io.BytesIO, io.BufferedReader)): file_contents = file.read() + else: + raise TypeError("file should be a filepath or file object.") + conn_creds = connection_credentials xml_request, content_type = RequestFactory.Workbook.publish_req( workbook_item, diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index 4cbea1443..21db9c484 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -232,36 +232,6 @@ def update_req(self, database_item): return ET.tostring(xml_request) -class DQWRequest(object): - def add_req(self, dqw_item): - xml_request = ET.Element("tsRequest") - dqw_element = ET.SubElement(xml_request, "dataQualityWarning") - - dqw_element.attrib["isActive"] = str(dqw_item.active).lower() - dqw_element.attrib["isSevere"] = str(dqw_item.severe).lower() - - dqw_element.attrib["type"] = dqw_item.warning_type - - if dqw_item.message: - dqw_element.attrib["message"] = str(dqw_item.message) - - return ET.tostring(xml_request) - - def update_req(self, database_item): - xml_request = ET.Element("tsRequest") - dqw_element = ET.SubElement(xml_request, "dataQualityWarning") - - dqw_element.attrib["isActive"] = str(dqw_item.active).lower() - dqw_element.attrib["isSevere"] = str(dqw_item.severe).lower() - - dqw_element.attrib["type"] = dqw_item.warning_type - - if dqw_item.message: - dqw_element.attrib["message"] = str(dqw_item.message) - - return ET.tostring(xml_request) - - class FavoriteRequest(object): def _add_to_req(self, id_, target_type, label): """ diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index a20694a92..4d289d5e5 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -33,10 +33,7 @@ import requests -try: - from distutils2.version import NormalizedVersion as Version -except ImportError: - from distutils.version import LooseVersion as Version +from distutils.version import LooseVersion as Version _PRODUCT_TO_REST_VERSION = { "10.0": "2.3", diff --git a/test/test_datasource.py b/test/test_datasource.py index 52a5eabe3..d00c05080 100644 --- a/test/test_datasource.py +++ b/test/test_datasource.py @@ -28,7 +28,7 @@ class DatasourceTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.server = TSC.Server('http://test') # Fake signin @@ -37,7 +37,7 @@ def setUp(self): self.baseurl = self.server.datasources.baseurl - def test_get(self): + def test_get(self) -> None: response_xml = read_xml_asset(GET_XML) with requests_mock.mock() as m: m.get(self.baseurl, text=response_xml) @@ -75,11 +75,11 @@ def test_get(self): self.assertFalse(all_datasources[1].has_extracts) self.assertTrue(all_datasources[1].use_remote_query_agent) - def test_get_before_signin(self): + def test_get_before_signin(self) -> None: self.server._auth_token = None self.assertRaises(TSC.NotSignedInError, self.server.datasources.get) - def test_get_empty(self): + def test_get_empty(self) -> None: response_xml = read_xml_asset(GET_EMPTY_XML) with requests_mock.mock() as m: m.get(self.baseurl, text=response_xml) @@ -88,7 +88,7 @@ def test_get_empty(self): self.assertEqual(0, pagination_item.total_available) self.assertEqual([], all_datasources) - def test_get_by_id(self): + def test_get_by_id(self) -> None: response_xml = read_xml_asset(GET_BY_ID_XML) with requests_mock.mock() as m: m.get(self.baseurl + '/9dbd2263-16b5-46e1-9c43-a76bb8ab65fb', text=response_xml) @@ -107,7 +107,7 @@ def test_get_by_id(self): self.assertEqual(set(['world', 'indicators', 'sample']), single_datasource.tags) self.assertEqual(TSC.DatasourceItem.AskDataEnablement.SiteDefault, single_datasource.ask_data_enablement) - def test_update(self): + def test_update(self) -> None: response_xml = read_xml_asset(UPDATE_XML) with requests_mock.mock() as m: m.put(self.baseurl + '/9dbd2263-16b5-46e1-9c43-a76bb8ab65fb', text=response_xml) @@ -127,7 +127,7 @@ def test_update(self): self.assertEqual(updated_datasource.certified, single_datasource.certified) self.assertEqual(updated_datasource.certification_note, single_datasource.certification_note) - def test_update_copy_fields(self): + def test_update_copy_fields(self) -> None: with open(asset(UPDATE_XML), 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -140,7 +140,7 @@ def test_update_copy_fields(self): self.assertEqual(single_datasource.tags, updated_datasource.tags) self.assertEqual(single_datasource._project_name, updated_datasource._project_name) - def test_update_tags(self): + def test_update_tags(self) -> None: add_tags_xml, update_xml = read_xml_assets(ADD_TAGS_XML, UPDATE_XML) with requests_mock.mock() as m: m.put(self.baseurl + '/9dbd2263-16b5-46e1-9c43-a76bb8ab65fb/tags', text=add_tags_xml) @@ -156,7 +156,7 @@ def test_update_tags(self): self.assertEqual(single_datasource.tags, updated_datasource.tags) self.assertEqual(single_datasource._initial_tags, updated_datasource._initial_tags) - def test_populate_connections(self): + def test_populate_connections(self) -> None: response_xml = read_xml_asset(POPULATE_CONNECTIONS_XML) with requests_mock.mock() as m: m.get(self.baseurl + '/9dbd2263-16b5-46e1-9c43-a76bb8ab65fb/connections', text=response_xml) @@ -180,7 +180,7 @@ def test_populate_connections(self): self.assertEqual('heero', ds2.username) self.assertEqual(False, ds2.embed_password) - def test_update_connection(self): + def test_update_connection(self) -> None: populate_xml, response_xml = read_xml_assets(POPULATE_CONNECTIONS_XML, UPDATE_CONNECTION_XML) with requests_mock.mock() as m: @@ -193,7 +193,7 @@ def test_update_connection(self): single_datasource._id = '9dbd2263-16b5-46e1-9c43-a76bb8ab65fb' self.server.datasources.populate_connections(single_datasource) - connection = single_datasource.connections[0] + connection = single_datasource.connections[0] # type: ignore[index] connection.server_address = 'bar' connection.server_port = '9876' connection.username = 'foo' @@ -204,7 +204,7 @@ def test_update_connection(self): self.assertEqual('9876', new_connection.server_port) self.assertEqual('foo', new_connection.username) - def test_populate_permissions(self): + def test_populate_permissions(self) -> None: with open(asset(POPULATE_PERMISSIONS_XML), 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -215,22 +215,22 @@ def test_populate_permissions(self): self.server.datasources.populate_permissions(single_datasource) permissions = single_datasource.permissions - self.assertEqual(permissions[0].grantee.tag_name, 'group') - self.assertEqual(permissions[0].grantee.id, '5e5e1978-71fa-11e4-87dd-7382f5c437af') - self.assertDictEqual(permissions[0].capabilities, { + self.assertEqual(permissions[0].grantee.tag_name, 'group') # type: ignore[index] + self.assertEqual(permissions[0].grantee.id, '5e5e1978-71fa-11e4-87dd-7382f5c437af') # type: ignore[index] + self.assertDictEqual(permissions[0].capabilities, { # type: ignore[index] TSC.Permission.Capability.Delete: TSC.Permission.Mode.Deny, TSC.Permission.Capability.ChangePermissions: TSC.Permission.Mode.Deny, TSC.Permission.Capability.Connect: TSC.Permission.Mode.Allow, TSC.Permission.Capability.Read: TSC.Permission.Mode.Allow, }) - self.assertEqual(permissions[1].grantee.tag_name, 'user') - self.assertEqual(permissions[1].grantee.id, '7c37ee24-c4b1-42b6-a154-eaeab7ee330a') - self.assertDictEqual(permissions[1].capabilities, { + self.assertEqual(permissions[1].grantee.tag_name, 'user') # type: ignore[index] + self.assertEqual(permissions[1].grantee.id, '7c37ee24-c4b1-42b6-a154-eaeab7ee330a') # type: ignore[index] + self.assertDictEqual(permissions[1].capabilities, { # type: ignore[index] TSC.Permission.Capability.Write: TSC.Permission.Mode.Allow, }) - def test_publish(self): + def test_publish(self) -> None: response_xml = read_xml_asset(PUBLISH_XML) with requests_mock.mock() as m: m.post(self.baseurl, text=response_xml) @@ -251,7 +251,7 @@ def test_publish(self): self.assertEqual('default', new_datasource.project_name) self.assertEqual('5de011f8-5aa9-4d5b-b991-f462c8dd6bb7', new_datasource.owner_id) - def test_publish_a_non_packaged_file_object(self): + def test_publish_a_non_packaged_file_object(self) -> None: response_xml = read_xml_asset(PUBLISH_XML) with requests_mock.mock() as m: m.post(self.baseurl, text=response_xml) @@ -273,7 +273,7 @@ def test_publish_a_non_packaged_file_object(self): self.assertEqual('default', new_datasource.project_name) self.assertEqual('5de011f8-5aa9-4d5b-b991-f462c8dd6bb7', new_datasource.owner_id) - def test_publish_a_packaged_file_object(self): + def test_publish_a_packaged_file_object(self) -> None: response_xml = read_xml_asset(PUBLISH_XML) with requests_mock.mock() as m: m.post(self.baseurl, text=response_xml) @@ -301,7 +301,7 @@ def test_publish_a_packaged_file_object(self): self.assertEqual('default', new_datasource.project_name) self.assertEqual('5de011f8-5aa9-4d5b-b991-f462c8dd6bb7', new_datasource.owner_id) - def test_publish_async(self): + def test_publish_async(self) -> None: self.server.version = "3.0" baseurl = self.server.datasources.baseurl response_xml = read_xml_asset(PUBLISH_XML_ASYNC) @@ -321,7 +321,7 @@ def test_publish_async(self): self.assertEqual('2018-06-30T00:54:54Z', format_datetime(new_job.created_at)) self.assertEqual(1, new_job.finish_code) - def test_publish_unnamed_file_object(self): + def test_publish_unnamed_file_object(self) -> None: new_datasource = TSC.DatasourceItem('test') publish_mode = self.server.PublishMode.CreateNew @@ -330,7 +330,7 @@ def test_publish_unnamed_file_object(self): new_datasource, file_object, publish_mode ) - def test_refresh_id(self): + def test_refresh_id(self) -> None: self.server.version = '2.8' self.baseurl = self.server.datasources.baseurl response_xml = read_xml_asset(REFRESH_XML) @@ -345,7 +345,7 @@ def test_refresh_id(self): self.assertEqual('2020-03-05T22:05:32Z', format_datetime(new_job.created_at)) self.assertEqual(-1, new_job.finish_code) - def test_refresh_object(self): + def test_refresh_object(self) -> None: self.server.version = '2.8' self.baseurl = self.server.datasources.baseurl datasource = TSC.DatasourceItem('') @@ -359,7 +359,7 @@ def test_refresh_object(self): # We only check the `id`; remaining fields are already tested in `test_refresh_id` self.assertEqual('7c3d599e-949f-44c3-94a1-f30ba85757e4', new_job.id) - def test_update_hyper_data_datasource_object(self): + def test_update_hyper_data_datasource_object(self) -> None: """Calling `update_hyper_data` with a `DatasourceItem` should update that datasource""" self.server.version = "3.13" self.baseurl = self.server.datasources.baseurl @@ -378,7 +378,7 @@ def test_update_hyper_data_datasource_object(self): self.assertEqual('2021-09-18T09:40:12Z', format_datetime(new_job.created_at)) self.assertEqual(-1, new_job.finish_code) - def test_update_hyper_data_connection_object(self): + def test_update_hyper_data_connection_object(self) -> None: """Calling `update_hyper_data` with a `ConnectionItem` should update that connection""" self.server.version = "3.13" self.baseurl = self.server.datasources.baseurl @@ -395,7 +395,7 @@ def test_update_hyper_data_connection_object(self): # We only check the `id`; remaining fields are already tested in `test_update_hyper_data_datasource_object` self.assertEqual('5c0ba560-c959-424e-b08a-f32ef0bfb737', new_job.id) - def test_update_hyper_data_datasource_string(self): + def test_update_hyper_data_datasource_string(self) -> None: """For convenience, calling `update_hyper_data` with a `str` should update the datasource with the corresponding UUID""" self.server.version = "3.13" self.baseurl = self.server.datasources.baseurl @@ -410,7 +410,7 @@ def test_update_hyper_data_datasource_string(self): # We only check the `id`; remaining fields are already tested in `test_update_hyper_data_datasource_object` self.assertEqual('5c0ba560-c959-424e-b08a-f32ef0bfb737', new_job.id) - def test_update_hyper_data_datasource_payload_file(self): + def test_update_hyper_data_datasource_payload_file(self) -> None: """If `payload` is present, we upload it and associate the job with it""" self.server.version = "3.13" self.baseurl = self.server.datasources.baseurl @@ -428,7 +428,7 @@ def test_update_hyper_data_datasource_payload_file(self): # We only check the `id`; remaining fields are already tested in `test_update_hyper_data_datasource_object` self.assertEqual('5c0ba560-c959-424e-b08a-f32ef0bfb737', new_job.id) - def test_update_hyper_data_datasource_invalid_payload_file(self): + def test_update_hyper_data_datasource_invalid_payload_file(self) -> None: """If `payload` points to a non-existing file, we report an error""" self.server.version = "3.13" self.baseurl = self.server.datasources.baseurl @@ -439,12 +439,12 @@ def test_update_hyper_data_datasource_invalid_payload_file(self): exception = cm.exception self.assertEqual(str(exception), "File path does not lead to an existing file.") - def test_delete(self): + def test_delete(self) -> None: with requests_mock.mock() as m: m.delete(self.baseurl + '/9dbd2263-16b5-46e1-9c43-a76bb8ab65fb', status_code=204) self.server.datasources.delete('9dbd2263-16b5-46e1-9c43-a76bb8ab65fb') - def test_download(self): + def test_download(self) -> None: with requests_mock.mock() as m: m.get(self.baseurl + '/9dbd2263-16b5-46e1-9c43-a76bb8ab65fb/content', headers={'Content-Disposition': 'name="tableau_datasource"; filename="Sample datasource.tds"'}) @@ -452,7 +452,7 @@ def test_download(self): self.assertTrue(os.path.exists(file_path)) os.remove(file_path) - def test_download_sanitizes_name(self): + def test_download_sanitizes_name(self) -> None: filename = "Name,With,Commas.tds" disposition = 'name="tableau_workbook"; filename="{}"'.format(filename) with requests_mock.mock() as m: @@ -463,7 +463,7 @@ def test_download_sanitizes_name(self): self.assertTrue(os.path.exists(file_path)) os.remove(file_path) - def test_download_extract_only(self): + def test_download_extract_only(self) -> None: # Pretend we're 2.5 for 'extract_only' self.server.version = "2.5" self.baseurl = self.server.datasources.baseurl @@ -476,40 +476,40 @@ def test_download_extract_only(self): self.assertTrue(os.path.exists(file_path)) os.remove(file_path) - def test_update_missing_id(self): + def test_update_missing_id(self) -> None: single_datasource = TSC.DatasourceItem('ee8c6e70-43b6-11e6-af4f-f7b0d8e20760', 'test') self.assertRaises(TSC.MissingRequiredFieldError, self.server.datasources.update, single_datasource) - def test_publish_missing_path(self): + def test_publish_missing_path(self) -> None: new_datasource = TSC.DatasourceItem('ee8c6e70-43b6-11e6-af4f-f7b0d8e20760', 'test') self.assertRaises(IOError, self.server.datasources.publish, new_datasource, '', self.server.PublishMode.CreateNew) - def test_publish_missing_mode(self): + def test_publish_missing_mode(self) -> None: new_datasource = TSC.DatasourceItem('ee8c6e70-43b6-11e6-af4f-f7b0d8e20760', 'test') self.assertRaises(ValueError, self.server.datasources.publish, new_datasource, asset('SampleDS.tds'), None) - def test_publish_invalid_file_type(self): + def test_publish_invalid_file_type(self) -> None: new_datasource = TSC.DatasourceItem('ee8c6e70-43b6-11e6-af4f-f7b0d8e20760', 'test') self.assertRaises(ValueError, self.server.datasources.publish, new_datasource, asset('SampleWB.twbx'), self.server.PublishMode.Append) - def test_publish_hyper_file_object_raises_exception(self): + def test_publish_hyper_file_object_raises_exception(self) -> None: new_datasource = TSC.DatasourceItem('ee8c6e70-43b6-11e6-af4f-f7b0d8e20760', 'test') - with open(asset('World Indicators.hyper')) as file_object: + with open(asset('World Indicators.hyper'), 'rb') as file_object: self.assertRaises(ValueError, self.server.datasources.publish, new_datasource, file_object, self.server.PublishMode.Append) - def test_publish_tde_file_object_raises_exception(self): + def test_publish_tde_file_object_raises_exception(self) -> None: new_datasource = TSC.DatasourceItem('ee8c6e70-43b6-11e6-af4f-f7b0d8e20760', 'test') tds_asset = asset(os.path.join('Data', 'Tableau Samples', 'World Indicators.tde')) - with open(tds_asset) as file_object: + with open(tds_asset, 'rb') as file_object: self.assertRaises(ValueError, self.server.datasources.publish, new_datasource, file_object, self.server.PublishMode.Append) - def test_publish_file_object_of_unknown_type_raises_exception(self): + def test_publish_file_object_of_unknown_type_raises_exception(self) -> None: new_datasource = TSC.DatasourceItem('ee8c6e70-43b6-11e6-af4f-f7b0d8e20760', 'test') with BytesIO() as file_object: @@ -518,7 +518,7 @@ def test_publish_file_object_of_unknown_type_raises_exception(self): self.assertRaises(ValueError, self.server.datasources.publish, new_datasource, file_object, self.server.PublishMode.Append) - def test_publish_multi_connection(self): + def test_publish_multi_connection(self) -> None: new_datasource = TSC.DatasourceItem(name='Sample', project_id='ee8c6e70-43b6-11e6-af4f-f7b0d8e20760') connection1 = TSC.ConnectionItem() connection1.server_address = 'mysql.test.com' @@ -532,11 +532,11 @@ def test_publish_multi_connection(self): connection_results = ET.fromstring(response).findall('.//connection') self.assertEqual(connection_results[0].get('serverAddress', None), 'mysql.test.com') - self.assertEqual(connection_results[0].find('connectionCredentials').get('name', None), 'test') + self.assertEqual(connection_results[0].find('connectionCredentials').get('name', None), 'test') # type: ignore[union-attr] self.assertEqual(connection_results[1].get('serverAddress', None), 'pgsql.test.com') - self.assertEqual(connection_results[1].find('connectionCredentials').get('password', None), 'secret') + self.assertEqual(connection_results[1].find('connectionCredentials').get('password', None), 'secret') # type: ignore[union-attr] - def test_publish_single_connection(self): + def test_publish_single_connection(self) -> None: new_datasource = TSC.DatasourceItem(name='Sample', project_id='ee8c6e70-43b6-11e6-af4f-f7b0d8e20760') connection_creds = TSC.ConnectionCredentials('test', 'secret', True) @@ -549,7 +549,7 @@ def test_publish_single_connection(self): self.assertEqual(credentials[0].get('password', None), 'secret') self.assertEqual(credentials[0].get('embed', None), 'true') - def test_credentials_and_multi_connect_raises_exception(self): + def test_credentials_and_multi_connect_raises_exception(self) -> None: new_datasource = TSC.DatasourceItem(name='Sample', project_id='ee8c6e70-43b6-11e6-af4f-f7b0d8e20760') connection_creds = TSC.ConnectionCredentials('test', 'secret', True) @@ -563,7 +563,7 @@ def test_credentials_and_multi_connect_raises_exception(self): connection_credentials=connection_creds, connections=[connection1]) - def test_synchronous_publish_timeout_error(self): + def test_synchronous_publish_timeout_error(self) -> None: with requests_mock.mock() as m: m.register_uri('POST', self.baseurl, status_code=504) @@ -574,14 +574,14 @@ def test_synchronous_publish_timeout_error(self): self.server.datasources.publish, new_datasource, asset('SampleDS.tds'), publish_mode) - def test_delete_extracts(self): + def test_delete_extracts(self) -> None: self.server.version = "3.10" self.baseurl = self.server.datasources.baseurl with requests_mock.mock() as m: m.post(self.baseurl + '/3cc6cd06-89ce-4fdc-b935-5294135d6d42/deleteExtract', status_code=200) self.server.datasources.delete_extract('3cc6cd06-89ce-4fdc-b935-5294135d6d42') - def test_create_extracts(self): + def test_create_extracts(self) -> None: self.server.version = "3.10" self.baseurl = self.server.datasources.baseurl @@ -591,7 +591,7 @@ def test_create_extracts(self): status_code=200, text=response_xml) self.server.datasources.create_extract('3cc6cd06-89ce-4fdc-b935-5294135d6d42') - def test_create_extracts_encrypted(self): + def test_create_extracts_encrypted(self) -> None: self.server.version = "3.10" self.baseurl = self.server.datasources.baseurl diff --git a/test/test_regression_tests.py b/test/test_regression_tests.py index 281f3fbca..52ea03b92 100644 --- a/test/test_regression_tests.py +++ b/test/test_regression_tests.py @@ -3,7 +3,7 @@ try: from unittest import mock except ImportError: - import mock + import mock # type: ignore[no-redef] import tableauserverclient.server.request_factory as factory from tableauserverclient.server.endpoint import Endpoint diff --git a/test/test_workbook.py b/test/test_workbook.py index 459b1f905..5ace90b2f 100644 --- a/test/test_workbook.py +++ b/test/test_workbook.py @@ -5,7 +5,7 @@ import requests_mock import tableauserverclient as TSC import xml.etree.ElementTree as ET - +from pathlib import Path from tableauserverclient.datetime_helpers import format_datetime from tableauserverclient.server.endpoint.exceptions import InternalServerError @@ -38,7 +38,7 @@ class WorkbookTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.server = TSC.Server('http://test') # Fake sign in @@ -47,7 +47,7 @@ def setUp(self): self.baseurl = self.server.workbooks.baseurl - def test_get(self): + def test_get(self) -> None: with open(GET_XML, 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -82,7 +82,7 @@ def test_get(self): self.assertEqual('5de011f8-5aa9-4d5b-b991-f462c8dd6bb7', all_workbooks[1].owner_id) self.assertEqual(set(['Safari', 'Sample']), all_workbooks[1].tags) - def test_get_ignore_invalid_date(self): + def test_get_ignore_invalid_date(self) -> None: with open(GET_INVALID_DATE_XML, 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -91,11 +91,11 @@ def test_get_ignore_invalid_date(self): self.assertEqual(None, format_datetime(all_workbooks[0].created_at)) self.assertEqual('2016-08-04T17:56:41Z', format_datetime(all_workbooks[0].updated_at)) - def test_get_before_signin(self): + def test_get_before_signin(self) -> None: self.server._auth_token = None self.assertRaises(TSC.NotSignedInError, self.server.workbooks.get) - def test_get_empty(self): + def test_get_empty(self) -> None: with open(GET_EMPTY_XML, 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -105,7 +105,7 @@ def test_get_empty(self): self.assertEqual(0, pagination_item.total_available) self.assertEqual([], all_workbooks) - def test_get_by_id(self): + def test_get_by_id(self) -> None: with open(GET_BY_ID_XML, 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -129,7 +129,7 @@ def test_get_by_id(self): self.assertEqual('ENDANGERED SAFARI', single_workbook.views[0].name) self.assertEqual('SafariSample/sheets/ENDANGEREDSAFARI', single_workbook.views[0].content_url) - def test_get_by_id_personal(self): + def test_get_by_id_personal(self) -> None: # workbooks in personal space don't have project_id or project_name with open(GET_BY_ID_XML_PERSONAL, 'rb') as f: response_xml = f.read().decode('utf-8') @@ -154,10 +154,10 @@ def test_get_by_id_personal(self): self.assertEqual('ENDANGERED SAFARI', single_workbook.views[0].name) self.assertEqual('SafariSample/sheets/ENDANGEREDSAFARI', single_workbook.views[0].content_url) - def test_get_by_id_missing_id(self): + def test_get_by_id_missing_id(self) -> None: self.assertRaises(ValueError, self.server.workbooks.get_by_id, '') - def test_refresh_id(self): + def test_refresh_id(self) -> None: self.server.version = '2.8' self.baseurl = self.server.workbooks.baseurl with open(REFRESH_XML, 'rb') as f: @@ -167,7 +167,7 @@ def test_refresh_id(self): status_code=202, text=response_xml) self.server.workbooks.refresh('3cc6cd06-89ce-4fdc-b935-5294135d6d42') - def test_refresh_object(self): + def test_refresh_object(self) -> None: self.server.version = '2.8' self.baseurl = self.server.workbooks.baseurl workbook = TSC.WorkbookItem('') @@ -179,15 +179,15 @@ def test_refresh_object(self): status_code=202, text=response_xml) self.server.workbooks.refresh(workbook) - def test_delete(self): + def test_delete(self) -> None: with requests_mock.mock() as m: m.delete(self.baseurl + '/3cc6cd06-89ce-4fdc-b935-5294135d6d42', status_code=204) self.server.workbooks.delete('3cc6cd06-89ce-4fdc-b935-5294135d6d42') - def test_delete_missing_id(self): + def test_delete_missing_id(self) -> None: self.assertRaises(ValueError, self.server.workbooks.delete, '') - def test_update(self): + def test_update(self) -> None: with open(UPDATE_XML, 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -210,11 +210,11 @@ def test_update(self): self.assertEqual(True, single_workbook.data_acceleration_config['acceleration_enabled']) self.assertEqual(False, single_workbook.data_acceleration_config['accelerate_now']) - def test_update_missing_id(self): + def test_update_missing_id(self) -> None: single_workbook = TSC.WorkbookItem('test') self.assertRaises(TSC.MissingRequiredFieldError, self.server.workbooks.update, single_workbook) - def test_update_copy_fields(self): + def test_update_copy_fields(self) -> None: with open(POPULATE_CONNECTIONS_XML, 'rb') as f: connection_xml = f.read().decode('utf-8') with open(UPDATE_XML, 'rb') as f: @@ -233,7 +233,7 @@ def test_update_copy_fields(self): self.assertEqual(single_workbook._initial_tags, updated_workbook._initial_tags) self.assertEqual(single_workbook._preview_image, updated_workbook._preview_image) - def test_update_tags(self): + def test_update_tags(self) -> None: with open(ADD_TAGS_XML, 'rb') as f: add_tags_xml = f.read().decode('utf-8') with open(UPDATE_XML, 'rb') as f: @@ -252,7 +252,7 @@ def test_update_tags(self): self.assertEqual(single_workbook.tags, updated_workbook.tags) self.assertEqual(single_workbook._initial_tags, updated_workbook._initial_tags) - def test_download(self): + def test_download(self) -> None: with requests_mock.mock() as m: m.get(self.baseurl + '/1f951daf-4061-451a-9df1-69a8062664f2/content', headers={'Content-Disposition': 'name="tableau_workbook"; filename="RESTAPISample.twbx"'}) @@ -260,7 +260,7 @@ def test_download(self): self.assertTrue(os.path.exists(file_path)) os.remove(file_path) - def test_download_sanitizes_name(self): + def test_download_sanitizes_name(self) -> None: filename = "Name,With,Commas.twbx" disposition = 'name="tableau_workbook"; filename="{}"'.format(filename) with requests_mock.mock() as m: @@ -271,7 +271,7 @@ def test_download_sanitizes_name(self): self.assertTrue(os.path.exists(file_path)) os.remove(file_path) - def test_download_extract_only(self): + def test_download_extract_only(self) -> None: # Pretend we're 2.5 for 'extract_only' self.server.version = "2.5" self.baseurl = self.server.workbooks.baseurl @@ -285,10 +285,10 @@ def test_download_extract_only(self): self.assertTrue(os.path.exists(file_path)) os.remove(file_path) - def test_download_missing_id(self): + def test_download_missing_id(self) -> None: self.assertRaises(ValueError, self.server.workbooks.download, '') - def test_populate_views(self): + def test_populate_views(self) -> None: with open(POPULATE_VIEWS_XML, 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -310,7 +310,7 @@ def test_populate_views(self): self.assertEqual('Interest rates', views_list[2].name) self.assertEqual('RESTAPISample/sheets/Interestrates', views_list[2].content_url) - def test_populate_views_with_usage(self): + def test_populate_views_with_usage(self) -> None: with open(POPULATE_VIEWS_USAGE_XML, 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -328,11 +328,11 @@ def test_populate_views_with_usage(self): self.assertEqual('0599c28c-6d82-457e-a453-e52c1bdb00f5', views_list[2].id) self.assertEqual(0, views_list[2].total_views) - def test_populate_views_missing_id(self): + def test_populate_views_missing_id(self) -> None: single_workbook = TSC.WorkbookItem('test') self.assertRaises(TSC.MissingRequiredFieldError, self.server.workbooks.populate_views, single_workbook) - def test_populate_connections(self): + def test_populate_connections(self) -> None: with open(POPULATE_CONNECTIONS_XML, 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -346,7 +346,7 @@ def test_populate_connections(self): self.assertEqual('4506225a-0d32-4ab1-82d3-c24e85f7afba', single_workbook.connections[0].datasource_id) self.assertEqual('World Indicators', single_workbook.connections[0].datasource_name) - def test_populate_permissions(self): + def test_populate_permissions(self) -> None: with open(POPULATE_PERMISSIONS_XML, 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -375,7 +375,7 @@ def test_populate_permissions(self): TSC.Permission.Capability.ViewComments: TSC.Permission.Mode.Deny }) - def test_add_permissions(self): + def test_add_permissions(self) -> None: with open(UPDATE_PERMISSIONS, 'rb') as f: response_xml = f.read().decode('utf-8') @@ -406,13 +406,13 @@ def test_add_permissions(self): TSC.Permission.Capability.Write: TSC.Permission.Mode.Allow }) - def test_populate_connections_missing_id(self): + def test_populate_connections_missing_id(self) -> None: single_workbook = TSC.WorkbookItem('test') self.assertRaises(TSC.MissingRequiredFieldError, self.server.workbooks.populate_connections, single_workbook) - def test_populate_pdf(self): + def test_populate_pdf(self) -> None: self.server.version = "3.4" self.baseurl = self.server.workbooks.baseurl with open(POPULATE_PDF, "rb") as f: @@ -430,7 +430,7 @@ def test_populate_pdf(self): self.server.workbooks.populate_pdf(single_workbook, req_option) self.assertEqual(response, single_workbook.pdf) - def test_populate_preview_image(self): + def test_populate_preview_image(self) -> None: with open(POPULATE_PREVIEW_IMAGE, 'rb') as f: response = f.read() with requests_mock.mock() as m: @@ -441,13 +441,13 @@ def test_populate_preview_image(self): self.assertEqual(response, single_workbook.preview_image) - def test_populate_preview_image_missing_id(self): + def test_populate_preview_image_missing_id(self) -> None: single_workbook = TSC.WorkbookItem('test') self.assertRaises(TSC.MissingRequiredFieldError, self.server.workbooks.populate_preview_image, single_workbook) - def test_publish(self): + def test_publish(self) -> None: with open(PUBLISH_XML, 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -478,7 +478,7 @@ def test_publish(self): self.assertEqual('GDP per capita', new_workbook.views[0].name) self.assertEqual('RESTAPISample_0/sheets/GDPpercapita', new_workbook.views[0].content_url) - def test_publish_a_packaged_file_object(self): + def test_publish_a_packaged_file_object(self) -> None: with open(PUBLISH_XML, 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -512,7 +512,7 @@ def test_publish_a_packaged_file_object(self): self.assertEqual('GDP per capita', new_workbook.views[0].name) self.assertEqual('RESTAPISample_0/sheets/GDPpercapita', new_workbook.views[0].content_url) - def test_publish_non_packeged_file_object(self): + def test_publish_non_packeged_file_object(self) -> None: with open(PUBLISH_XML, 'rb') as f: response_xml = f.read().decode('utf-8') @@ -547,7 +547,38 @@ def test_publish_non_packeged_file_object(self): self.assertEqual('GDP per capita', new_workbook.views[0].name) self.assertEqual('RESTAPISample_0/sheets/GDPpercapita', new_workbook.views[0].content_url) - def test_publish_with_hidden_view(self): + def test_publish_path_object(self) -> None: + with open(PUBLISH_XML, 'rb') as f: + response_xml = f.read().decode('utf-8') + with requests_mock.mock() as m: + m.post(self.baseurl, text=response_xml) + + new_workbook = TSC.WorkbookItem(name='Sample', + show_tabs=False, + project_id='ee8c6e70-43b6-11e6-af4f-f7b0d8e20760') + + sample_workbook = Path(TEST_ASSET_DIR) / 'SampleWB.twbx' + publish_mode = self.server.PublishMode.CreateNew + + new_workbook = self.server.workbooks.publish(new_workbook, + sample_workbook, + publish_mode) + + self.assertEqual('a8076ca1-e9d8-495e-bae6-c684dbb55836', new_workbook.id) + self.assertEqual('RESTAPISample', new_workbook.name) + self.assertEqual('RESTAPISample_0', new_workbook.content_url) + self.assertEqual(False, new_workbook.show_tabs) + self.assertEqual(1, new_workbook.size) + self.assertEqual('2016-08-18T18:33:24Z', format_datetime(new_workbook.created_at)) + self.assertEqual('2016-08-18T20:31:34Z', format_datetime(new_workbook.updated_at)) + self.assertEqual('ee8c6e70-43b6-11e6-af4f-f7b0d8e20760', new_workbook.project_id) + self.assertEqual('default', new_workbook.project_name) + self.assertEqual('5de011f8-5aa9-4d5b-b991-f462c8dd6bb7', new_workbook.owner_id) + self.assertEqual('fe0b4e89-73f4-435e-952d-3a263fbfa56c', new_workbook.views[0].id) + self.assertEqual('GDP per capita', new_workbook.views[0].name) + self.assertEqual('RESTAPISample_0/sheets/GDPpercapita', new_workbook.views[0].content_url) + + def test_publish_with_hidden_view(self) -> None: with open(PUBLISH_XML, 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -570,7 +601,7 @@ def test_publish_with_hidden_view(self): self.assertTrue(re.search(rb'<\/views>', request_body)) self.assertTrue(re.search(rb'<\/views>', request_body)) - def test_publish_with_query_params(self): + def test_publish_with_query_params(self) -> None: with open(PUBLISH_ASYNC_XML, 'rb') as f: response_xml = f.read().decode('utf-8') with requests_mock.mock() as m: @@ -592,7 +623,7 @@ def test_publish_with_query_params(self): self.assertTrue('skipconnectioncheck' in request_query_params) self.assertTrue(request_query_params['skipconnectioncheck']) - def test_publish_async(self): + def test_publish_async(self) -> None: self.server.version = '3.0' baseurl = self.server.workbooks.baseurl with open(PUBLISH_ASYNC_XML, 'rb') as f: @@ -618,27 +649,36 @@ def test_publish_async(self): self.assertEqual('2018-06-29T23:22:32Z', format_datetime(new_job.created_at)) self.assertEqual(1, new_job.finish_code) - def test_publish_invalid_file(self): + def test_publish_invalid_file(self) -> None: new_workbook = TSC.WorkbookItem('test', 'ee8c6e70-43b6-11e6-af4f-f7b0d8e20760') self.assertRaises(IOError, self.server.workbooks.publish, new_workbook, '.', self.server.PublishMode.CreateNew) - def test_publish_invalid_file_type(self): + def test_publish_invalid_file_type(self) -> None: new_workbook = TSC.WorkbookItem('test', 'ee8c6e70-43b6-11e6-af4f-f7b0d8e20760') self.assertRaises(ValueError, self.server.workbooks.publish, new_workbook, os.path.join(TEST_ASSET_DIR, 'SampleDS.tds'), self.server.PublishMode.CreateNew) - def test_publish_unnamed_file_object(self): + def test_publish_unnamed_file_object(self) -> None: new_workbook = TSC.WorkbookItem('test') - with open(os.path.join(TEST_ASSET_DIR, 'SampleWB.twbx')) as f: + with open(os.path.join(TEST_ASSET_DIR, 'SampleWB.twbx'), 'rb') as f: self.assertRaises(ValueError, self.server.workbooks.publish, new_workbook, f, self.server.PublishMode.CreateNew ) - def test_publish_file_object_of_unknown_type_raises_exception(self): + def test_publish_non_bytes_file_object(self) -> None: + new_workbook = TSC.WorkbookItem('test') + + with open(os.path.join(TEST_ASSET_DIR, 'SampleWB.twbx')) as f: + + self.assertRaises(TypeError, self.server.workbooks.publish, + new_workbook, f, self.server.PublishMode.CreateNew + ) + + def test_publish_file_object_of_unknown_type_raises_exception(self) -> None: new_workbook = TSC.WorkbookItem('test') with BytesIO() as file_object: file_object.write(bytes.fromhex('89504E470D0A1A0A')) @@ -646,7 +686,7 @@ def test_publish_file_object_of_unknown_type_raises_exception(self): self.assertRaises(ValueError, self.server.workbooks.publish, new_workbook, file_object, self.server.PublishMode.CreateNew) - def test_publish_multi_connection(self): + def test_publish_multi_connection(self) -> None: new_workbook = TSC.WorkbookItem(name='Sample', show_tabs=False, project_id='ee8c6e70-43b6-11e6-af4f-f7b0d8e20760') connection1 = TSC.ConnectionItem() @@ -661,11 +701,11 @@ def test_publish_multi_connection(self): connection_results = ET.fromstring(response).findall('.//connection') self.assertEqual(connection_results[0].get('serverAddress', None), 'mysql.test.com') - self.assertEqual(connection_results[0].find('connectionCredentials').get('name', None), 'test') + self.assertEqual(connection_results[0].find('connectionCredentials').get('name', None), 'test') # type: ignore[union-attr] self.assertEqual(connection_results[1].get('serverAddress', None), 'pgsql.test.com') - self.assertEqual(connection_results[1].find('connectionCredentials').get('password', None), 'secret') + self.assertEqual(connection_results[1].find('connectionCredentials').get('password', None), 'secret') # type: ignore[union-attr] - def test_publish_single_connection(self): + def test_publish_single_connection(self) -> None: new_workbook = TSC.WorkbookItem(name='Sample', show_tabs=False, project_id='ee8c6e70-43b6-11e6-af4f-f7b0d8e20760') connection_creds = TSC.ConnectionCredentials('test', 'secret', True) @@ -678,7 +718,7 @@ def test_publish_single_connection(self): self.assertEqual(credentials[0].get('password', None), 'secret') self.assertEqual(credentials[0].get('embed', None), 'true') - def test_credentials_and_multi_connect_raises_exception(self): + def test_credentials_and_multi_connect_raises_exception(self) -> None: new_workbook = TSC.WorkbookItem(name='Sample', show_tabs=False, project_id='ee8c6e70-43b6-11e6-af4f-f7b0d8e20760') @@ -693,7 +733,7 @@ def test_credentials_and_multi_connect_raises_exception(self): connection_credentials=connection_creds, connections=[connection1]) - def test_synchronous_publish_timeout_error(self): + def test_synchronous_publish_timeout_error(self) -> None: with requests_mock.mock() as m: m.register_uri('POST', self.baseurl, status_code=504) @@ -703,14 +743,14 @@ def test_synchronous_publish_timeout_error(self): self.assertRaisesRegex(InternalServerError, 'Please use asynchronous publishing to avoid timeouts', self.server.workbooks.publish, new_workbook, asset('SampleWB.twbx'), publish_mode) - def test_delete_extracts_all(self): + def test_delete_extracts_all(self) -> None: self.server.version = "3.10" self.baseurl = self.server.workbooks.baseurl with requests_mock.mock() as m: m.post(self.baseurl + '/3cc6cd06-89ce-4fdc-b935-5294135d6d42/deleteExtract', status_code=200) self.server.workbooks.delete_extract('3cc6cd06-89ce-4fdc-b935-5294135d6d42') - def test_create_extracts_all(self): + def test_create_extracts_all(self) -> None: self.server.version = "3.10" self.baseurl = self.server.workbooks.baseurl @@ -721,7 +761,7 @@ def test_create_extracts_all(self): status_code=200, text=response_xml) self.server.workbooks.create_extract('3cc6cd06-89ce-4fdc-b935-5294135d6d42') - def test_create_extracts_one(self): + def test_create_extracts_one(self) -> None: self.server.version = "3.10" self.baseurl = self.server.workbooks.baseurl