diff --git a/CHANGELOG.md b/CHANGELOG.md index 79ae622b4..dba230f1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ### 0.14.3 -- TBD - Fixed - Added encapsulating double quotes to comply with [DOT language](https://graphviz.org/doc/info/lang.html) - PR [#1177](https://github.com/datajoint/datajoint-python/pull/1177) +- Added - Datajoint python CLI ([#940](https://github.com/datajoint/datajoint-python/issues/940)) PR [#1095](https://github.com/datajoint/datajoint-python/pull/1095) ### 0.14.2 -- Aug 19, 2024 - Added - Migrate nosetests to pytest - PR [#1142](https://github.com/datajoint/datajoint-python/pull/1142) diff --git a/datajoint/__init__.py b/datajoint/__init__.py index 096def463..59fac9e28 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -51,6 +51,7 @@ "key", "key_hash", "logger", + "cli", ] from .logging import logger @@ -70,6 +71,7 @@ from .attribute_adapter import AttributeAdapter from . import errors from .errors import DataJointError +from .cli import cli ERD = Di = Diagram # Aliases for Diagram schema = Schema # Aliases for Schema diff --git a/datajoint/cli.py b/datajoint/cli.py new file mode 100644 index 000000000..e2432fe6e --- /dev/null +++ b/datajoint/cli.py @@ -0,0 +1,77 @@ +import argparse +from code import interact +from collections import ChainMap +import datajoint as dj + + +def cli(args: list = None): + """ + Console interface for DataJoint Python + + :param args: List of arguments to be passed in, defaults to reading stdin + :type args: list, optional + """ + parser = argparse.ArgumentParser( + prog="datajoint", + description="DataJoint console interface.", + conflict_handler="resolve", + ) + parser.add_argument( + "-V", "--version", action="version", version=f"{dj.__name__} {dj.__version__}" + ) + parser.add_argument( + "-u", + "--user", + type=str, + default=dj.config["database.user"], + required=False, + help="Datajoint username", + ) + parser.add_argument( + "-p", + "--password", + type=str, + default=dj.config["database.password"], + required=False, + help="Datajoint password", + ) + parser.add_argument( + "-h", + "--host", + type=str, + default=dj.config["database.host"], + required=False, + help="Datajoint host", + ) + parser.add_argument( + "-s", + "--schemas", + nargs="+", + type=str, + required=False, + help="A list of virtual module mappings in `db:schema ...` format", + ) + kwargs = vars(parser.parse_args(args)) + mods = {} + if kwargs["user"]: + dj.config["database.user"] = kwargs["user"] + if kwargs["password"]: + dj.config["database.password"] = kwargs["password"] + if kwargs["host"]: + dj.config["database.host"] = kwargs["host"] + if kwargs["schemas"]: + for vm in kwargs["schemas"]: + d, m = vm.split(":") + mods[m] = dj.create_virtual_module(m, d) + + banner = "dj repl\n" + if mods: + modstr = "\n".join(" - {}".format(m) for m in mods) + banner += "\nschema modules:\n\n" + modstr + "\n" + interact(banner, local=dict(ChainMap(mods, locals(), globals()))) + + raise SystemExit + + +if __name__ == "__main__": + cli() diff --git a/setup.py b/setup.py index 904260681..e280038ce 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,9 @@ "automated research workflows", ], packages=find_packages(exclude=["contrib", "docs", "tests*"]), + entry_points={ + "console_scripts": ["dj=datajoint.cli:cli", "datajoint=datajoint.cli:cli"], + }, install_requires=requirements, python_requires="~={}.{}".format(*min_py_version), setup_requires=["otumat"], # maybe remove due to conflicts? diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 000000000..41459ebc2 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,140 @@ +""" +Collection of test cases to test the dj cli +""" + +import json +import subprocess +import pytest +import datajoint as dj +from . import CONN_INFO_ROOT, PREFIX + + +def test_cli_version(capsys): + with pytest.raises(SystemExit) as pytest_wrapped_e: + dj.cli(args=["-V"]) + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == 0 + + captured_output = capsys.readouterr().out + assert captured_output == f"{dj.__name__} {dj.__version__}\n" + + +def test_cli_help(capsys): + with pytest.raises(SystemExit) as pytest_wrapped_e: + dj.cli(args=["--help"]) + assert pytest_wrapped_e.type == SystemExit + assert pytest_wrapped_e.value.code == 0 + + captured_output = capsys.readouterr().out + + assert ( + "\ +usage: datajoint [--help] [-V] [-u USER] [-p PASSWORD] [-h HOST]\n\ + [-s SCHEMAS [SCHEMAS ...]]\n\n\ +\ +DataJoint console interface.\n\n\ +\ +optional arguments:\n\ + --help show this help message and exit\n\ + -V, --version show program's version number and exit\n\ + -u USER, --user USER Datajoint username\n\ + -p PASSWORD, --password PASSWORD\n\ + Datajoint password\n\ + -h HOST, --host HOST Datajoint host\n\ + -s SCHEMAS [SCHEMAS ...], --schemas SCHEMAS [SCHEMAS ...]\n\ + A list of virtual module mappings in `db:schema ...`\n\ + format\n" + == captured_output + ) + + +def test_cli_config(): + process = subprocess.Popen( + ["dj"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + process.stdin.write("dj.config\n") + process.stdin.flush() + + stdout, stderr = process.communicate() + + assert dj.config == json.loads( + stdout[4:519] + .replace("'", '"') + .replace("None", "null") + .replace("True", "true") + .replace("False", "false") + ) + + +def test_cli_args(): + process = subprocess.Popen( + ["dj", "-utest_user", "-ptest_pass", "-htest_host"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + process.stdin.write("dj.config['database.user']\n") + process.stdin.write("dj.config['database.password']\n") + process.stdin.write("dj.config['database.host']\n") + process.stdin.flush() + + stdout, stderr = process.communicate() + assert "test_user" == stdout[5:14] + assert "test_pass" == stdout[21:30] + assert "test_host" == stdout[37:46] + + +def test_cli_schemas(): + schema = dj.Schema(PREFIX + "_cli", locals(), connection=dj.conn(**CONN_INFO_ROOT)) + + @schema + class IJ(dj.Lookup): + definition = """ # tests restrictions + i : int + j : int + """ + contents = list(dict(i=i, j=j + 2) for i in range(3) for j in range(3)) + + process = subprocess.Popen( + ["dj", "-s", "djtest_cli:test_schema"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + process.stdin.write("test_schema.__dict__['__name__']\n") + process.stdin.write("test_schema.__dict__['schema']\n") + process.stdin.write("test_schema.IJ.fetch(as_dict=True)\n") + process.stdin.flush() + + stdout, stderr = process.communicate() + fetch_res = [ + {"i": 0, "j": 2}, + {"i": 0, "j": 3}, + {"i": 0, "j": 4}, + {"i": 1, "j": 2}, + {"i": 1, "j": 3}, + {"i": 1, "j": 4}, + {"i": 2, "j": 2}, + {"i": 2, "j": 3}, + {"i": 2, "j": 4}, + ] + assert ( + "\ +dj repl\n\n\ +\ +schema modules:\n\n\ + - test_schema" + == stderr[159:200] + ) + assert "'test_schema'" == stdout[4:17] + assert "Schema `djtest_cli`" == stdout[22:41] + assert fetch_res == json.loads(stdout[47:209].replace("'", '"'))