diff --git a/CHANGES.md b/CHANGES.md index 8d29f36..379b7f2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,6 +8,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and ## 0.4.2 - 2024-xx-xx - {pull}`85` simplifies code since loky is a dependency. +- {pull}`89` restructures the package. ## 0.4.1 - 2024-01-12 diff --git a/src/pytask_parallel/config.py b/src/pytask_parallel/config.py index 7e4dd2e..271662c 100644 --- a/src/pytask_parallel/config.py +++ b/src/pytask_parallel/config.py @@ -2,40 +2,46 @@ from __future__ import annotations -import enum import os from typing import Any from pytask import hookimpl +from pytask_parallel import execute +from pytask_parallel import processes +from pytask_parallel import threads from pytask_parallel.backends import ParallelBackend @hookimpl def pytask_parse_config(config: dict[str, Any]) -> None: """Parse the configuration.""" + __tracebackhide__ = True + if config["n_workers"] == "auto": config["n_workers"] = max(os.cpu_count() - 1, 1) - if ( - isinstance(config["parallel_backend"], str) - and config["parallel_backend"] in ParallelBackend._value2member_map_ # noqa: SLF001 - ): + try: config["parallel_backend"] = ParallelBackend(config["parallel_backend"]) - elif ( - isinstance(config["parallel_backend"], enum.Enum) - and config["parallel_backend"] in ParallelBackend - ): - pass - else: - msg = f"Invalid value for 'parallel_backend'. Got {config['parallel_backend']}." - raise ValueError(msg) + except ValueError: + msg = ( + f"Invalid value for 'parallel_backend'. Got {config['parallel_backend']}. " + f"Choose one of {', '.join([e.value for e in ParallelBackend])}." + ) + raise ValueError(msg) from None config["delay"] = 0.1 @hookimpl def pytask_post_parse(config: dict[str, Any]) -> None: - """Disable parallelization if debugging is enabled.""" + """Register the parallel backend if debugging is not enabled.""" if config["pdb"] or config["trace"] or config["dry_run"]: config["n_workers"] = 1 + + if config["n_workers"] > 1: + config["pm"].register(execute) + if config["parallel_backend"] == ParallelBackend.THREADS: + config["pm"].register(threads) + else: + config["pm"].register(processes) diff --git a/src/pytask_parallel/execute.py b/src/pytask_parallel/execute.py index 837b666..7145da4 100644 --- a/src/pytask_parallel/execute.py +++ b/src/pytask_parallel/execute.py @@ -2,66 +2,29 @@ from __future__ import annotations -import inspect import sys import time -import warnings -from functools import partial from typing import TYPE_CHECKING from typing import Any -from typing import Callable -import cloudpickle from attrs import define from attrs import field from pytask import ExecutionReport -from pytask import Mark from pytask import PNode -from pytask import PTask from pytask import PythonNode from pytask import Session -from pytask import Task -from pytask import WarningReport -from pytask import console -from pytask import get_marks from pytask import hookimpl -from pytask import parse_warning_filter -from pytask import remove_internal_traceback_frames_from_exc_info -from pytask import warning_record_to_str -from pytask.tree_util import PyTree -from pytask.tree_util import tree_leaves from pytask.tree_util import tree_map -from pytask.tree_util import tree_structure -from rich.traceback import Traceback from pytask_parallel.backends import PARALLEL_BACKEND_BUILDER from pytask_parallel.backends import ParallelBackend if TYPE_CHECKING: from concurrent.futures import Future - from pathlib import Path - from types import ModuleType from types import TracebackType - from rich.console import ConsoleOptions - @hookimpl -def pytask_post_parse(config: dict[str, Any]) -> None: - """Register the parallel backend.""" - if config["parallel_backend"] == ParallelBackend.THREADS: - config["pm"].register(DefaultBackendNameSpace) - else: - config["pm"].register(ProcessesNameSpace) - - if PARALLEL_BACKEND_BUILDER[config["parallel_backend"]] is None: - raise - config["_parallel_executor"] = PARALLEL_BACKEND_BUILDER[ - config["parallel_backend"] - ]() - - -@hookimpl(tryfirst=True) def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR0915 """Execute tasks with a parallel backend. @@ -75,132 +38,122 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 """ __tracebackhide__ = True - if session.config["n_workers"] > 1: - reports = session.execution_reports - running_tasks: dict[str, Future[Any]] = {} + reports = session.execution_reports + running_tasks: dict[str, Future[Any]] = {} - parallel_backend = PARALLEL_BACKEND_BUILDER[ - session.config["parallel_backend"] - ]() + parallel_backend = PARALLEL_BACKEND_BUILDER[session.config["parallel_backend"]]() - with parallel_backend(max_workers=session.config["n_workers"]) as executor: - session.config["_parallel_executor"] = executor - sleeper = _Sleeper() + with parallel_backend(max_workers=session.config["n_workers"]) as executor: + session.config["_parallel_executor"] = executor + sleeper = _Sleeper() - while session.scheduler.is_active(): - try: - newly_collected_reports = [] - n_new_tasks = session.config["n_workers"] - len(running_tasks) + while session.scheduler.is_active(): + try: + newly_collected_reports = [] + n_new_tasks = session.config["n_workers"] - len(running_tasks) - ready_tasks = ( - list(session.scheduler.get_ready(n_new_tasks)) - if n_new_tasks >= 1 - else [] - ) + ready_tasks = ( + list(session.scheduler.get_ready(n_new_tasks)) + if n_new_tasks >= 1 + else [] + ) - for task_name in ready_tasks: - task = session.dag.nodes[task_name]["task"] - session.hook.pytask_execute_task_log_start( + for task_name in ready_tasks: + task = session.dag.nodes[task_name]["task"] + session.hook.pytask_execute_task_log_start( + session=session, task=task + ) + try: + session.hook.pytask_execute_task_setup( session=session, task=task ) - try: - session.hook.pytask_execute_task_setup( - session=session, task=task + except Exception: # noqa: BLE001 + report = ExecutionReport.from_task_and_exception( + task, sys.exc_info() + ) + newly_collected_reports.append(report) + session.scheduler.done(task_name) + else: + running_tasks[task_name] = session.hook.pytask_execute_task( + session=session, task=task + ) + sleeper.reset() + + if not ready_tasks: + sleeper.increment() + + for task_name in list(running_tasks): + future = running_tasks[task_name] + if future.done(): + # An exception was thrown before the task was executed. + if future.exception() is not None: + exc_info = _parse_future_exception(future.exception()) + warning_reports = [] + # A task raised an exception. + else: + (python_nodes, warning_reports, task_exception) = ( + future.result() ) - except Exception: # noqa: BLE001 - report = ExecutionReport.from_task_and_exception( - task, sys.exc_info() + session.warnings.extend(warning_reports) + exc_info = ( + _parse_future_exception(future.exception()) + or task_exception ) - newly_collected_reports.append(report) + + if exc_info is not None: + task = session.dag.nodes[task_name]["task"] + newly_collected_reports.append( + ExecutionReport.from_task_and_exception(task, exc_info) + ) + running_tasks.pop(task_name) session.scheduler.done(task_name) else: - running_tasks[task_name] = session.hook.pytask_execute_task( - session=session, task=task - ) - sleeper.reset() - - if not ready_tasks: - sleeper.increment() - - for task_name in list(running_tasks): - future = running_tasks[task_name] - if future.done(): - # An exception was thrown before the task was executed. - if future.exception() is not None: - exc_info = _parse_future_exception(future.exception()) - warning_reports = [] - # A task raised an exception. - else: - ( - python_nodes, - warning_reports, - task_exception, - ) = future.result() - session.warnings.extend(warning_reports) - exc_info = ( - _parse_future_exception(future.exception()) - or task_exception + task = session.dag.nodes[task_name]["task"] + + # Update PythonNodes with the values from the future if + # not threads. + if ( + session.config["parallel_backend"] + != ParallelBackend.THREADS + ): + task.produces = tree_map( + _update_python_node, task.produces, python_nodes ) - if exc_info is not None: - task = session.dag.nodes[task_name]["task"] - newly_collected_reports.append( - ExecutionReport.from_task_and_exception( - task, exc_info - ) + try: + session.hook.pytask_execute_task_teardown( + session=session, task=task + ) + except Exception: # noqa: BLE001 + report = ExecutionReport.from_task_and_exception( + task, sys.exc_info() ) - running_tasks.pop(task_name) - session.scheduler.done(task_name) else: - task = session.dag.nodes[task_name]["task"] - - # Update PythonNodes with the values from the future if - # not threads. - if ( - session.config["parallel_backend"] - != ParallelBackend.THREADS - ): - task.produces = tree_map( - _update_python_node, - task.produces, - python_nodes, - ) - - try: - session.hook.pytask_execute_task_teardown( - session=session, task=task - ) - except Exception: # noqa: BLE001 - report = ExecutionReport.from_task_and_exception( - task, sys.exc_info() - ) - else: - report = ExecutionReport.from_task(task) - - running_tasks.pop(task_name) - newly_collected_reports.append(report) - session.scheduler.done(task_name) - else: - pass + report = ExecutionReport.from_task(task) - for report in newly_collected_reports: - session.hook.pytask_execute_task_process_report( - session=session, report=report - ) - session.hook.pytask_execute_task_log_end( - session=session, task=task, report=report - ) - reports.append(report) + running_tasks.pop(task_name) + newly_collected_reports.append(report) + session.scheduler.done(task_name) + else: + pass - if session.should_stop: - break - sleeper.sleep() + for report in newly_collected_reports: + session.hook.pytask_execute_task_process_report( + session=session, report=report + ) + session.hook.pytask_execute_task_log_end( + session=session, task=task, report=report + ) + reports.append(report) - except KeyboardInterrupt: + if session.should_stop: break + sleeper.sleep() - return True - return None + except KeyboardInterrupt: + break + + return True def _update_python_node(x: PNode, y: PythonNode | None) -> PNode: @@ -216,215 +169,6 @@ def _parse_future_exception( return None if exc is None else (type(exc), exc, exc.__traceback__) -class ProcessesNameSpace: - """The name space for hooks related to processes.""" - - @staticmethod - @hookimpl(tryfirst=True) - def pytask_execute_task(session: Session, task: PTask) -> Future[Any] | None: - """Execute a task. - - Take a task, pickle it and send the bytes over to another process. - - """ - if session.config["n_workers"] > 1: - kwargs = _create_kwargs_for_task(task) - - # Task modules are dynamically loaded and added to `sys.modules`. Thus, - # cloudpickle believes the module of the task function is also importable in - # the child process. We have to register the module as dynamic again, so - # that cloudpickle will pickle it with the function. See cloudpickle#417, - # pytask#373 and pytask#374. - task_module = _get_module(task.function, getattr(task, "path", None)) - cloudpickle.register_pickle_by_value(task_module) - - return session.config["_parallel_executor"].submit( - _execute_task, - task=task, - kwargs=kwargs, - show_locals=session.config["show_locals"], - console_options=console.options, - session_filterwarnings=session.config["filterwarnings"], - task_filterwarnings=get_marks(task, "filterwarnings"), - ) - return None - - -def _raise_exception_on_breakpoint(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 - msg = ( - "You cannot use 'breakpoint()' or 'pdb.set_trace()' while parallelizing the " - "execution of tasks with pytask-parallel. Please, remove the breakpoint or run " - "the task without parallelization to debug it." - ) - raise RuntimeError(msg) - - -def _patch_set_trace_and_breakpoint() -> None: - """Patch :func:`pdb.set_trace` and :func:`breakpoint`. - - Patch sys.breakpointhook to intercept any call of breakpoint() and pdb.set_trace in - a subprocess and print a better exception message. - - """ - import pdb # noqa: T100 - import sys - - pdb.set_trace = _raise_exception_on_breakpoint - sys.breakpointhook = _raise_exception_on_breakpoint - - -def _execute_task( # noqa: PLR0913 - task: PTask, - kwargs: dict[str, Any], - show_locals: bool, # noqa: FBT001 - console_options: ConsoleOptions, - session_filterwarnings: tuple[str, ...], - task_filterwarnings: tuple[Mark, ...], -) -> tuple[ - PyTree[PythonNode | None], - list[WarningReport], - tuple[type[BaseException], BaseException, str] | None, -]: - """Unserialize and execute task. - - This function receives bytes and unpickles them to a task which is them execute in a - spawned process or thread. - - """ - __tracebackhide__ = True - _patch_set_trace_and_breakpoint() - - with warnings.catch_warnings(record=True) as log: - for arg in session_filterwarnings: - warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) - - # apply filters from "filterwarnings" marks - for mark in task_filterwarnings: - for arg in mark.args: - warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) - - try: - out = task.execute(**kwargs) - except Exception: # noqa: BLE001 - exc_info = sys.exc_info() - processed_exc_info = _process_exception( - exc_info, show_locals, console_options - ) - else: - _handle_task_function_return(task, out) - processed_exc_info = None - - task_display_name = getattr(task, "display_name", task.name) - warning_reports = [] - for warning_message in log: - fs_location = warning_message.filename, warning_message.lineno - warning_reports.append( - WarningReport( - message=warning_record_to_str(warning_message), - fs_location=fs_location, - id_=task_display_name, - ) - ) - - python_nodes = tree_map( - lambda x: x if isinstance(x, PythonNode) else None, task.produces - ) - - return python_nodes, warning_reports, processed_exc_info - - -def _process_exception( - exc_info: tuple[type[BaseException], BaseException, TracebackType | None], - show_locals: bool, # noqa: FBT001 - console_options: ConsoleOptions, -) -> tuple[type[BaseException], BaseException, str]: - """Process the exception and convert the traceback to a string.""" - exc_info = remove_internal_traceback_frames_from_exc_info(exc_info) - traceback = Traceback.from_exception(*exc_info, show_locals=show_locals) - segments = console.render(traceback, options=console_options) - text = "".join(segment.text for segment in segments) - return (*exc_info[:2], text) - - -def _handle_task_function_return(task: PTask, out: Any) -> None: - if "return" not in task.produces: - return - - structure_out = tree_structure(out) - structure_return = tree_structure(task.produces["return"]) - # strict must be false when none is leaf. - if not structure_return.is_prefix(structure_out, strict=False): - msg = ( - "The structure of the return annotation is not a subtree of " - "the structure of the function return.\n\nFunction return: " - f"{structure_out}\n\nReturn annotation: {structure_return}" - ) - raise ValueError(msg) - - nodes = tree_leaves(task.produces["return"]) - values = structure_return.flatten_up_to(out) - for node, value in zip(nodes, values): - node.save(value) - - -class DefaultBackendNameSpace: - """The name space for hooks related to threads.""" - - @staticmethod - @hookimpl(tryfirst=True) - def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None: - """Execute a task. - - Since threads have shared memory, it is not necessary to pickle and unpickle the - task. - - """ - if session.config["n_workers"] > 1: - kwargs = _create_kwargs_for_task(task) - return session.config["_parallel_executor"].submit( - _mock_processes_for_threads, task=task, **kwargs - ) - return None - - -def _mock_processes_for_threads( - task: PTask, **kwargs: Any -) -> tuple[ - None, list[Any], tuple[type[BaseException], BaseException, TracebackType] | None -]: - """Mock execution function such that it returns the same as for processes. - - The function for processes returns ``warning_reports`` and an ``exception``. With - threads, these object are collected by the main and not the subprocess. So, we just - return placeholders. - - """ - __tracebackhide__ = True - try: - out = task.function(**kwargs) - except Exception: # noqa: BLE001 - exc_info = sys.exc_info() - else: - _handle_task_function_return(task, out) - exc_info = None - return None, [], exc_info - - -def _create_kwargs_for_task(task: PTask) -> dict[str, PyTree[Any]]: - """Create kwargs for task function.""" - parameters = inspect.signature(task.function).parameters - - kwargs = {} - for name, value in task.depends_on.items(): - kwargs[name] = tree_map(lambda x: x.load(), value) - - for name, value in task.produces.items(): - if name in parameters: - kwargs[name] = tree_map(lambda x: x.load(), value) - - return kwargs - - @define(kw_only=True) class _Sleeper: """A sleeper that always sleeps a bit and up to 1 second if you don't wake it up. @@ -446,22 +190,3 @@ def increment(self) -> None: def sleep(self) -> None: time.sleep(self.timings[self.timing_idx]) - - -def _get_module(func: Callable[..., Any], path: Path | None) -> ModuleType: - """Get the module of a python function. - - ``functools.partial`` obfuscates the module of the function and - ``inspect.getmodule`` returns :mod`functools`. Therefore, we recover the original - function. - - We use the path from the task module to aid the search although it is not clear - whether it helps. - - """ - if isinstance(func, partial): - func = func.func - - if path: - return inspect.getmodule(func, path.as_posix()) - return inspect.getmodule(func) diff --git a/src/pytask_parallel/plugin.py b/src/pytask_parallel/plugin.py index 8ccf789..353cdef 100644 --- a/src/pytask_parallel/plugin.py +++ b/src/pytask_parallel/plugin.py @@ -8,7 +8,6 @@ from pytask_parallel import build from pytask_parallel import config -from pytask_parallel import execute from pytask_parallel import logging if TYPE_CHECKING: @@ -20,5 +19,4 @@ def pytask_add_hooks(pm: PluginManager) -> None: """Register plugins.""" pm.register(build) pm.register(config) - pm.register(execute) pm.register(logging) diff --git a/src/pytask_parallel/processes.py b/src/pytask_parallel/processes.py new file mode 100644 index 0000000..c6de65a --- /dev/null +++ b/src/pytask_parallel/processes.py @@ -0,0 +1,180 @@ +"""Contains functions related to processes and loky.""" + +from __future__ import annotations + +import inspect +import sys +import warnings +from functools import partial +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable + +import cloudpickle +from pytask import Mark +from pytask import PTask +from pytask import PythonNode +from pytask import Session +from pytask import WarningReport +from pytask import console +from pytask import get_marks +from pytask import hookimpl +from pytask import parse_warning_filter +from pytask import remove_internal_traceback_frames_from_exc_info +from pytask import warning_record_to_str +from pytask.tree_util import PyTree +from pytask.tree_util import tree_map +from rich.traceback import Traceback + +from pytask_parallel.utils import create_kwargs_for_task +from pytask_parallel.utils import handle_task_function_return + +if TYPE_CHECKING: + from concurrent.futures import Future + from pathlib import Path + from types import ModuleType + from types import TracebackType + + from rich.console import ConsoleOptions + + +@hookimpl +def pytask_execute_task(session: Session, task: PTask) -> Future[Any]: + """Execute a task. + + Take a task, pickle it and send the bytes over to another process. + + """ + kwargs = create_kwargs_for_task(task) + + # Task modules are dynamically loaded and added to `sys.modules`. Thus, cloudpickle + # believes the module of the task function is also importable in the child process. + # We have to register the module as dynamic again, so that cloudpickle will pickle + # it with the function. See cloudpickle#417, pytask#373 and pytask#374. + task_module = _get_module(task.function, getattr(task, "path", None)) + cloudpickle.register_pickle_by_value(task_module) + + return session.config["_parallel_executor"].submit( + _execute_task, + task=task, + kwargs=kwargs, + show_locals=session.config["show_locals"], + console_options=console.options, + session_filterwarnings=session.config["filterwarnings"], + task_filterwarnings=get_marks(task, "filterwarnings"), + ) + + +def _raise_exception_on_breakpoint(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 + msg = ( + "You cannot use 'breakpoint()' or 'pdb.set_trace()' while parallelizing the " + "execution of tasks with pytask-parallel. Please, remove the breakpoint or run " + "the task without parallelization to debug it." + ) + raise RuntimeError(msg) + + +def _patch_set_trace_and_breakpoint() -> None: + """Patch :func:`pdb.set_trace` and :func:`breakpoint`. + + Patch sys.breakpointhook to intercept any call of breakpoint() and pdb.set_trace in + a subprocess and print a better exception message. + + """ + import pdb # noqa: T100 + import sys + + pdb.set_trace = _raise_exception_on_breakpoint + sys.breakpointhook = _raise_exception_on_breakpoint + + +def _execute_task( # noqa: PLR0913 + task: PTask, + kwargs: dict[str, Any], + show_locals: bool, # noqa: FBT001 + console_options: ConsoleOptions, + session_filterwarnings: tuple[str, ...], + task_filterwarnings: tuple[Mark, ...], +) -> tuple[ + PyTree[PythonNode | None], + list[WarningReport], + tuple[type[BaseException], BaseException, str] | None, +]: + """Unserialize and execute task. + + This function receives bytes and unpickles them to a task which is them execute in a + spawned process or thread. + + """ + __tracebackhide__ = True + _patch_set_trace_and_breakpoint() + + with warnings.catch_warnings(record=True) as log: + for arg in session_filterwarnings: + warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) + + # apply filters from "filterwarnings" marks + for mark in task_filterwarnings: + for arg in mark.args: + warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) + + try: + out = task.execute(**kwargs) + except Exception: # noqa: BLE001 + exc_info = sys.exc_info() + processed_exc_info = _process_exception( + exc_info, show_locals, console_options + ) + else: + handle_task_function_return(task, out) + processed_exc_info = None + + task_display_name = getattr(task, "display_name", task.name) + warning_reports = [] + for warning_message in log: + fs_location = warning_message.filename, warning_message.lineno + warning_reports.append( + WarningReport( + message=warning_record_to_str(warning_message), + fs_location=fs_location, + id_=task_display_name, + ) + ) + + python_nodes = tree_map( + lambda x: x if isinstance(x, PythonNode) else None, task.produces + ) + + return python_nodes, warning_reports, processed_exc_info + + +def _process_exception( + exc_info: tuple[type[BaseException], BaseException, TracebackType | None], + show_locals: bool, # noqa: FBT001 + console_options: ConsoleOptions, +) -> tuple[type[BaseException], BaseException, str]: + """Process the exception and convert the traceback to a string.""" + exc_info = remove_internal_traceback_frames_from_exc_info(exc_info) + traceback = Traceback.from_exception(*exc_info, show_locals=show_locals) + segments = console.render(traceback, options=console_options) + text = "".join(segment.text for segment in segments) + return (*exc_info[:2], text) + + +def _get_module(func: Callable[..., Any], path: Path | None) -> ModuleType: + """Get the module of a python function. + + ``functools.partial`` obfuscates the module of the function and + ``inspect.getmodule`` returns :mod`functools`. Therefore, we recover the original + function. + + We use the path from the task module to aid the search although it is not clear + whether it helps. + + """ + if isinstance(func, partial): + func = func.func + + if path: + return inspect.getmodule(func, path.as_posix()) + return inspect.getmodule(func) diff --git a/src/pytask_parallel/threads.py b/src/pytask_parallel/threads.py new file mode 100644 index 0000000..4c2db4a --- /dev/null +++ b/src/pytask_parallel/threads.py @@ -0,0 +1,55 @@ +"""Contains functions for the threads backend.""" + +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING +from typing import Any + +from pytask import PTask +from pytask import Session +from pytask import hookimpl + +from pytask_parallel.utils import create_kwargs_for_task +from pytask_parallel.utils import handle_task_function_return + +if TYPE_CHECKING: + from concurrent.futures import Future + from types import TracebackType + + +@hookimpl +def pytask_execute_task(session: Session, task: PTask) -> Future[Any]: + """Execute a task. + + Since threads have shared memory, it is not necessary to pickle and unpickle the + task. + + """ + kwargs = create_kwargs_for_task(task) + return session.config["_parallel_executor"].submit( + _mock_processes_for_threads, task=task, **kwargs + ) + + +def _mock_processes_for_threads( + task: PTask, **kwargs: Any +) -> tuple[ + None, list[Any], tuple[type[BaseException], BaseException, TracebackType] | None +]: + """Mock execution function such that it returns the same as for processes. + + The function for processes returns ``warning_reports`` and an ``exception``. With + threads, these object are collected by the main and not the subprocess. So, we just + return placeholders. + + """ + __tracebackhide__ = True + try: + out = task.function(**kwargs) + except Exception: # noqa: BLE001 + exc_info = sys.exc_info() + else: + handle_task_function_return(task, out) + exc_info = None + return None, [], exc_info diff --git a/src/pytask_parallel/utils.py b/src/pytask_parallel/utils.py new file mode 100644 index 0000000..4472c5e --- /dev/null +++ b/src/pytask_parallel/utils.py @@ -0,0 +1,52 @@ +"""Contains utility functions.""" + +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING +from typing import Any + +from pytask.tree_util import PyTree +from pytask.tree_util import tree_leaves +from pytask.tree_util import tree_map +from pytask.tree_util import tree_structure + +if TYPE_CHECKING: + from pytask import PTask + + +def handle_task_function_return(task: PTask, out: Any) -> None: + """Handle the return value of a task function.""" + if "return" not in task.produces: + return + + structure_out = tree_structure(out) + structure_return = tree_structure(task.produces["return"]) + # strict must be false when none is leaf. + if not structure_return.is_prefix(structure_out, strict=False): + msg = ( + "The structure of the return annotation is not a subtree of " + "the structure of the function return.\n\nFunction return: " + f"{structure_out}\n\nReturn annotation: {structure_return}" + ) + raise ValueError(msg) + + nodes = tree_leaves(task.produces["return"]) + values = structure_return.flatten_up_to(out) + for node, value in zip(nodes, values): + node.save(value) + + +def create_kwargs_for_task(task: PTask) -> dict[str, PyTree[Any]]: + """Create kwargs for task function.""" + parameters = inspect.signature(task.function).parameters + + kwargs = {} + for name, value in task.depends_on.items(): + kwargs[name] = tree_map(lambda x: x.load(), value) + + for name, value in task.produces.items(): + if name in parameters: + kwargs[name] = tree_map(lambda x: x.load(), value) + + return kwargs