diff --git a/docs/source/changes.md b/docs/source/changes.md index 8cfff71f..fc439ba7 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -51,6 +51,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and in pytask-parallel. - {pull}`611` removes the initial task execution status from `pytask_execute_task_log_start`. +- {pull}`612` adds validation for data catalog names. ## 0.4.7 - 2024-03-19 diff --git a/src/_pytask/data_catalog.py b/src/_pytask/data_catalog.py index 615fc021..d2993e60 100644 --- a/src/_pytask/data_catalog.py +++ b/src/_pytask/data_catalog.py @@ -9,6 +9,7 @@ import hashlib import inspect import pickle +import re from pathlib import Path from typing import Any @@ -61,13 +62,26 @@ class DataCatalog: default_node: type[PNode] = PickleNode entries: dict[str, PNode | PProvisionalNode] = field(factory=dict) - name: str = "default" + name: str = field(default="default") path: Path | None = None _session_config: dict[str, Any] = field( factory=lambda *x: {"check_casing_of_paths": True} # noqa: ARG005 ) _instance_path: Path = field(factory=_get_parent_path_of_data_catalog_module) + @name.validator + def _check(self, attribute: str, value: str) -> None: # noqa: ARG002 + _rich_traceback_omit = True + if not isinstance(value, str): + msg = "The name of a data catalog must be a string." + raise TypeError(msg) + if not re.match(r"[a-zA-Z0-9-_]+", value): + msg = ( + "The name of a data catalog must be a string containing only letters, " + "numbers, hyphens, and underscores." + ) + raise ValueError(msg) + def __attrs_post_init__(self) -> None: root_path, _ = find_project_root_and_config((self._instance_path,)) self._session_config["paths"] = (root_path,) diff --git a/tests/test_data_catalog.py b/tests/test_data_catalog.py index d6050cfe..28b7fdf5 100644 --- a/tests/test_data_catalog.py +++ b/tests/test_data_catalog.py @@ -234,3 +234,17 @@ def task_add_content( result = runner.invoke(cli, [tmp_path.as_posix()]) assert result.exit_code == ExitCode.OK assert tmp_path.joinpath("output.txt").read_text() == "Hello, World!" + + +@pytest.mark.end_to_end() +def test_data_catalog_has_invalid_name(runner, tmp_path): + source = """ + from pytask import DataCatalog + + data_catalog = DataCatalog(name="?1") + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.COLLECTION_FAILED + assert "The name of a data catalog" in result.stdout