diff --git a/tests/conftest.py b/tests/conftest.py index 0b1465241..f0a7a58b6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,7 @@ schema_advanced, schema_adapted, schema_external, + schema_uuid as schema_uuid_module, ) @@ -307,6 +308,20 @@ def schema_ext(connection_test, stores_config, enable_filepath_feature): schema.drop() +@pytest.fixture +def schema_uuid(connection_test): + schema = dj.Schema( + PREFIX + "_test1", + context=schema_uuid_module.LOCALS_UUID, + connection=connection_test, + ) + schema(schema_uuid_module.Basic) + schema(schema_uuid_module.Topic) + schema(schema_uuid_module.Item) + yield schema + schema.drop() + + @pytest.fixture(scope="session") def http_client(): # Initialize httpClient with relevant timeout. diff --git a/tests/schema_aggr_regress.py b/tests/schema_aggr_regress.py new file mode 100644 index 000000000..9b85bfffb --- /dev/null +++ b/tests/schema_aggr_regress.py @@ -0,0 +1,51 @@ +import datajoint as dj +import itertools +import inspect + + +class R(dj.Lookup): + definition = """ + r : char(1) + """ + contents = zip("ABCDFGHIJKLMNOPQRST") + + +class Q(dj.Lookup): + definition = """ + -> R + """ + contents = zip("ABCDFGH") + + +class S(dj.Lookup): + definition = """ + -> R + s : int + """ + contents = itertools.product("ABCDF", range(10)) + + +class A(dj.Lookup): + definition = """ + id: int + """ + contents = zip(range(10)) + + +class B(dj.Lookup): + definition = """ + -> A + id2: int + """ + contents = zip(range(5), range(5, 10)) + + +class X(dj.Lookup): + definition = """ + id: int + """ + contents = zip(range(10)) + + +LOCALS_AGGR_REGRESS = {k: v for k, v in locals().items() if inspect.isclass(v)} +__all__ = list(LOCALS_AGGR_REGRESS) diff --git a/tests/schema_uuid.py b/tests/schema_uuid.py new file mode 100644 index 000000000..6bf994b5b --- /dev/null +++ b/tests/schema_uuid.py @@ -0,0 +1,50 @@ +import uuid +import inspect +import datajoint as dj +from . import PREFIX, CONN_INFO + +top_level_namespace_id = uuid.UUID("00000000-0000-0000-0000-000000000000") + + +class Basic(dj.Manual): + definition = """ + item : uuid + --- + number : int + """ + + +class Topic(dj.Manual): + definition = """ + # A topic for items + topic_id : uuid # internal identification of a topic, reflects topic name + --- + topic : varchar(8000) # full topic name used to generate the topic id + """ + + def add(self, topic): + """add a new topic with a its UUID""" + self.insert1( + dict(topic_id=uuid.uuid5(top_level_namespace_id, topic), topic=topic) + ) + + +class Item(dj.Computed): + definition = """ + item_id : uuid # internal identification of + --- + -> Topic + word : varchar(8000) + """ + + key_source = Topic # test key source that is not instantiated + + def make(self, key): + for word in ("Habenula", "Hippocampus", "Hypothalamus", "Hypophysis"): + self.insert1( + dict(key, word=word, item_id=uuid.uuid5(key["topic_id"], word)) + ) + + +LOCALS_UUID = {k: v for k, v in locals().items() if inspect.isclass(v)} +__all__ = list(LOCALS_UUID) diff --git a/tests/test_aggr_regressions.py b/tests/test_aggr_regressions.py new file mode 100644 index 000000000..b4d4e0802 --- /dev/null +++ b/tests/test_aggr_regressions.py @@ -0,0 +1,130 @@ +""" +Regression tests for issues 386, 449, 484, and 558 — all related to processing complex aggregations and projections. +""" + +import pytest +import datajoint as dj +from . import PREFIX +import uuid +from .schema_uuid import Topic, Item, top_level_namespace_id +from .schema_aggr_regress import R, Q, S, A, B, X, LOCALS_AGGR_REGRESS + + +@pytest.fixture(scope="function") +def schema_aggr_reg(connection_test): + context = LOCALS_AGGR_REGRESS + schema = dj.Schema( + PREFIX + "_aggr_regress", + context=context, + connection=connection_test, + ) + schema(R) + schema(Q) + schema(S) + yield schema + schema.drop() + + +@pytest.fixture(scope="function") +def schema_aggr_reg_with_abx(connection_test): + context = LOCALS_AGGR_REGRESS + schema = dj.Schema( + PREFIX + "_aggr_regress_with_abx", + context=context, + connection=connection_test, + ) + schema(R) + schema(Q) + schema(S) + schema(A) + schema(B) + schema(X) + yield schema + schema.drop() + + +def test_issue386(schema_aggr_reg): + """ + --------------- ISSUE 386 ------------------- + Issue 386 resulted from the loss of aggregated attributes when the aggregation was used as the restrictor + Q & (R.aggr(S, n='count(*)') & 'n=2') + Error: Unknown column 'n' in HAVING + """ + result = R.aggr(S, n="count(*)") & "n=10" + result = Q & result + result.fetch() + + +def test_issue449(schema_aggr_reg): + """ + ---------------- ISSUE 449 ------------------ + Issue 449 arises from incorrect group by attributes after joining with a dj.U() + """ + result = dj.U("n") * R.aggr(S, n="max(s)") + result.fetch() + + +def test_issue484(schema_aggr_reg): + """ + ---------------- ISSUE 484 ----------------- + Issue 484 + """ + q = dj.U().aggr(S, n="max(s)") + n = q.fetch("n") + n = q.fetch1("n") + q = dj.U().aggr(S, n="avg(s)") + result = dj.U().aggr(q, m="max(n)") + result.fetch() + + +def test_union_join(schema_aggr_reg_with_abx): + """ + This test fails if it runs after TestIssue558. + + https://github.com/datajoint/datajoint-python/issues/930 + """ + A.insert(zip([100, 200, 300, 400, 500, 600])) + B.insert([(100, 11), (200, 22), (300, 33), (400, 44)]) + q1 = B & "id < 300" + q2 = B & "id > 300" + + expected_data = [ + {"id": 0, "id2": 5}, + {"id": 1, "id2": 6}, + {"id": 2, "id2": 7}, + {"id": 3, "id2": 8}, + {"id": 4, "id2": 9}, + {"id": 100, "id2": 11}, + {"id": 200, "id2": 22}, + {"id": 400, "id2": 44}, + ] + + assert ((q1 + q2) * A).fetch(as_dict=True) == expected_data + + +class TestIssue558: + """ + --------------- ISSUE 558 ------------------ + Issue 558 resulted from the fact that DataJoint saves subqueries and often combines a restriction followed + by a projection into a single SELECT statement, which in several unusual cases produces unexpected results. + """ + + def test_issue558_part1(self, schema_aggr_reg_with_abx): + q = (A - B).proj(id2="3") + assert len(A - B) == len(q) + + def test_issue558_part2(self, schema_aggr_reg_with_abx): + d = dict(id=3, id2=5) + assert len(X & d) == len((X & d).proj(id2="3")) + + +def test_left_join_len(schema_uuid): + Topic().add("jeff") + Item.populate() + Topic().add("jeff2") + Topic().add("jeff3") + q = Topic.join( + Item - dict(topic_id=uuid.uuid5(top_level_namespace_id, "jeff")), left=True + ) + qf = q.fetch() + assert len(q) == len(qf) diff --git a/tests/test_alter.py b/tests/test_alter.py new file mode 100644 index 000000000..a78a07f26 --- /dev/null +++ b/tests/test_alter.py @@ -0,0 +1,116 @@ +import pytest +import re +import datajoint as dj +from . import schema as schema_any_module, PREFIX + + +class Experiment(dj.Imported): + original_definition = """ # information about experiments + -> Subject + experiment_id :smallint # experiment number for this subject + --- + experiment_date :date # date when experiment was started + -> [nullable] User + data_path="" :varchar(255) # file path to recorded data + notes="" :varchar(2048) # e.g. purpose of experiment + entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp + """ + + definition1 = """ # Experiment + -> Subject + experiment_id :smallint # experiment number for this subject + --- + data_path : int # some number + extra=null : longblob # just testing + -> [nullable] User + subject_notes=null :varchar(2048) # {notes} e.g. purpose of experiment + entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp + """ + + +class Parent(dj.Manual): + definition = """ + parent_id: int + """ + + class Child(dj.Part): + definition = """ + -> Parent + """ + definition_new = """ + -> master + --- + child_id=null: int + """ + + class Grandchild(dj.Part): + definition = """ + -> master.Child + """ + definition_new = """ + -> master.Child + --- + grandchild_id=null: int + """ + + +LOCALS_ALTER = {"Experiment": Experiment, "Parent": Parent} +COMBINED_CONTEXT = { + **schema_any_module.LOCALS_ANY, + **LOCALS_ALTER, +} + + +@pytest.fixture +def schema_alter(connection_test, schema_any): + # Overwrite Experiment and Parent nodes + schema_any(Experiment, context=LOCALS_ALTER) + schema_any(Parent, context=LOCALS_ALTER) + yield schema_any + schema_any.drop() + + +class TestAlter: + def test_alter(self, schema_alter): + original = schema_alter.connection.query( + "SHOW CREATE TABLE " + Experiment.full_table_name + ).fetchone()[1] + Experiment.definition = Experiment.definition1 + Experiment.alter(prompt=False, context=COMBINED_CONTEXT) + altered = schema_alter.connection.query( + "SHOW CREATE TABLE " + Experiment.full_table_name + ).fetchone()[1] + assert original != altered + Experiment.definition = Experiment.original_definition + Experiment().alter(prompt=False, context=COMBINED_CONTEXT) + restored = schema_alter.connection.query( + "SHOW CREATE TABLE " + Experiment.full_table_name + ).fetchone()[1] + assert altered != restored + assert original == restored + + def verify_alter(self, schema_alter, table, attribute_sql): + definition_original = schema_alter.connection.query( + f"SHOW CREATE TABLE {table.full_table_name}" + ).fetchone()[1] + table.definition = table.definition_new + table.alter(prompt=False) + definition_new = schema_alter.connection.query( + f"SHOW CREATE TABLE {table.full_table_name}" + ).fetchone()[1] + assert ( + re.sub(f"{attribute_sql},\n ", "", definition_new) == definition_original + ) + + def test_alter_part(self, schema_alter): + """ + https://github.com/datajoint/datajoint-python/issues/936 + """ + self.verify_alter( + schema_alter, table=Parent.Child, attribute_sql="`child_id` .* DEFAULT NULL" + ) + self.verify_alter( + schema_alter, + table=Parent.Grandchild, + attribute_sql="`grandchild_id` .* DEFAULT NULL", + ) diff --git a/tests/test_erd.py b/tests/test_erd.py index aebf62eaf..8a2d1d3ac 100644 --- a/tests/test_erd.py +++ b/tests/test_erd.py @@ -58,7 +58,7 @@ def test_make_image(schema_simp): def test_part_table_parsing(schema_simp): # https://github.com/datajoint/datajoint-python/issues/882 - erd = dj.Di(schema_simp) + erd = dj.Di(schema_simp, context=LOCALS_SIMPLE) graph = erd._make_graph() assert "OutfitLaunch" in graph.nodes() assert "OutfitLaunch.OutfitPiece" in graph.nodes()