diff --git a/tableauserverclient/server/endpoint/auth_endpoint.py b/tableauserverclient/server/endpoint/auth_endpoint.py index 10f4cb4db..f74b88b21 100644 --- a/tableauserverclient/server/endpoint/auth_endpoint.py +++ b/tableauserverclient/server/endpoint/auth_endpoint.py @@ -1,5 +1,5 @@ from ..request_factory import RequestFactory - +from .exceptions import ServerResponseError from .endpoint import Endpoint, api import xml.etree.ElementTree as ET import logging @@ -52,3 +52,24 @@ def sign_out(self): self.post_request(url, '') self.parent_srv._clear_auth() logger.info('Signed out') + + @api(version="2.6") + def switch_site(self, site_item): + url = "{0}/{1}".format(self.baseurl, 'switchSite') + switch_req = RequestFactory.Auth.switch_req(site_item.content_url) + try: + server_response = self.post_request(url, switch_req) + except ServerResponseError as e: + if e.code == "403070": + return Auth.contextmgr(self.sign_out) + else: + raise e + self.parent_srv._namespace.detect(server_response.content) + self._check_status(server_response) + parsed_response = ET.fromstring(server_response.content) + site_id = parsed_response.find('.//t:site', namespaces=self.parent_srv.namespace).get('id', None) + user_id = parsed_response.find('.//t:user', namespaces=self.parent_srv.namespace).get('id', None) + auth_token = parsed_response.find('t:credentials', namespaces=self.parent_srv.namespace).get('token', None) + self.parent_srv._set_auth(site_id, user_id, auth_token) + logger.info('Signed into {0} as user with id {1}'.format(self.parent_srv.server_address, user_id)) + return Auth.contextmgr(self.sign_out) diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index 9c869c686..c1a54760a 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -67,6 +67,13 @@ def signin_req(self, auth_item): user_element.attrib['id'] = auth_item.user_id_to_impersonate return ET.tostring(xml_request) + def switch_req(self, site_content_url): + xml_request = ET.Element('tsRequest') + + site_element = ET.SubElement(xml_request, 'site') + site_element.attrib['contentUrl'] = site_content_url + return ET.tostring(xml_request) + class ColumnRequest(object): def update_req(self, column_item): diff --git a/test/test_auth.py b/test/test_auth.py index 28e241335..b879ab121 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -90,3 +90,19 @@ def test_sign_out(self): self.assertIsNone(self.server._auth_token) self.assertIsNone(self.server._site_id) self.assertIsNone(self.server._user_id) + + def test_switch_site(self): + self.server.version = '2.6' + baseurl = self.server.auth.baseurl + site_id, user_id, auth_token = list('123') + self.server._set_auth(site_id, user_id, auth_token) + with open(SIGN_IN_XML, 'rb') as f: + response_xml = f.read().decode('utf-8') + with requests_mock.mock() as m: + m.post(baseurl + '/switchSite', text=response_xml) + site = TSC.SiteItem('Samples', 'Samples') + self.server.auth.switch_site(site) + + self.assertEqual('eIX6mvFsqyansa4KqEI1UwOpS8ggRs2l', self.server.auth_token) + self.assertEqual('6b7179ba-b82b-4f0f-91ed-812074ac5da6', self.server.site_id) + self.assertEqual('1a96d216-e9b8-497b-a82a-0b899a965e01', self.server.user_id)