Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
schema_advanced,
schema_adapted,
schema_external,
schema_uuid as schema_uuid_module,
)


Expand Down Expand Up @@ -307,6 +308,20 @@ def schema_ext(connection_test, stores_config, enable_filepath_feature):
schema.drop()


@pytest.fixture
def schema_uuid(connection_test):
schema = dj.Schema(
PREFIX + "_test1",
context=schema_uuid_module.LOCALS_UUID,
connection=connection_test,
)
schema(schema_uuid_module.Basic)
schema(schema_uuid_module.Topic)
schema(schema_uuid_module.Item)
yield schema
schema.drop()


@pytest.fixture(scope="session")
def http_client():
# Initialize httpClient with relevant timeout.
Expand Down
51 changes: 51 additions & 0 deletions tests/schema_aggr_regress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import datajoint as dj
import itertools
import inspect


class R(dj.Lookup):
definition = """
r : char(1)
"""
contents = zip("ABCDFGHIJKLMNOPQRST")


class Q(dj.Lookup):
definition = """
-> R
"""
contents = zip("ABCDFGH")


class S(dj.Lookup):
definition = """
-> R
s : int
"""
contents = itertools.product("ABCDF", range(10))


class A(dj.Lookup):
definition = """
id: int
"""
contents = zip(range(10))


class B(dj.Lookup):
definition = """
-> A
id2: int
"""
contents = zip(range(5), range(5, 10))


class X(dj.Lookup):
definition = """
id: int
"""
contents = zip(range(10))


LOCALS_AGGR_REGRESS = {k: v for k, v in locals().items() if inspect.isclass(v)}
__all__ = list(LOCALS_AGGR_REGRESS)
50 changes: 50 additions & 0 deletions tests/schema_uuid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import uuid
import inspect
import datajoint as dj
from . import PREFIX, CONN_INFO

top_level_namespace_id = uuid.UUID("00000000-0000-0000-0000-000000000000")


class Basic(dj.Manual):
definition = """
item : uuid
---
number : int
"""


class Topic(dj.Manual):
definition = """
# A topic for items
topic_id : uuid # internal identification of a topic, reflects topic name
---
topic : varchar(8000) # full topic name used to generate the topic id
"""

def add(self, topic):
"""add a new topic with a its UUID"""
self.insert1(
dict(topic_id=uuid.uuid5(top_level_namespace_id, topic), topic=topic)
)


class Item(dj.Computed):
definition = """
item_id : uuid # internal identification of
---
-> Topic
word : varchar(8000)
"""

key_source = Topic # test key source that is not instantiated

def make(self, key):
for word in ("Habenula", "Hippocampus", "Hypothalamus", "Hypophysis"):
self.insert1(
dict(key, word=word, item_id=uuid.uuid5(key["topic_id"], word))
)


LOCALS_UUID = {k: v for k, v in locals().items() if inspect.isclass(v)}
__all__ = list(LOCALS_UUID)
130 changes: 130 additions & 0 deletions tests/test_aggr_regressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
Regression tests for issues 386, 449, 484, and 558 — all related to processing complex aggregations and projections.
"""

import pytest
import datajoint as dj
from . import PREFIX
import uuid
from .schema_uuid import Topic, Item, top_level_namespace_id
from .schema_aggr_regress import R, Q, S, A, B, X, LOCALS_AGGR_REGRESS


@pytest.fixture(scope="function")
def schema_aggr_reg(connection_test):
context = LOCALS_AGGR_REGRESS
schema = dj.Schema(
PREFIX + "_aggr_regress",
context=context,
connection=connection_test,
)
schema(R)
schema(Q)
schema(S)
yield schema
schema.drop()


@pytest.fixture(scope="function")
def schema_aggr_reg_with_abx(connection_test):
context = LOCALS_AGGR_REGRESS
schema = dj.Schema(
PREFIX + "_aggr_regress_with_abx",
context=context,
connection=connection_test,
)
schema(R)
schema(Q)
schema(S)
schema(A)
schema(B)
schema(X)
yield schema
schema.drop()


def test_issue386(schema_aggr_reg):
"""
--------------- ISSUE 386 -------------------
Issue 386 resulted from the loss of aggregated attributes when the aggregation was used as the restrictor
Q & (R.aggr(S, n='count(*)') & 'n=2')
Error: Unknown column 'n' in HAVING
"""
result = R.aggr(S, n="count(*)") & "n=10"
result = Q & result
result.fetch()


def test_issue449(schema_aggr_reg):
"""
---------------- ISSUE 449 ------------------
Issue 449 arises from incorrect group by attributes after joining with a dj.U()
"""
result = dj.U("n") * R.aggr(S, n="max(s)")
result.fetch()


def test_issue484(schema_aggr_reg):
"""
---------------- ISSUE 484 -----------------
Issue 484
"""
q = dj.U().aggr(S, n="max(s)")
n = q.fetch("n")
n = q.fetch1("n")
q = dj.U().aggr(S, n="avg(s)")
result = dj.U().aggr(q, m="max(n)")
result.fetch()


def test_union_join(schema_aggr_reg_with_abx):
"""
This test fails if it runs after TestIssue558.

https://github.com/datajoint/datajoint-python/issues/930
"""
A.insert(zip([100, 200, 300, 400, 500, 600]))
B.insert([(100, 11), (200, 22), (300, 33), (400, 44)])
q1 = B & "id < 300"
q2 = B & "id > 300"

expected_data = [
{"id": 0, "id2": 5},
{"id": 1, "id2": 6},
{"id": 2, "id2": 7},
{"id": 3, "id2": 8},
{"id": 4, "id2": 9},
{"id": 100, "id2": 11},
{"id": 200, "id2": 22},
{"id": 400, "id2": 44},
]

assert ((q1 + q2) * A).fetch(as_dict=True) == expected_data


class TestIssue558:
"""
--------------- ISSUE 558 ------------------
Issue 558 resulted from the fact that DataJoint saves subqueries and often combines a restriction followed
by a projection into a single SELECT statement, which in several unusual cases produces unexpected results.
"""

def test_issue558_part1(self, schema_aggr_reg_with_abx):
q = (A - B).proj(id2="3")
assert len(A - B) == len(q)

def test_issue558_part2(self, schema_aggr_reg_with_abx):
d = dict(id=3, id2=5)
assert len(X & d) == len((X & d).proj(id2="3"))


def test_left_join_len(schema_uuid):
Topic().add("jeff")
Item.populate()
Topic().add("jeff2")
Topic().add("jeff3")
q = Topic.join(
Item - dict(topic_id=uuid.uuid5(top_level_namespace_id, "jeff")), left=True
)
qf = q.fetch()
assert len(q) == len(qf)
116 changes: 116 additions & 0 deletions tests/test_alter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import pytest
import re
import datajoint as dj
from . import schema as schema_any_module, PREFIX


class Experiment(dj.Imported):
original_definition = """ # information about experiments
-> Subject
experiment_id :smallint # experiment number for this subject
---
experiment_date :date # date when experiment was started
-> [nullable] User
data_path="" :varchar(255) # file path to recorded data
notes="" :varchar(2048) # e.g. purpose of experiment
entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp
"""

definition1 = """ # Experiment
-> Subject
experiment_id :smallint # experiment number for this subject
---
data_path : int # some number
extra=null : longblob # just testing
-> [nullable] User
subject_notes=null :varchar(2048) # {notes} e.g. purpose of experiment
entry_time=CURRENT_TIMESTAMP :timestamp # automatic timestamp
"""


class Parent(dj.Manual):
definition = """
parent_id: int
"""

class Child(dj.Part):
definition = """
-> Parent
"""
definition_new = """
-> master
---
child_id=null: int
"""

class Grandchild(dj.Part):
definition = """
-> master.Child
"""
definition_new = """
-> master.Child
---
grandchild_id=null: int
"""


LOCALS_ALTER = {"Experiment": Experiment, "Parent": Parent}
COMBINED_CONTEXT = {
**schema_any_module.LOCALS_ANY,
**LOCALS_ALTER,
}


@pytest.fixture
def schema_alter(connection_test, schema_any):
# Overwrite Experiment and Parent nodes
schema_any(Experiment, context=LOCALS_ALTER)
schema_any(Parent, context=LOCALS_ALTER)
yield schema_any
schema_any.drop()


class TestAlter:
def test_alter(self, schema_alter):
original = schema_alter.connection.query(
"SHOW CREATE TABLE " + Experiment.full_table_name
).fetchone()[1]
Experiment.definition = Experiment.definition1
Experiment.alter(prompt=False, context=COMBINED_CONTEXT)
altered = schema_alter.connection.query(
"SHOW CREATE TABLE " + Experiment.full_table_name
).fetchone()[1]
assert original != altered
Experiment.definition = Experiment.original_definition
Experiment().alter(prompt=False, context=COMBINED_CONTEXT)
restored = schema_alter.connection.query(
"SHOW CREATE TABLE " + Experiment.full_table_name
).fetchone()[1]
assert altered != restored
assert original == restored

def verify_alter(self, schema_alter, table, attribute_sql):
definition_original = schema_alter.connection.query(
f"SHOW CREATE TABLE {table.full_table_name}"
).fetchone()[1]
table.definition = table.definition_new
table.alter(prompt=False)
definition_new = schema_alter.connection.query(
f"SHOW CREATE TABLE {table.full_table_name}"
).fetchone()[1]
assert (
re.sub(f"{attribute_sql},\n ", "", definition_new) == definition_original
)

def test_alter_part(self, schema_alter):
"""
https://github.com/datajoint/datajoint-python/issues/936
"""
self.verify_alter(
schema_alter, table=Parent.Child, attribute_sql="`child_id` .* DEFAULT NULL"
)
self.verify_alter(
schema_alter,
table=Parent.Grandchild,
attribute_sql="`grandchild_id` .* DEFAULT NULL",
)
2 changes: 1 addition & 1 deletion tests/test_erd.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_make_image(schema_simp):

def test_part_table_parsing(schema_simp):
# https://github.com/datajoint/datajoint-python/issues/882
erd = dj.Di(schema_simp)
erd = dj.Di(schema_simp, context=LOCALS_SIMPLE)
graph = erd._make_graph()
assert "OutfitLaunch" in graph.nodes()
assert "OutfitLaunch.OutfitPiece" in graph.nodes()