Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
schema_advanced,
schema_adapted,
schema_external,
schema_uuid as schema_uuid_module,
)


Expand Down Expand Up @@ -307,6 +308,20 @@ def schema_ext(connection_test, stores_config, enable_filepath_feature):
schema.drop()


@pytest.fixture
def schema_uuid(connection_test):
schema = dj.Schema(
PREFIX + "_test1",
context=schema_uuid_module.LOCALS_UUID,
connection=connection_test,
)
schema(schema_uuid_module.Basic)
schema(schema_uuid_module.Topic)
schema(schema_uuid_module.Item)
yield schema
schema.drop()


@pytest.fixture(scope="session")
def http_client():
# Initialize httpClient with relevant timeout.
Expand Down
51 changes: 51 additions & 0 deletions tests/schema_aggr_regress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import datajoint as dj
import itertools
import inspect


class R(dj.Lookup):
definition = """
r : char(1)
"""
contents = zip("ABCDFGHIJKLMNOPQRST")


class Q(dj.Lookup):
definition = """
-> R
"""
contents = zip("ABCDFGH")


class S(dj.Lookup):
definition = """
-> R
s : int
"""
contents = itertools.product("ABCDF", range(10))


class A(dj.Lookup):
definition = """
id: int
"""
contents = zip(range(10))


class B(dj.Lookup):
definition = """
-> A
id2: int
"""
contents = zip(range(5), range(5, 10))


class X(dj.Lookup):
definition = """
id: int
"""
contents = zip(range(10))


LOCALS_AGGR_REGRESS = {k: v for k, v in locals().items() if inspect.isclass(v)}
__all__ = list(LOCALS_AGGR_REGRESS)
50 changes: 50 additions & 0 deletions tests/schema_uuid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import uuid
import inspect
import datajoint as dj
from . import PREFIX, CONN_INFO

top_level_namespace_id = uuid.UUID("00000000-0000-0000-0000-000000000000")


class Basic(dj.Manual):
definition = """
item : uuid
---
number : int
"""


class Topic(dj.Manual):
definition = """
# A topic for items
topic_id : uuid # internal identification of a topic, reflects topic name
---
topic : varchar(8000) # full topic name used to generate the topic id
"""

def add(self, topic):
"""add a new topic with a its UUID"""
self.insert1(
dict(topic_id=uuid.uuid5(top_level_namespace_id, topic), topic=topic)
)


class Item(dj.Computed):
definition = """
item_id : uuid # internal identification of
---
-> Topic
word : varchar(8000)
"""

key_source = Topic # test key source that is not instantiated

def make(self, key):
for word in ("Habenula", "Hippocampus", "Hypothalamus", "Hypophysis"):
self.insert1(
dict(key, word=word, item_id=uuid.uuid5(key["topic_id"], word))
)


LOCALS_UUID = {k: v for k, v in locals().items() if inspect.isclass(v)}
__all__ = list(LOCALS_UUID)
130 changes: 130 additions & 0 deletions tests/test_aggr_regressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
Regression tests for issues 386, 449, 484, and 558 — all related to processing complex aggregations and projections.
"""

import pytest
import datajoint as dj
from . import PREFIX
import uuid
from .schema_uuid import Topic, Item, top_level_namespace_id
from .schema_aggr_regress import R, Q, S, A, B, X, LOCALS_AGGR_REGRESS


@pytest.fixture(scope="function")
def schema_aggr_reg(connection_test):
context = LOCALS_AGGR_REGRESS
schema = dj.Schema(
PREFIX + "_aggr_regress",
context=context,
connection=connection_test,
)
schema(R)
schema(Q)
schema(S)
yield schema
schema.drop()


@pytest.fixture(scope="function")
def schema_aggr_reg_with_abx(connection_test):
context = LOCALS_AGGR_REGRESS
schema = dj.Schema(
PREFIX + "_aggr_regress_with_abx",
context=context,
connection=connection_test,
)
schema(R)
schema(Q)
schema(S)
schema(A)
schema(B)
schema(X)
yield schema
schema.drop()


def test_issue386(schema_aggr_reg):
"""
--------------- ISSUE 386 -------------------
Issue 386 resulted from the loss of aggregated attributes when the aggregation was used as the restrictor
Q & (R.aggr(S, n='count(*)') & 'n=2')
Error: Unknown column 'n' in HAVING
"""
result = R.aggr(S, n="count(*)") & "n=10"
result = Q & result
result.fetch()


def test_issue449(schema_aggr_reg):
"""
---------------- ISSUE 449 ------------------
Issue 449 arises from incorrect group by attributes after joining with a dj.U()
"""
result = dj.U("n") * R.aggr(S, n="max(s)")
result.fetch()


def test_issue484(schema_aggr_reg):
"""
---------------- ISSUE 484 -----------------
Issue 484
"""
q = dj.U().aggr(S, n="max(s)")
n = q.fetch("n")
n = q.fetch1("n")
q = dj.U().aggr(S, n="avg(s)")
result = dj.U().aggr(q, m="max(n)")
result.fetch()


def test_union_join(schema_aggr_reg_with_abx):
"""
This test fails if it runs after TestIssue558.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Document this as a known GitHub Issue after we merge dev-tests to trunk.


https://github.com/datajoint/datajoint-python/issues/930
"""
A.insert(zip([100, 200, 300, 400, 500, 600]))
B.insert([(100, 11), (200, 22), (300, 33), (400, 44)])
q1 = B & "id < 300"
q2 = B & "id > 300"

expected_data = [
{"id": 0, "id2": 5},
{"id": 1, "id2": 6},
{"id": 2, "id2": 7},
{"id": 3, "id2": 8},
{"id": 4, "id2": 9},
{"id": 100, "id2": 11},
{"id": 200, "id2": 22},
{"id": 400, "id2": 44},
]

assert ((q1 + q2) * A).fetch(as_dict=True) == expected_data


class TestIssue558:
"""
--------------- ISSUE 558 ------------------
Issue 558 resulted from the fact that DataJoint saves subqueries and often combines a restriction followed
by a projection into a single SELECT statement, which in several unusual cases produces unexpected results.
"""

def test_issue558_part1(self, schema_aggr_reg_with_abx):
q = (A - B).proj(id2="3")
assert len(A - B) == len(q)

def test_issue558_part2(self, schema_aggr_reg_with_abx):
d = dict(id=3, id2=5)
assert len(X & d) == len((X & d).proj(id2="3"))


def test_left_join_len(schema_uuid):
Topic().add("jeff")
Item.populate()
Topic().add("jeff2")
Topic().add("jeff3")
q = Topic.join(
Item - dict(topic_id=uuid.uuid5(top_level_namespace_id, "jeff")), left=True
)
qf = q.fetch()
assert len(q) == len(qf)
2 changes: 1 addition & 1 deletion tests/test_erd.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_make_image(schema_simp):

def test_part_table_parsing(schema_simp):
# https://github.com/datajoint/datajoint-python/issues/882
erd = dj.Di(schema_simp)
erd = dj.Di(schema_simp, context=LOCALS_SIMPLE)
graph = erd._make_graph()
assert "OutfitLaunch" in graph.nodes()
assert "OutfitLaunch.OutfitPiece" in graph.nodes()