Skip to content

Make pytask even lazier. #496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
when a product annotation is used with the argument name `produces`. And, allow
`produces` to intake any node.
- {pull}`490` refactors and better tests parsing of dependencies.
- {pull}`496` makes pytask even lazier. Now, when a task produces a node whose hash
remains the same, the consecutive tasks are not executed. It remained from when pytask
relied on timestamps.

## 0.4.2 - 2023-11-8
## 0.4.2 - 2023-11-08

- {pull}`449` simplifies the code building the plugin manager.
- {pull}`451` improves `collect_command.py` and renames `graph.py` to `dag_command.py`.
Expand Down
1 change: 0 additions & 1 deletion docs/source/reference_guides/hookspecs.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ your plugin.
```{eval-rst}
.. autofunction:: pytask_dag
.. autofunction:: pytask_dag_create_dag
.. autofunction:: pytask_dag_select_execution_dag
.. autofunction:: pytask_dag_log

```
Expand Down
10 changes: 9 additions & 1 deletion src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from _pytask.console import is_jupyter
from _pytask.exceptions import CollectionError
from _pytask.mark import MarkGenerator
from _pytask.mark_utils import get_all_marks
from _pytask.mark_utils import has_mark
from _pytask.node_protocols import PNode
from _pytask.node_protocols import PPathNode
Expand Down Expand Up @@ -246,6 +247,13 @@ def pytask_collect_task(

"""
if (name.startswith("task_") or has_mark(obj, "task")) and is_task_function(obj):
if has_mark(obj, "try_first") and has_mark(obj, "try_last"):
msg = (
"The task cannot have mixed priorities. Do not apply "
"'@pytask.mark.try_first' and '@pytask.mark.try_last' at the same time."
)
raise ValueError(msg)

path_nodes = Path.cwd() if path is None else path.parent
dependencies = parse_dependencies_from_task_function(
session, path, name, path_nodes, obj
Expand All @@ -254,7 +262,7 @@ def pytask_collect_task(
session, path, name, path_nodes, obj
)

markers = obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else []
markers = get_all_marks(obj)

# Get the underlying function to avoid having different states of the function,
# e.g. due to pytask_meta, in different layers of the wrapping.
Expand Down
61 changes: 0 additions & 61 deletions src/_pytask/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,7 @@
from _pytask.console import format_task_name
from _pytask.console import render_to_string
from _pytask.console import TASK_ICON
from _pytask.dag_utils import node_and_neighbors
from _pytask.dag_utils import task_and_descending_tasks
from _pytask.dag_utils import TopologicalSorter
from _pytask.database_utils import DatabaseSession
from _pytask.database_utils import State
from _pytask.exceptions import ResolvingDependenciesError
from _pytask.mark import Mark
from _pytask.node_protocols import PNode
from _pytask.node_protocols import PTask
from _pytask.nodes import PythonNode
Expand All @@ -31,7 +25,6 @@
from rich.tree import Tree

if TYPE_CHECKING:
from _pytask.node_protocols import MetaNode
from pathlib import Path
from _pytask.session import Session

Expand All @@ -44,7 +37,6 @@ def pytask_dag(session: Session) -> bool | None:
session=session, tasks=session.tasks
)
session.hook.pytask_dag_modify_dag(session=session, dag=session.dag)
session.hook.pytask_dag_select_execution_dag(session=session, dag=session.dag)

except Exception: # noqa: BLE001
report = DagReport.from_exception(sys.exc_info())
Expand Down Expand Up @@ -101,59 +93,6 @@ def _add_product(dag: nx.DiGraph, task: PTask, node: PNode) -> None:
return dag


@hookimpl
def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None:
"""Select the tasks which need to be executed."""
scheduler = TopologicalSorter.from_dag(dag)
visited_nodes: set[str] = set()

while scheduler.is_active():
task_signature = scheduler.get_ready()[0]
if task_signature not in visited_nodes:
task = dag.nodes[task_signature]["task"]
have_changed = _have_task_or_neighbors_changed(session, dag, task)
if have_changed:
visited_nodes.update(task_and_descending_tasks(task_signature, dag))
else:
dag.nodes[task_signature]["task"].markers.append(
Mark("skip_unchanged", (), {})
)
scheduler.done(task_signature)


def _have_task_or_neighbors_changed(
session: Session, dag: nx.DiGraph, task: PTask
) -> bool:
"""Indicate whether dependencies or products of a task have changed."""
return any(
session.hook.pytask_dag_has_node_changed(
session=session,
dag=dag,
task=task,
node=dag.nodes[node_name].get("task") or dag.nodes[node_name].get("node"),
)
for node_name in node_and_neighbors(dag, task.signature)
)


@hookimpl(trylast=True)
def pytask_dag_has_node_changed(task: PTask, node: MetaNode) -> bool:
"""Indicate whether a single dependency or product has changed."""
# If node does not exist, we receive None.
node_state = node.state()
if node_state is None:
return True

with DatabaseSession() as session:
db_state = session.get(State, (task.signature, node.signature))

# If the node is not in the database.
if db_state is None:
return True

return node_state != db_state.hash_


def _check_if_dag_has_cycles(dag: nx.DiGraph) -> None:
"""Check if DAG has cycles."""
try:
Expand Down
27 changes: 4 additions & 23 deletions src/_pytask/dag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from typing import TYPE_CHECKING

import networkx as nx
from _pytask.console import format_strings_as_flat_tree
from _pytask.console import format_task_name
from _pytask.console import TASK_ICON
from _pytask.mark_utils import has_mark
from attrs import define
from attrs import field
Expand Down Expand Up @@ -54,8 +51,11 @@ def node_and_neighbors(dag: nx.DiGraph, node: str) -> Iterable[str]:
We cannot use ``dag.neighbors`` as it only considers successors as neighbors in a
DAG.

The task node needs to be yield in the middle so that first predecessors are checked
and then the rest of the nodes.

"""
return itertools.chain([node], dag.predecessors(node), dag.successors(node))
return itertools.chain(dag.predecessors(node), [node], dag.successors(node))


@define
Expand Down Expand Up @@ -166,25 +166,6 @@ def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]:
}
for task in tasks
}
tasks_w_mixed_priorities = [
name for name, p in priorities.items() if p["try_first"] and p["try_last"]
]

if tasks_w_mixed_priorities:
name_to_task = {task.signature: task for task in tasks}
reduced_names = []
for name in tasks_w_mixed_priorities:
reduced_name = format_task_name(name_to_task[name], "no_link")
reduced_names.append(reduced_name.plain)

text = format_strings_as_flat_tree(
reduced_names, "Tasks with mixed priorities", TASK_ICON
)
msg = (
f"'try_first' and 'try_last' cannot be applied on the same task. See the "
f"following tasks for errors:\n\n{text}"
)
raise ValueError(msg)

# Recode to numeric values for sorting.
numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1}
Expand Down
19 changes: 19 additions & 0 deletions src/_pytask/database_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from sqlalchemy.orm import sessionmaker

if TYPE_CHECKING:
from _pytask.node_protocols import MetaNode
from _pytask.node_protocols import PTask
from _pytask.session import Session


Expand Down Expand Up @@ -62,3 +64,20 @@ def update_states_in_database(session: Session, task_signature: str) -> None:
node = session.dag.nodes[name].get("task") or session.dag.nodes[name]["node"]
hash_ = node.state()
_create_or_update_state(task_signature, node.signature, hash_)


def has_node_changed(task: PTask, node: MetaNode) -> bool:
"""Indicate whether a single dependency or product has changed."""
# If node does not exist, we receive None.
node_state = node.state()
if node_state is None:
return True

with DatabaseSession() as session:
db_state = session.get(State, (task.signature, node.signature))

# If the node is not in the database.
if db_state is None:
return True

return node_state != db_state.hash_
49 changes: 33 additions & 16 deletions src/_pytask/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from _pytask.console import format_strings_as_flat_tree
from _pytask.console import unify_styles
from _pytask.dag_utils import descending_tasks
from _pytask.dag_utils import node_and_neighbors
from _pytask.dag_utils import TopologicalSorter
from _pytask.database_utils import has_node_changed
from _pytask.database_utils import update_states_in_database
from _pytask.exceptions import ExecutionError
from _pytask.exceptions import NodeLoadError
Expand All @@ -28,6 +30,7 @@
from _pytask.node_protocols import PTask
from _pytask.outcomes import count_outcomes
from _pytask.outcomes import Exit
from _pytask.outcomes import SkippedUnchanged
from _pytask.outcomes import TaskOutcome
from _pytask.outcomes import WouldBeExecuted
from _pytask.reports import ExecutionReport
Expand Down Expand Up @@ -124,28 +127,42 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None:
2. Create the directory where the product will be placed.

"""
for dependency in session.dag.predecessors(task.signature):
node = session.dag.nodes[dependency]["node"]
if not node.state():
msg = f"{task.name!r} requires missing node {node.name!r}."
if IS_FILE_SYSTEM_CASE_SENSITIVE:
msg += (
"\n\n(Hint: Your file-system is case-sensitive. Check the paths' "
"capitalization carefully.)"
)
raise NodeNotFoundError(msg)
if has_mark(task, "would_be_executed"):
raise WouldBeExecuted

dag = session.dag

needs_to_be_executed = session.config["force"]
if not needs_to_be_executed:
predecessors = set(dag.predecessors(task.signature)) | {task.signature}
for node_signature in node_and_neighbors(dag, task.signature):
node = dag.nodes[node_signature].get("task") or dag.nodes[
node_signature
].get("node")
if node_signature in predecessors and not node.state():
msg = f"{task.name!r} requires missing node {node.name!r}."
if IS_FILE_SYSTEM_CASE_SENSITIVE:
msg += (
"\n\n(Hint: Your file-system is case-sensitive. Check the "
"paths' capitalization carefully.)"
)
raise NodeNotFoundError(msg)

has_changed = has_node_changed(task=task, node=node)
if has_changed:
needs_to_be_executed = True
break

if not needs_to_be_executed:
raise SkippedUnchanged

# Create directory for product if it does not exist. Maybe this should be a `setup`
# method for the node classes.
for product in session.dag.successors(task.signature):
node = session.dag.nodes[product]["node"]
for product in dag.successors(task.signature):
node = dag.nodes[product]["node"]
if isinstance(node, PPathNode):
node.path.parent.mkdir(parents=True, exist_ok=True)

would_be_executed = has_mark(task, "would_be_executed")
if would_be_executed:
raise WouldBeExecuted


def _safe_load(node: PNode, task: PTask, is_product: bool) -> Any:
try:
Expand Down
23 changes: 0 additions & 23 deletions src/_pytask/hookspecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


if TYPE_CHECKING:
from _pytask.node_protocols import MetaNode
from _pytask.models import NodeInfo
from _pytask.node_protocols import PNode
import click
Expand Down Expand Up @@ -245,28 +244,6 @@ def pytask_dag_modify_dag(session: Session, dag: nx.DiGraph) -> None:
"""


@hookspec
def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None:
"""Select the subgraph which needs to be executed.

This hook determines which of the tasks have to be re-run because something has
changed.

"""


@hookspec(firstresult=True)
def pytask_dag_has_node_changed(
session: Session, dag: nx.DiGraph, task: PTask, node: MetaNode
) -> None:
"""Select the subgraph which needs to be executed.

This hook determines which of the tasks have to be re-run because something has
changed.

"""


@hookspec
def pytask_dag_log(session: Session, report: DagReport) -> None:
"""Log errors during resolving dependencies."""
Expand Down
3 changes: 2 additions & 1 deletion src/_pytask/mark/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Iterable
from typing import Mapping

from _pytask.mark_utils import get_all_marks
from _pytask.models import CollectionMetadata
from _pytask.typing import is_task_function
from attrs import define
Expand Down Expand Up @@ -122,7 +123,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> MarkDecorator:

def get_unpacked_marks(obj: Callable[..., Any]) -> list[Mark]:
"""Obtain the unpacked marks that are stored on an object."""
mark_list = obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else []
mark_list = get_all_marks(obj)
return normalize_mark_list(mark_list)


Expand Down
12 changes: 11 additions & 1 deletion src/_pytask/persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from _pytask.config import hookimpl
from _pytask.dag_utils import node_and_neighbors
from _pytask.database_utils import has_node_changed
from _pytask.database_utils import update_states_in_database
from _pytask.mark_utils import has_mark
from _pytask.outcomes import Persisted
Expand Down Expand Up @@ -46,7 +47,16 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None:
)

if all_nodes_exist:
raise Persisted
any_node_changed = any(
has_node_changed(
task=task,
node=session.dag.nodes[name].get("task")
or session.dag.nodes[name]["node"],
)
for name in node_and_neighbors(session.dag, task.signature)
)
if any_node_changed:
raise Persisted


@hookimpl
Expand Down
18 changes: 18 additions & 0 deletions tests/test_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,21 @@ def task_example() -> Annotated[int, 1]: ...
result = runner.invoke(cli, [tmp_path.as_posix()])
assert result.exit_code == ExitCode.COLLECTION_FAILED
assert "The return annotation of the task" in result.output


@pytest.mark.end_to_end()
def test_scheduling_w_mixed_priorities(runner, tmp_path):
source = """
import pytask

@pytask.mark.try_last
@pytask.mark.try_first
def task_mixed(): pass
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))

result = runner.invoke(cli, [tmp_path.as_posix()])

assert result.exit_code == ExitCode.COLLECTION_FAILED
assert "Could not collect" in result.output
assert "The task cannot have" in result.output
Loading