diff --git a/tableauserverclient/models/task_item.py b/tableauserverclient/models/task_item.py index 159869b07..0ffc3bfab 100644 --- a/tableauserverclient/models/task_item.py +++ b/tableauserverclient/models/task_item.py @@ -1,8 +1,11 @@ +from datetime import datetime +from typing import List, Optional + from defusedxml.ElementTree import fromstring from tableauserverclient.datetime_helpers import parse_datetime -from .schedule_item import ScheduleItem -from .target import Target +from tableauserverclient.models.schedule_item import ScheduleItem +from tableauserverclient.models.target import Target class TaskItem(object): @@ -19,14 +22,14 @@ class Type: def __init__( self, - id_, - task_type, - priority, - consecutive_failed_count=0, - schedule_id=None, - schedule_item=None, - last_run_at=None, - target=None, + id_: str, + task_type: str, + priority: int, + consecutive_failed_count: int = 0, + schedule_id: Optional[str] = None, + schedule_item: Optional[ScheduleItem] = None, + last_run_at: Optional[datetime] = None, + target: Optional[Target] = None, ): self.id = id_ self.task_type = task_type @@ -37,14 +40,14 @@ def __init__( self.last_run_at = last_run_at self.target = target - def __repr__(self): + def __repr__(self) -> str: return ( "".format(**self.__dict__) ) @classmethod - def from_response(cls, xml, ns, task_type=Type.ExtractRefresh): + def from_response(cls, xml, ns, task_type=Type.ExtractRefresh) -> List["TaskItem"]: parsed_response = fromstring(xml) all_tasks_xml = parsed_response.findall(".//t:task/t:{}".format(task_type), namespaces=ns) @@ -62,8 +65,7 @@ def _parse_element(cls, element, ns): last_run_at_element = element.find(".//t:lastRunAt", namespaces=ns) schedule_item_list = ScheduleItem.from_element(element, ns) - if len(schedule_item_list) >= 1: - schedule_item = schedule_item_list[0] + schedule_item = next(iter(schedule_item_list), None) # according to the Tableau Server REST API documentation, # there should be only one of workbook or datasource @@ -87,14 +89,14 @@ def _parse_element(cls, element, ns): task_type, priority, consecutive_failed_count, - schedule_item.id, + schedule_item.id if schedule_item is not None else None, schedule_item, last_run_at, target, ) @staticmethod - def _translate_task_type(task_type): + def _translate_task_type(task_type: str) -> str: if task_type in TaskItem._TASK_TYPE_MAPPING: return TaskItem._TASK_TYPE_MAPPING[task_type] else: diff --git a/tableauserverclient/server/endpoint/tasks_endpoint.py b/tableauserverclient/server/endpoint/tasks_endpoint.py index 092597388..a727a515f 100644 --- a/tableauserverclient/server/endpoint/tasks_endpoint.py +++ b/tableauserverclient/server/endpoint/tasks_endpoint.py @@ -1,19 +1,23 @@ import logging +from typing import List, Optional, Tuple, TYPE_CHECKING -from .endpoint import Endpoint, api -from .exceptions import MissingRequiredFieldError +from tableauserverclient.server.endpoint.endpoint import Endpoint, api +from tableauserverclient.server.endpoint.exceptions import MissingRequiredFieldError from tableauserverclient.models import TaskItem, PaginationItem from tableauserverclient.server import RequestFactory from tableauserverclient.helpers.logging import logger +if TYPE_CHECKING: + from tableauserverclient.server.request_options import RequestOptions + class Tasks(Endpoint): @property - def baseurl(self): + def baseurl(self) -> str: return "{0}/sites/{1}/tasks".format(self.parent_srv.baseurl, self.parent_srv.site_id) - def __normalize_task_type(self, task_type): + def __normalize_task_type(self, task_type: str) -> str: """ The word for extract refresh used in API URL is "extractRefreshes". It is different than the tag "extractRefresh" used in the request body. @@ -24,11 +28,13 @@ def __normalize_task_type(self, task_type): return task_type @api(version="2.6") - def get(self, req_options=None, task_type=TaskItem.Type.ExtractRefresh): + def get( + self, req_options: Optional["RequestOptions"] = None, task_type: str = TaskItem.Type.ExtractRefresh + ) -> Tuple[List[TaskItem], PaginationItem]: if task_type == TaskItem.Type.DataAcceleration: self.parent_srv.assert_at_least_version("3.8", "Data Acceleration Tasks") - logger.info("Querying all {} tasks for the site".format(task_type)) + logger.info("Querying all %s tasks for the site", task_type) url = "{0}/{1}".format(self.baseurl, self.__normalize_task_type(task_type)) server_response = self.get_request(url, req_options) @@ -38,11 +44,11 @@ def get(self, req_options=None, task_type=TaskItem.Type.ExtractRefresh): return all_tasks, pagination_item @api(version="2.6") - def get_by_id(self, task_id): + def get_by_id(self, task_id: str) -> TaskItem: if not task_id: error = "No Task ID provided" raise ValueError(error) - logger.info("Querying a single task by id ({})".format(task_id)) + logger.info("Querying a single task by id %s", task_id) url = "{}/{}/{}".format( self.baseurl, self.__normalize_task_type(TaskItem.Type.ExtractRefresh), @@ -56,14 +62,14 @@ def create(self, extract_item: TaskItem) -> TaskItem: if not extract_item: error = "No extract refresh provided" raise ValueError(error) - logger.info("Creating an extract refresh ({})".format(extract_item)) + logger.info("Creating an extract refresh %s", extract_item) url = "{0}/{1}".format(self.baseurl, self.__normalize_task_type(TaskItem.Type.ExtractRefresh)) create_req = RequestFactory.Task.create_extract_req(extract_item) server_response = self.post_request(url, create_req) return server_response.content @api(version="2.6") - def run(self, task_item): + def run(self, task_item: TaskItem) -> bytes: if not task_item.id: error = "Task item missing ID." raise MissingRequiredFieldError(error) @@ -79,7 +85,7 @@ def run(self, task_item): # Delete 1 task by id @api(version="3.6") - def delete(self, task_id, task_type=TaskItem.Type.ExtractRefresh): + def delete(self, task_id: str, task_type: str = TaskItem.Type.ExtractRefresh) -> None: if task_type == TaskItem.Type.DataAcceleration: self.parent_srv.assert_at_least_version("3.8", "Data Acceleration Tasks") @@ -88,4 +94,4 @@ def delete(self, task_id, task_type=TaskItem.Type.ExtractRefresh): raise ValueError(error) url = "{0}/{1}/{2}".format(self.baseurl, self.__normalize_task_type(task_type), task_id) self.delete_request(url) - logger.info("Deleted single task (ID: {0})".format(task_id)) + logger.info("Deleted single task (ID: %s)", task_id) diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index 7fb9bf9ed..6316527ec 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -1032,6 +1032,16 @@ def run_req(self, xml_request, task_item): def create_extract_req(self, xml_request: ET.Element, extract_item: "TaskItem") -> bytes: extract_element = ET.SubElement(xml_request, "extractRefresh") + # Main attributes + extract_element.attrib["type"] = extract_item.task_type + + if extract_item.target is not None: + target_element = ET.SubElement(extract_element, extract_item.target.type) + target_element.attrib["id"] = extract_item.target.id + + if extract_item.schedule_item is None: + return ET.tostring(xml_request) + # Schedule attributes schedule_element = ET.SubElement(xml_request, "schedule") @@ -1043,17 +1053,11 @@ def create_extract_req(self, xml_request: ET.Element, extract_item: "TaskItem") frequency_element.attrib["end"] = str(interval_item.end_time) if hasattr(interval_item, "interval") and interval_item.interval: intervals_element = ET.SubElement(frequency_element, "intervals") - for interval in interval_item._interval_type_pairs(): + for interval in interval_item._interval_type_pairs(): # type: ignore expression, value = interval single_interval_element = ET.SubElement(intervals_element, "interval") single_interval_element.attrib[expression] = value - # Main attributes - extract_element.attrib["type"] = extract_item.task_type - - target_element = ET.SubElement(extract_element, extract_item.target.type) - target_element.attrib["id"] = extract_item.target.id - return ET.tostring(xml_request) diff --git a/test/assets/tasks_without_schedule.xml b/test/assets/tasks_without_schedule.xml new file mode 100644 index 000000000..e669bf67f --- /dev/null +++ b/test/assets/tasks_without_schedule.xml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/test/test_task.py b/test/test_task.py index 4eb2c02e2..4e0157dfd 100644 --- a/test/test_task.py +++ b/test/test_task.py @@ -1,6 +1,7 @@ import os import unittest from datetime import time +from pathlib import Path import requests_mock @@ -8,7 +9,7 @@ from tableauserverclient.datetime_helpers import parse_datetime from tableauserverclient.models.task_item import TaskItem -TEST_ASSET_DIR = os.path.join(os.path.dirname(__file__), "assets") +TEST_ASSET_DIR = Path(__file__).parent / "assets" GET_XML_NO_WORKBOOK = os.path.join(TEST_ASSET_DIR, "tasks_no_workbook_or_datasource.xml") GET_XML_WITH_WORKBOOK = os.path.join(TEST_ASSET_DIR, "tasks_with_workbook.xml") @@ -17,6 +18,7 @@ GET_XML_DATAACCELERATION_TASK = os.path.join(TEST_ASSET_DIR, "tasks_with_dataacceleration_task.xml") GET_XML_RUN_NOW_RESPONSE = os.path.join(TEST_ASSET_DIR, "tasks_run_now_response.xml") GET_XML_CREATE_TASK_RESPONSE = os.path.join(TEST_ASSET_DIR, "tasks_create_extract_task.xml") +GET_XML_WITHOUT_SCHEDULE = TEST_ASSET_DIR / "tasks_without_schedule.xml" class TaskTests(unittest.TestCase): @@ -86,6 +88,15 @@ def test_get_task_with_schedule(self): self.assertEqual("workbook", task.target.type) self.assertEqual("b60b4efd-a6f7-4599-beb3-cb677e7abac1", task.schedule_id) + def test_get_task_without_schedule(self): + with requests_mock.mock() as m: + m.get(self.baseurl, text=GET_XML_WITHOUT_SCHEDULE.read_text()) + all_tasks, pagination_item = self.server.tasks.get() + + task = all_tasks[0] + self.assertEqual("c7a9327e-1cda-4504-b026-ddb43b976d1d", task.target.id) + self.assertEqual("datasource", task.target.type) + def test_delete(self): with requests_mock.mock() as m: m.delete(self.baseurl + "/c7a9327e-1cda-4504-b026-ddb43b976d1d", status_code=204)