diff --git a/tableauserverclient/__init__.py b/tableauserverclient/__init__.py index d1b8a4e74..d435962b1 100644 --- a/tableauserverclient/__init__.py +++ b/tableauserverclient/__init__.py @@ -1,9 +1,9 @@ from .namespace import NEW_NAMESPACE as DEFAULT_NAMESPACE from .models import ConnectionCredentials, ConnectionItem, DatasourceItem,\ - GroupItem, JobItem, BackgroundJobItem, PaginationItem, ProjectItem, ScheduleItem, \ - SiteItem, TableauAuth, PersonalAccessTokenAuth, UserItem, ViewItem, WorkbookItem, UnpopulatedPropertyError, \ - HourlyInterval, DailyInterval, WeeklyInterval, MonthlyInterval, IntervalItem, TaskItem, \ - SubscriptionItem, Target, PermissionsRule, Permission + GroupItem, JobItem, BackgroundJobItem, PaginationItem, ProjectItem, ScheduleItem,\ + SiteItem, TableauAuth, PersonalAccessTokenAuth, UserItem, ViewItem, WorkbookItem, UnpopulatedPropertyError,\ + HourlyInterval, DailyInterval, WeeklyInterval, MonthlyInterval, IntervalItem, TaskItem,\ + SubscriptionItem, Target, PermissionsRule, Permission, DatabaseItem, TableItem, ColumnItem from .server import RequestOptions, CSVRequestOptions, ImageRequestOptions, PDFRequestOptions, Filter, Sort, \ Server, ServerResponseError, MissingRequiredFieldError, NotSignedInError, Pager from ._version import get_versions diff --git a/tableauserverclient/models/__init__.py b/tableauserverclient/models/__init__.py index f96f78565..4cfdd4846 100644 --- a/tableauserverclient/models/__init__.py +++ b/tableauserverclient/models/__init__.py @@ -1,6 +1,8 @@ from .connection_credentials import ConnectionCredentials from .connection_item import ConnectionItem +from .column_item import ColumnItem from .datasource_item import DatasourceItem +from .database_item import DatabaseItem from .exceptions import UnpopulatedPropertyError from .group_item import GroupItem from .interval_item import IntervalItem, DailyInterval, WeeklyInterval, MonthlyInterval, HourlyInterval @@ -13,6 +15,7 @@ from .tableau_auth import TableauAuth from .personal_access_token_auth import PersonalAccessTokenAuth from .target import Target +from .table_item import TableItem from .task_item import TaskItem from .user_item import UserItem from .view_item import ViewItem diff --git a/tableauserverclient/models/column_item.py b/tableauserverclient/models/column_item.py new file mode 100644 index 000000000..475dd0e2a --- /dev/null +++ b/tableauserverclient/models/column_item.py @@ -0,0 +1,69 @@ +import xml.etree.ElementTree as ET + +from .property_decorators import property_is_enum, property_not_empty +from .exceptions import UnpopulatedPropertyError + + +class ColumnItem(object): + def __init__(self, name, description=None): + self._id = None + self.description = description + self.name = name + + @property + def id(self): + return self._id + + @property + def name(self): + return self._name + + @name.setter + @property_not_empty + def name(self, value): + self._name = value + + @property + def description(self): + return self._description + + @description.setter + def description(self, value): + self._description = value + + @property + def remote_type(self): + return self._remote_type + + def _set_values(self, id, name, description, remote_type): + if id is not None: + self._id = id + if name: + self._name = name + if description: + self.description = description + if remote_type: + self._remote_type = remote_type + + @classmethod + def from_response(cls, resp, ns): + all_column_items = list() + parsed_response = ET.fromstring(resp) + all_column_xml = parsed_response.findall('.//t:column', namespaces=ns) + + for column_xml in all_column_xml: + (id, name, description, remote_type) = cls._parse_element(column_xml, ns) + column_item = cls(name) + column_item._set_values(id, name, description, remote_type) + all_column_items.append(column_item) + + return all_column_items + + @staticmethod + def _parse_element(column_xml, ns): + id = column_xml.get('id', None) + name = column_xml.get('name', None) + description = column_xml.get('description', None) + remote_type = column_xml.get('remoteType', None) + + return id, name, description, remote_type diff --git a/tableauserverclient/models/database_item.py b/tableauserverclient/models/database_item.py new file mode 100644 index 000000000..cb7ebf5f0 --- /dev/null +++ b/tableauserverclient/models/database_item.py @@ -0,0 +1,260 @@ +import xml.etree.ElementTree as ET + +from .permissions_item import Permission + +from .property_decorators import property_is_enum, property_not_empty, property_is_boolean +from .exceptions import UnpopulatedPropertyError + + +class DatabaseItem(object): + class ContentPermissions: + LockedToProject = 'LockedToDatabase' + ManagedByOwner = 'ManagedByOwner' + + def __init__(self, name, description=None, content_permissions=None): + self._id = None + self.name = name + self.description = description + self.content_permissions = content_permissions + self._certified = None + self._certification_note = None + self._contact_id = None + + self._connector_url = None + self._connection_type = None + self._embedded = None + self._file_extension = None + self._file_id = None + self._file_path = None + self._host_name = None + self._metadata_type = None + self._mime_type = None + self._port = None + self._provider = None + self._request_url = None + + self._permissions = None + self._default_table_permissions = None + + self._tables = None # Not implemented yet + + @property + def content_permissions(self): + return self._content_permissions + + @property + def permissions(self): + if self._permissions is None: + error = "Project item must be populated with permissions first." + raise UnpopulatedPropertyError(error) + return self._permissions() + + @property + def default_table_permissions(self): + if self._default_table_permissions is None: + error = "Project item must be populated with permissions first." + raise UnpopulatedPropertyError(error) + return self._default_table_permissions() + + @content_permissions.setter + @property_is_enum(ContentPermissions) + def content_permissions(self, value): + self._content_permissions = value + + @property + def id(self): + return self._id + + @property + def name(self): + return self._name + + @name.setter + @property_not_empty + def name(self, value): + self._name = value + + @property + def description(self): + return self._description + + @description.setter + def description(self, value): + self._description = value + + @property + def embedded(self): + return self._embedded + + @property + def certified(self): + return self._certified + + @certified.setter + @property_is_boolean + def certified(self, value): + self._certified = value + + @property + def certification_note(self): + return self._certification_note + + @certification_note.setter + def certification_note(self, value): + self._certification_note = value + + @property + def metadata_type(self): + return self._metadata_type + + @property + def host_name(self): + return self._host_name + + @property + def port(self): + return self._port + + @property + def file_path(self): + return self._file_path + + @property + def provider(self): + return self._provider + + @property + def mime_type(self): + return self._mime_type + + @property + def connector_url(self): + return self._connector_url + + @property + def connection_type(self): + return self._connection_type + + @property + def request_url(self): + return self._request_url + + @property + def file_extension(self): + return self._file_extension + + @property + def file_id(self): + return self._file_id + + @property + def contact_id(self): + return self._contact_id + + @contact_id.setter + def contact_id(self, value): + self._contact_id = value + + @property + def tables(self): + if self._tables is None: + error = "Database must be populated with tables first." + raise UnpopulatedPropertyError(error) + # Each call to `.tables` should create a new pager, this just runs the callable + return self._tables() + + def _set_values(self, database_values): + # ID & Settable + if 'id' in database_values: + self._id = database_values['id'] + + if 'contact' in database_values: + self._contact_id = database_values['contact']['id'] + + if 'name' in database_values: + self._name = database_values['name'] + + if 'description' in database_values: + self._description = database_values['description'] + + if 'isCertified' in database_values: + self._certified = string_to_bool(database_values['isCertified']) + + if 'certificationNote' in database_values: + self._certification_note = database_values['certificationNote'] + + # Not settable, alphabetical + + if 'connectionType' in database_values: + self._connection_type = database_values['connectionType'] + + if 'connectorUrl' in database_values: + self._connector_url = database_values['connectorUrl'] + + if 'contentPermissions' in database_values: + self._content_permissions = database_values['contentPermissions'] + + if 'embedded' in database_values: + self._embedded = string_to_bool(database_values['embedded']) + + if 'fileExtension' in database_values: + self._file_extension = database_values['fileExtension'] + + if 'fileId' in database_values: + self._file_id = database_values['fileId'] + + if 'filePath' in database_values: + self._file_path = database_values['filePath'] + + if 'hostName' in database_values: + self._host_name = database_values['hostName'] + + if 'mimeType' in database_values: + self._mime_type = database_values['mimeType'] + + if 'port' in database_values: + self._port = int(database_values['port']) + + if 'provider' in database_values: + self._provider = database_values['provider'] + + if 'requestUrl' in database_values: + self._request_url = database_values['requestUrl'] + + if 'type' in database_values: + self._metadata_type = database_values['type'] + + def _set_permissions(self, permissions): + self._permissions = permissions + + def _set_tables(self, tables): + self._tables = tables + + def _set_default_permissions(self, permissions, content_type): + setattr(self, "_default_{content}_permissions".format(content=content_type), permissions) + + @classmethod + def from_response(cls, resp, ns): + all_database_items = list() + parsed_response = ET.fromstring(resp) + all_database_xml = parsed_response.findall('.//t:database', namespaces=ns) + + for database_xml in all_database_xml: + parsed_database = cls._parse_element(database_xml, ns) + database_item = cls(parsed_database['name']) + database_item._set_values(parsed_database) + all_database_items.append(database_item) + return all_database_items + + @staticmethod + def _parse_element(database_xml, ns): + database_values = database_xml.attrib.copy() + contact = database_xml.find('.//t:contact', namespaces=ns) + if contact is not None: + database_values['contact'] = contact.attrib.copy() + return database_values + + +# Used to convert string represented boolean to a boolean type +def string_to_bool(s): + return s.lower() == 'true' diff --git a/tableauserverclient/models/permissions_item.py b/tableauserverclient/models/permissions_item.py index 2c2abdf82..6487b6ca5 100644 --- a/tableauserverclient/models/permissions_item.py +++ b/tableauserverclient/models/permissions_item.py @@ -36,6 +36,8 @@ class Resource: Workbook = 'workbook' Datasource = 'datasource' Flow = 'flow' + Table = 'table' + Database = 'database' class PermissionsRule(object): diff --git a/tableauserverclient/models/table_item.py b/tableauserverclient/models/table_item.py new file mode 100644 index 000000000..c3debc0b7 --- /dev/null +++ b/tableauserverclient/models/table_item.py @@ -0,0 +1,147 @@ +import xml.etree.ElementTree as ET + +from .permissions_item import Permission +from .column_item import ColumnItem + +from .property_decorators import property_is_enum, property_not_empty, property_is_boolean +from .exceptions import UnpopulatedPropertyError + + +class TableItem(object): + def __init__(self, name, description=None): + self._id = None + self.description = description + self.name = name + + self._contact_id = None + self._certified = None + self._certification_note = None + self._permissions = None + self._schema = None + + self._columns = None + + @property + def permissions(self): + if self._permissions is None: + error = "Project item must be populated with permissions first." + raise UnpopulatedPropertyError(error) + return self._permissions() + + @property + def id(self): + return self._id + + @property + def name(self): + return self._name + + @name.setter + @property_not_empty + def name(self, value): + self._name = value + + @property + def description(self): + return self._description + + @description.setter + def description(self, value): + self._description = value + + @property + def certified(self): + return self._certified + + @certified.setter + @property_is_boolean + def certified(self, value): + self._certified = value + + @property + def certification_note(self): + return self._certification_note + + @certification_note.setter + def certification_note(self, value): + self._certification_note = value + + @property + def contact_id(self): + return self._contact_id + + @contact_id.setter + def contact_id(self, value): + self._contact_id = value + + @property + def schema(self): + return self._schema + + @property + def columns(self): + if self._columns is None: + error = "Table must be populated with columns first." + raise UnpopulatedPropertyError(error) + # Each call to `.columns` should create a new pager, this just runs the callable + return self._columns() + + def _set_columns(self, columns): + self._columns = columns + + def _set_values(self, table_values): + if 'id' in table_values: + self._id = table_values['id'] + + if 'name' in table_values: + self._name = table_values['name'] + + if 'description' in table_values: + self._description = table_values['description'] + + if 'isCertified' in table_values: + self._certified = string_to_bool(table_values['isCertified']) + + if 'certificationNote' in table_values: + self._certification_note = table_values['certificationNote'] + + if 'embedded' in table_values: + self._embedded = string_to_bool(table_values['embedded']) + + if 'schema' in table_values: + self._schema = table_values['schema'] + + if 'contact' in table_values: + self._contact_id = table_values['contact']['id'] + + def _set_permissions(self, permissions): + self._permissions = permissions + + @classmethod + def from_response(cls, resp, ns): + all_table_items = list() + parsed_response = ET.fromstring(resp) + all_table_xml = parsed_response.findall('.//t:table', namespaces=ns) + + for table_xml in all_table_xml: + parsed_table = cls._parse_element(table_xml, ns) + table_item = cls(parsed_table["name"]) + table_item._set_values(parsed_table) + all_table_items.append(table_item) + return all_table_items + + @staticmethod + def _parse_element(table_xml, ns): + + table_values = table_xml.attrib.copy() + + contact = table_xml.find('.//t:contact', namespaces=ns) + if contact is not None: + table_values['contact'] = contact.attrib.copy() + + return table_values + + +# Used to convert string represented boolean to a boolean type +def string_to_bool(s): + return s.lower() == 'true' diff --git a/tableauserverclient/server/__init__.py b/tableauserverclient/server/__init__.py index 7fa59ef3c..dcbdc8d13 100644 --- a/tableauserverclient/server/__init__.py +++ b/tableauserverclient/server/__init__.py @@ -2,11 +2,12 @@ from .request_options import CSVRequestOptions, ImageRequestOptions, PDFRequestOptions, RequestOptions from .filter import Filter from .sort import Sort -from .. import ConnectionItem, DatasourceItem, JobItem, BackgroundJobItem, \ +from .. import ConnectionItem, DatasourceItem, DatabaseItem, JobItem, BackgroundJobItem, \ GroupItem, PaginationItem, ProjectItem, ScheduleItem, SiteItem, TableauAuth,\ - UserItem, ViewItem, WorkbookItem, TaskItem, SubscriptionItem, PermissionsRule, Permission + UserItem, ViewItem, WorkbookItem, TableItem, TaskItem, SubscriptionItem, \ + PermissionsRule, Permission, ColumnItem from .endpoint import Auth, Datasources, Endpoint, Groups, Projects, Schedules, \ - Sites, Users, Views, Workbooks, Subscriptions, ServerResponseError, \ + Sites, Tables, Users, Views, Workbooks, Subscriptions, ServerResponseError, \ MissingRequiredFieldError from .server import Server from .pager import Pager diff --git a/tableauserverclient/server/endpoint/__init__.py b/tableauserverclient/server/endpoint/__init__.py index 24881b2e4..99bb37005 100644 --- a/tableauserverclient/server/endpoint/__init__.py +++ b/tableauserverclient/server/endpoint/__init__.py @@ -1,5 +1,6 @@ from .auth_endpoint import Auth from .datasources_endpoint import Datasources +from .databases_endpoint import Databases from .endpoint import Endpoint from .exceptions import ServerResponseError, MissingRequiredFieldError, ServerInfoEndpointNotFoundError from .groups_endpoint import Groups @@ -9,6 +10,7 @@ from .schedules_endpoint import Schedules from .server_info_endpoint import ServerInfo from .sites_endpoint import Sites +from .tables_endpoint import Tables from .tasks_endpoint import Tasks from .users_endpoint import Users from .views_endpoint import Views diff --git a/tableauserverclient/server/endpoint/databases_endpoint.py b/tableauserverclient/server/endpoint/databases_endpoint.py new file mode 100644 index 000000000..c0726abe2 --- /dev/null +++ b/tableauserverclient/server/endpoint/databases_endpoint.py @@ -0,0 +1,108 @@ +from .endpoint import api, Endpoint +from .exceptions import MissingRequiredFieldError +from .permissions_endpoint import _PermissionsEndpoint +from .default_permissions_endpoint import _DefaultPermissionsEndpoint + +from .. import RequestFactory, DatabaseItem, PaginationItem, PermissionsRule, Permission + +import logging + +logger = logging.getLogger('tableau.endpoint.databases') + + +class Databases(Endpoint): + def __init__(self, parent_srv): + super(Databases, self).__init__(parent_srv) + + self._permissions = _PermissionsEndpoint(parent_srv, lambda: self.baseurl) + self._default_permissions = _DefaultPermissionsEndpoint(parent_srv, lambda: self.baseurl) + + @property + def baseurl(self): + return "{0}/sites/{1}/databases".format(self.parent_srv.baseurl, self.parent_srv.site_id) + + @api(version="3.5") + def get(self, req_options=None): + logger.info('Querying all databases on site') + url = self.baseurl + server_response = self.get_request(url, req_options) + pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace) + all_database_items = DatabaseItem.from_response(server_response.content, self.parent_srv.namespace) + return all_database_items, pagination_item + + # Get 1 database + @api(version="3.5") + def get_by_id(self, database_id): + if not database_id: + error = "database ID undefined." + raise ValueError(error) + logger.info('Querying single database (ID: {0})'.format(database_id)) + url = "{0}/{1}".format(self.baseurl, database_id) + server_response = self.get_request(url) + return DatabaseItem.from_response(server_response.content, self.parent_srv.namespace)[0] + + @api(version="3.5") + def delete(self, database_id): + if not database_id: + error = "Database ID undefined." + raise ValueError(error) + url = "{0}/{1}".format(self.baseurl, database_id) + self.delete_request(url) + logger.info('Deleted single database (ID: {0})'.format(database_id)) + + @api(version="3.5") + def update(self, database_item): + if not database_item.id: + error = "Database item missing ID." + raise MissingRequiredFieldError(error) + + url = "{0}/{1}".format(self.baseurl, database_item.id) + update_req = RequestFactory.Database.update_req(database_item) + server_response = self.put_request(url, update_req) + logger.info('Updated database item (ID: {0})'.format(database_item.id)) + updated_database = DatabaseItem.from_response(server_response.content, self.parent_srv.namespace)[0] + return updated_database + + # Not Implemented Yet + @api(version="99") + def populate_tables(self, database_item): + if not database_item.id: + error = "database item missing ID. database must be retrieved from server first." + raise MissingRequiredFieldError(error) + + def column_fetcher(): + return self._get_tables_for_database(database_item) + + database_item._set_tables(column_fetcher) + logger.info('Populated tables for database (ID: {0}'.format(database_item.id)) + + def _get_tables_for_database(self, database_item): + url = "{0}/{1}/tables".format(self.baseurl, database_item.id) + server_response = self.get_request(url) + tables = TableItem.from_response(server_response.content, + self.parent_srv.namespace) + return tables + + @api(version='3.5') + def populate_permissions(self, item): + self._permissions.populate(item) + + @api(version='3.5') + def update_permission(self, item, rules): + return self._permissions.update(item, rules) + + @api(version='3.5') + def delete_permission(self, item, rules): + return self._permissions.delete(item, rules) + + @api(version='3.5') + def populate_table_default_permissions(self, item): + self._default_permissions.populate_default_permissions(item, Permission.Resource.Table) + + @api(version='3.5') + def update_table_default_permissions(self, item): + self._default_permissions.update_default_permissions(item, Permission.Resource.Table) + + @api(version='3.5') + def delete_table_default_permissions(self, item): + self._default_permissions.delete_default_permissions(item, Permission.Resource.Table) diff --git a/tableauserverclient/server/endpoint/tables_endpoint.py b/tableauserverclient/server/endpoint/tables_endpoint.py new file mode 100644 index 000000000..b8430a124 --- /dev/null +++ b/tableauserverclient/server/endpoint/tables_endpoint.py @@ -0,0 +1,108 @@ +from .endpoint import api, Endpoint +from .exceptions import MissingRequiredFieldError +from .permissions_endpoint import _PermissionsEndpoint +from .default_permissions_endpoint import _DefaultPermissionsEndpoint +from ..pager import Pager + +from .. import RequestFactory, TableItem, ColumnItem, PaginationItem, PermissionsRule, Permission + +import logging + +logger = logging.getLogger('tableau.endpoint.tables') + + +class Tables(Endpoint): + def __init__(self, parent_srv): + super(Tables, self).__init__(parent_srv) + + self._permissions = _PermissionsEndpoint(parent_srv, lambda: self.baseurl) + + @property + def baseurl(self): + return "{0}/sites/{1}/tables".format(self.parent_srv.baseurl, self.parent_srv.site_id) + + @api(version="3.5") + def get(self, req_options=None): + logger.info('Querying all tables on site') + url = self.baseurl + server_response = self.get_request(url, req_options) + pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace) + all_table_items = TableItem.from_response(server_response.content, self.parent_srv.namespace) + return all_table_items, pagination_item + + # Get 1 table + @api(version="3.5") + def get_by_id(self, table_id): + if not table_id: + error = "table ID undefined." + raise ValueError(error) + logger.info('Querying single table (ID: {0})'.format(table_id)) + url = "{0}/{1}".format(self.baseurl, table_id) + server_response = self.get_request(url) + return TableItem.from_response(server_response.content, self.parent_srv.namespace)[0] + + @api(version="3.5") + def delete(self, table_id): + if not table_id: + error = "Database ID undefined." + raise ValueError(error) + url = "{0}/{1}".format(self.baseurl, table_id) + self.delete_request(url) + logger.info('Deleted single table (ID: {0})'.format(table_id)) + + @api(version="3.5") + def update(self, table_item): + if not table_item.id: + error = "table item missing ID." + raise MissingRequiredFieldError(error) + + url = "{0}/{1}".format(self.baseurl, table_item.id) + update_req = RequestFactory.Table.update_req(table_item) + server_response = self.put_request(url, update_req) + logger.info('Updated table item (ID: {0})'.format(table_item.id)) + updated_table = TableItem.from_response(server_response.content, self.parent_srv.namespace)[0] + return updated_table + + # Get all columns of the table + @api(version="3.5") + def populate_columns(self, table_item, req_options=None): + if not table_item.id: + error = "Table item missing ID. table must be retrieved from server first." + raise MissingRequiredFieldError(error) + + def column_fetcher(): + return Pager(lambda options: self._get_columns_for_table(table_item, options), req_options) + + table_item._set_columns(column_fetcher) + logger.info('Populated columns for table (ID: {0}'.format(table_item.id)) + + def _get_columns_for_table(self, table_item, req_options=None): + url = "{0}/{1}/columns".format(self.baseurl, table_item.id) + server_response = self.get_request(url, req_options) + columns = ColumnItem.from_response(server_response.content, + self.parent_srv.namespace) + pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace) + return columns, pagination_item + + @api(version="3.5") + def update_column(self, table_item, column_item): + url = "{0}/{1}/columns/{2}".format(self.baseurl, table_item.id, column_item.id) + update_req = RequestFactory.Column.update_req(column_item) + server_response = self.put_request(url, update_req) + column = ColumnItem.from_response(server_response.content, self.parent_srv.namespace)[0] + + logger.info('Updated table item (ID: {0} & column item {1}'.format(table_item.id, + column_item.id)) + return column + + @api(version='3.5') + def populate_permissions(self, item): + self._permissions.populate(item) + + @api(version='3.5') + def update_permission(self, item, rules): + return self._permissions.update(item, rules) + + @api(version='3.5') + def delete_permission(self, item, rules): + return self._permissions.delete(item, rules) diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index b6739af6b..1f787382e 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -63,6 +63,36 @@ def signin_req(self, auth_item): return ET.tostring(xml_request) +class ColumnRequest(object): + def update_req(self, column_item): + xml_request = ET.Element('tsRequest') + column_element = ET.SubElement(xml_request, 'column') + + if column_item.description: + column_element.attrib['description'] = str(column_item.description) + + return ET.tostring(xml_request) + + +class DatabaseRequest(object): + def update_req(self, database_item): + xml_request = ET.Element('tsRequest') + database_element = ET.SubElement(xml_request, 'database') + if database_item.contact_id: + contact_element = ET.SubElement(database_element, 'contact') + contact_element.attrib['id'] = database_item.contact_id + + database_element.attrib['isCertified'] = str(database_item.certified).lower() + + if database_item.certification_note: + database_element.attrib['certificationNote'] = str(database_item.certification_note) + + if database_item.description: + database_element.attrib['description'] = str(database_item.description) + + return ET.tostring(xml_request) + + class DatasourceRequest(object): def _generate_xml(self, datasource_item, connection_credentials=None, connections=None): xml_request = ET.Element('tsRequest') @@ -317,6 +347,26 @@ def create_req(self, site_item): return ET.tostring(xml_request) +class TableRequest(object): + def update_req(self, table_item): + xml_request = ET.Element('tsRequest') + table_element = ET.SubElement(xml_request, 'table') + + if table_item.contact_id: + contact_element = ET.SubElement(table_element, 'contact') + contact_element.attrib['id'] = table_item.contact_id + + table_element.attrib['isCertified'] = str(table_item.certified).lower() + + if table_item.certification_note: + table_element.attrib['certificationNote'] = str(table_item.certification_note) + + if table_item.description: + table_element.attrib['description'] = str(table_item.description) + + return ET.tostring(xml_request) + + class TagRequest(object): def add_req(self, tag_set): xml_request = ET.Element('tsRequest') @@ -468,7 +518,9 @@ def empty_req(self, xml_request): class RequestFactory(object): Auth = AuthRequest() Connection = Connection() + Column = ColumnRequest() Datasource = DatasourceRequest() + Database = DatabaseRequest() Empty = EmptyRequest() Fileupload = FileuploadRequest() Group = GroupRequest() @@ -476,6 +528,7 @@ class RequestFactory(object): Project = ProjectRequest() Schedule = ScheduleRequest() Site = SiteRequest() + Table = TableRequest() Tag = TagRequest() Task = TaskRequest() User = UserRequest() diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index 536b3982a..9ba195d9d 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -3,7 +3,8 @@ from .exceptions import NotSignedInError from ..namespace import Namespace from .endpoint import Sites, Views, Users, Groups, Workbooks, Datasources, Projects, Auth, \ - Schedules, ServerInfo, Tasks, ServerInfoEndpointNotFoundError, Subscriptions, Jobs, Metadata + Schedules, ServerInfo, Tasks, ServerInfoEndpointNotFoundError, Subscriptions, Jobs, Metadata,\ + Databases, Tables from .endpoint.exceptions import EndpointUnavailableError, ServerInfoEndpointNotFoundError import requests @@ -51,6 +52,8 @@ def __init__(self, server_address, use_server_version=False): self.tasks = Tasks(self) self.subscriptions = Subscriptions(self) self.metadata = Metadata(self) + self.databases = Databases(self) + self.tables = Tables(self) self._namespace = Namespace() if use_server_version: diff --git a/test/assets/database_get.xml b/test/assets/database_get.xml new file mode 100644 index 000000000..7d22daf4c --- /dev/null +++ b/test/assets/database_get.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/test/assets/database_populate_permissions.xml b/test/assets/database_populate_permissions.xml new file mode 100644 index 000000000..21f30fea9 --- /dev/null +++ b/test/assets/database_populate_permissions.xml @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/test/assets/database_update.xml b/test/assets/database_update.xml new file mode 100644 index 000000000..b2cbd68c9 --- /dev/null +++ b/test/assets/database_update.xml @@ -0,0 +1,9 @@ + + + + + + + + \ No newline at end of file diff --git a/test/assets/table_get.xml b/test/assets/table_get.xml new file mode 100644 index 000000000..0bd2763d5 --- /dev/null +++ b/test/assets/table_get.xml @@ -0,0 +1,21 @@ + + + + + + + +
+ + + +
+ + +
+ + +
+
+
\ No newline at end of file diff --git a/test/assets/table_update.xml b/test/assets/table_update.xml new file mode 100644 index 000000000..975f0cedb --- /dev/null +++ b/test/assets/table_update.xml @@ -0,0 +1,8 @@ + + + + + +
+
\ No newline at end of file diff --git a/test/test_database.py b/test/test_database.py new file mode 100644 index 000000000..fb9ffbd86 --- /dev/null +++ b/test/test_database.py @@ -0,0 +1,87 @@ +import unittest +import os +import requests_mock +import xml.etree.ElementTree as ET +import tableauserverclient as TSC +from tableauserverclient.datetime_helpers import format_datetime +from tableauserverclient.server.endpoint.exceptions import InternalServerError +from tableauserverclient.server.request_factory import RequestFactory +from ._utils import read_xml_asset, read_xml_assets, asset + +GET_XML = 'database_get.xml' +POPULATE_PERMISSIONS_XML = 'database_populate_permissions.xml' +UPDATE_XML = 'database_update.xml' + + +class DatabaseTests(unittest.TestCase): + def setUp(self): + self.server = TSC.Server('http://test') + + # Fake signin + self.server._site_id = 'dad65087-b08b-4603-af4e-2887b8aafc67' + self.server._auth_token = 'j80k54ll2lfMZ0tv97mlPvvSCRyD0DOM' + self.server.version = "3.5" + + self.baseurl = self.server.databases.baseurl + + def test_get(self): + response_xml = read_xml_asset(GET_XML) + with requests_mock.mock() as m: + m.get(self.baseurl, text=response_xml) + all_databases, pagination_item = self.server.databases.get() + + self.assertEqual(5, pagination_item.total_available) + self.assertEqual('5ea59b45-e497-4827-8809-bfe213236f75', all_databases[0].id) + self.assertEqual('hyper', all_databases[0].connection_type) + self.assertEqual('hyper_0.hyper', all_databases[0].name) + + self.assertEqual('23591f2c-4802-4d6a-9e28-574a8ea9bc4c', all_databases[1].id) + self.assertEqual('sqlserver', all_databases[1].connection_type) + self.assertEqual('testv1', all_databases[1].name) + self.assertEqual('9324cf6b-ba72-4b8e-b895-ac3f28d2f0e0', all_databases[1].contact_id) + self.assertEqual(True, all_databases[1].certified) + + def test_update(self): + response_xml = read_xml_asset(UPDATE_XML) + with requests_mock.mock() as m: + m.put(self.baseurl + '/23591f2c-4802-4d6a-9e28-574a8ea9bc4c', text=response_xml) + single_database = TSC.DatabaseItem('test') + single_database.contact_id = '9324cf6b-ba72-4b8e-b895-ac3f28d2f0e0' + single_database._id = '23591f2c-4802-4d6a-9e28-574a8ea9bc4c' + single_database.certified = True + single_database.certification_note = "Test" + single_database = self.server.databases.update(single_database) + + self.assertEqual('23591f2c-4802-4d6a-9e28-574a8ea9bc4c', single_database.id) + self.assertEqual('9324cf6b-ba72-4b8e-b895-ac3f28d2f0e0', single_database.contact_id) + self.assertEqual(True, single_database.certified) + self.assertEqual("Test", single_database.certification_note) + + def test_populate_permissions(self): + with open(asset(POPULATE_PERMISSIONS_XML), 'rb') as f: + response_xml = f.read().decode('utf-8') + with requests_mock.mock() as m: + m.get(self.baseurl + '/0448d2ed-590d-4fa0-b272-a2a8a24555b5/permissions', text=response_xml) + single_database = TSC.DatabaseItem('test') + single_database._id = '0448d2ed-590d-4fa0-b272-a2a8a24555b5' + + self.server.databases.populate_permissions(single_database) + permissions = single_database.permissions + + self.assertEqual(permissions[0].grantee.tag_name, 'group') + self.assertEqual(permissions[0].grantee.id, '5e5e1978-71fa-11e4-87dd-7382f5c437af') + self.assertDictEqual(permissions[0].capabilities, { + TSC.Permission.Capability.ChangePermissions: TSC.Permission.Mode.Deny, + TSC.Permission.Capability.Read: TSC.Permission.Mode.Allow, + }) + + self.assertEqual(permissions[1].grantee.tag_name, 'user') + self.assertEqual(permissions[1].grantee.id, '7c37ee24-c4b1-42b6-a154-eaeab7ee330a') + self.assertDictEqual(permissions[1].capabilities, { + TSC.Permission.Capability.Write: TSC.Permission.Mode.Allow, + }) + + def test_delete(self): + with requests_mock.mock() as m: + m.delete(self.baseurl + '/0448d2ed-590d-4fa0-b272-a2a8a24555b5', status_code=204) + self.server.databases.delete('0448d2ed-590d-4fa0-b272-a2a8a24555b5') diff --git a/test/test_table.py b/test/test_table.py new file mode 100644 index 000000000..45af43c9a --- /dev/null +++ b/test/test_table.py @@ -0,0 +1,62 @@ +import unittest +import os +import requests_mock +import xml.etree.ElementTree as ET +import tableauserverclient as TSC +from tableauserverclient.datetime_helpers import format_datetime +from tableauserverclient.server.endpoint.exceptions import InternalServerError +from tableauserverclient.server.request_factory import RequestFactory +from ._utils import read_xml_asset, read_xml_assets, asset + +GET_XML = 'table_get.xml' +UPDATE_XML = 'table_update.xml' + + +class TableTests(unittest.TestCase): + def setUp(self): + self.server = TSC.Server('http://test') + + # Fake signin + self.server._site_id = 'dad65087-b08b-4603-af4e-2887b8aafc67' + self.server._auth_token = 'j80k54ll2lfMZ0tv97mlPvvSCRyD0DOM' + self.server.version = "3.5" + + self.baseurl = self.server.tables.baseurl + + def test_get(self): + response_xml = read_xml_asset(GET_XML) + with requests_mock.mock() as m: + m.get(self.baseurl, text=response_xml) + all_tables, pagination_item = self.server.tables.get() + + self.assertEqual(4, pagination_item.total_available) + self.assertEqual('10224773-ecee-42ac-b822-d786b0b8e4d9', all_tables[0].id) + self.assertEqual('dim_Product', all_tables[0].name) + + self.assertEqual('53c77bc1-fb41-4342-a75a-f68ac0656d0d', all_tables[1].id) + self.assertEqual('customer', all_tables[1].name) + self.assertEqual('dbo', all_tables[1].schema) + self.assertEqual('9324cf6b-ba72-4b8e-b895-ac3f28d2f0e0', all_tables[1].contact_id) + self.assertEqual(False, all_tables[1].certified) + + def test_update(self): + response_xml = read_xml_asset(UPDATE_XML) + with requests_mock.mock() as m: + m.put(self.baseurl + '/10224773-ecee-42ac-b822-d786b0b8e4d9', text=response_xml) + single_table = TSC.TableItem('test') + single_table._id = '10224773-ecee-42ac-b822-d786b0b8e4d9' + + single_table.contact_id = '8e1a8235-c9ee-4d61-ae82-2ffacceed8e0' + single_table.certified = True + single_table.certification_note = "Test" + single_table = self.server.tables.update(single_table) + + self.assertEqual('10224773-ecee-42ac-b822-d786b0b8e4d9', single_table.id) + self.assertEqual('8e1a8235-c9ee-4d61-ae82-2ffacceed8e0', single_table.contact_id) + self.assertEqual(True, single_table.certified) + self.assertEqual("Test", single_table.certification_note) + + def test_delete(self): + with requests_mock.mock() as m: + m.delete(self.baseurl + '/0448d2ed-590d-4fa0-b272-a2a8a24555b5', status_code=204) + self.server.tables.delete('0448d2ed-590d-4fa0-b272-a2a8a24555b5')