diff --git a/tests/conftest.py b/tests/conftest.py index 9d697ef47..3979efe50 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,11 @@ import json from pathlib import Path from datajoint import errors -from datajoint.errors import ADAPTED_TYPE_SWITCH, FILEPATH_FEATURE_SWITCH +from datajoint.errors import ( + ADAPTED_TYPE_SWITCH, + FILEPATH_FEATURE_SWITCH, + DataJointError, +) from . import ( PREFIX, CONN_INFO, @@ -227,6 +231,10 @@ def schema_any(connection_test): PREFIX + "_test1", schema.LOCALS_ANY, connection=connection_test ) assert schema.LOCALS_ANY, "LOCALS_ANY is empty" + try: + schema_any.jobs.delete() + except DataJointError: + pass schema_any(schema.TTest) schema_any(schema.TTest2) schema_any(schema.TTest3) @@ -264,6 +272,10 @@ def schema_any(connection_test): schema_any(schema.Stimulus) schema_any(schema.Longblob) yield schema_any + try: + schema_any.jobs.delete() + except DataJointError: + pass schema_any.drop() diff --git a/tests/schema.py b/tests/schema.py index 5a60b1c0b..81e5ac44c 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -258,7 +258,7 @@ class SimpleSource(dj.Lookup): definition = """ id : int # id """ - contents = ((x,) for x in range(10)) + contents = [(x,) for x in range(10)] class SigIntTable(dj.Computed): diff --git a/tests/test_fetch_same.py b/tests/test_fetch_same.py new file mode 100644 index 000000000..4935bb037 --- /dev/null +++ b/tests/test_fetch_same.py @@ -0,0 +1,72 @@ +import pytest +from . import PREFIX, CONN_INFO +import numpy as np +import datajoint as dj + + +class ProjData(dj.Manual): + definition = """ + id : int + --- + resp : float + sim : float + big : longblob + blah : varchar(10) + """ + + +@pytest.fixture +def schema_fetch_same(connection_root): + schema = dj.Schema( + PREFIX + "_fetch_same", + context=dict(ProjData=ProjData), + connection=connection_root, + ) + schema(ProjData) + ProjData().insert( + [ + {"id": 0, "resp": 20.33, "sim": 45.324, "big": 3, "blah": "yes"}, + { + "id": 1, + "resp": 94.3, + "sim": 34.23, + "big": {"key1": np.random.randn(20, 10)}, + "blah": "si", + }, + { + "id": 2, + "resp": 1.90, + "sim": 10.23, + "big": np.random.randn(4, 2), + "blah": "sim", + }, + ] + ) + yield schema + schema.drop() + + +@pytest.fixture +def projdata(): + yield ProjData() + + +class TestFetchSame: + def test_object_conversion_one(self, schema_fetch_same, projdata): + new = projdata.proj(sub="resp").fetch("sub") + assert new.dtype == np.float64 + + def test_object_conversion_two(self, schema_fetch_same, projdata): + [sub, add] = projdata.proj(sub="resp", add="sim").fetch("sub", "add") + assert sub.dtype == np.float64 + assert add.dtype == np.float64 + + def test_object_conversion_all(self, schema_fetch_same, projdata): + new = projdata.proj(sub="resp", add="sim").fetch() + assert new["sub"].dtype == np.float64 + assert new["add"].dtype == np.float64 + + def test_object_no_convert(self, schema_fetch_same, projdata): + new = projdata.fetch() + assert new["big"].dtype == "object" + assert new["blah"].dtype == "object" diff --git a/tests/test_jobs.py b/tests/test_jobs.py new file mode 100644 index 000000000..37974ac86 --- /dev/null +++ b/tests/test_jobs.py @@ -0,0 +1,151 @@ +import pytest +from . import schema +from datajoint.jobs import ERROR_MESSAGE_LENGTH, TRUNCATION_APPENDIX +import random +import string +import datajoint as dj + + +@pytest.fixture +def subjects(): + yield schema.Subject() + + +def test_reserve_job(schema_any, subjects): + assert subjects + table_name = "fake_table" + + # reserve jobs + for key in subjects.fetch("KEY"): + assert schema_any.jobs.reserve(table_name, key), "failed to reserve a job" + + # refuse jobs + for key in subjects.fetch("KEY"): + assert not schema_any.jobs.reserve( + table_name, key + ), "failed to respect reservation" + + # complete jobs + for key in subjects.fetch("KEY"): + schema_any.jobs.complete(table_name, key) + assert not schema_any.jobs, "failed to free jobs" + + # reserve jobs again + for key in subjects.fetch("KEY"): + assert schema_any.jobs.reserve(table_name, key), "failed to reserve new jobs" + + # finish with error + for key in subjects.fetch("KEY"): + schema_any.jobs.error(table_name, key, "error message") + + # refuse jobs with errors + for key in subjects.fetch("KEY"): + assert not schema_any.jobs.reserve( + table_name, key + ), "failed to ignore error jobs" + + # clear error jobs + (schema_any.jobs & dict(status="error")).delete() + assert not schema_any.jobs, "failed to clear error jobs" + + +def test_restrictions(schema_any): + jobs = schema_any.jobs + jobs.delete() + jobs.reserve("a", {"key": "a1"}) + jobs.reserve("a", {"key": "a2"}) + jobs.reserve("b", {"key": "b1"}) + jobs.error("a", {"key": "a2"}, "error") + jobs.error("b", {"key": "b1"}, "error") + + assert len(jobs & {"table_name": "a"}) == 2 + assert len(jobs & {"status": "error"}) == 2 + assert len(jobs & {"table_name": "a", "status": "error"}) == 1 + jobs.delete() + + +def test_sigint(schema_any): + try: + schema.SigIntTable().populate(reserve_jobs=True) + except KeyboardInterrupt: + pass + + assert len(schema_any.jobs.fetch()), "SigInt jobs table is empty" + status, error_message = schema_any.jobs.fetch1("status", "error_message") + assert status == "error" + assert error_message == "KeyboardInterrupt" + + +def test_sigterm(schema_any): + try: + schema.SigTermTable().populate(reserve_jobs=True) + except SystemExit: + pass + + assert len(schema_any.jobs.fetch()), "SigTerm jobs table is empty" + status, error_message = schema_any.jobs.fetch1("status", "error_message") + assert status == "error" + assert error_message == "SystemExit: SIGTERM received" + + +def test_suppress_dj_errors(schema_any): + """test_suppress_dj_errors: dj errors suppressible w/o native py blobs""" + with dj.config(enable_python_native_blobs=False): + schema.ErrorClass.populate(reserve_jobs=True, suppress_errors=True) + assert len(schema.DjExceptionName()) == len(schema_any.jobs) > 0 + + +def test_long_error_message(schema_any, subjects): + # create long error message + long_error_message = "".join( + random.choice(string.ascii_letters) for _ in range(ERROR_MESSAGE_LENGTH + 100) + ) + short_error_message = "".join( + random.choice(string.ascii_letters) for _ in range(ERROR_MESSAGE_LENGTH // 2) + ) + assert subjects + table_name = "fake_table" + + key = subjects.fetch("KEY")[0] + + # test long error message + schema_any.jobs.reserve(table_name, key) + schema_any.jobs.error(table_name, key, long_error_message) + error_message = schema_any.jobs.fetch1("error_message") + assert ( + len(error_message) == ERROR_MESSAGE_LENGTH + ), "error message is longer than max allowed" + assert error_message.endswith( + TRUNCATION_APPENDIX + ), "appropriate ending missing for truncated error message" + schema_any.jobs.delete() + + # test long error message + schema_any.jobs.reserve(table_name, key) + schema_any.jobs.error(table_name, key, short_error_message) + error_message = schema_any.jobs.fetch1("error_message") + assert error_message == short_error_message, "error messages do not agree" + assert not error_message.endswith( + TRUNCATION_APPENDIX + ), "error message should not be truncated" + schema_any.jobs.delete() + + +def test_long_error_stack(schema_any, subjects): + # create long error stack + STACK_SIZE = ( + 89942 # Does not fit into small blob (should be 64k, but found to be higher) + ) + long_error_stack = "".join( + random.choice(string.ascii_letters) for _ in range(STACK_SIZE) + ) + assert subjects + table_name = "fake_table" + + key = subjects.fetch("KEY")[0] + + # test long error stack + schema_any.jobs.reserve(table_name, key) + schema_any.jobs.error(table_name, key, "error message", long_error_stack) + error_stack = schema_any.jobs.fetch1("error_stack") + assert error_stack == long_error_stack, "error stacks do not agree"