Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def enable_filepath_feature(monkeypatch):
monkeypatch.delenv(FILEPATH_FEATURE_SWITCH, raising=True)


@pytest.fixture(scope="session")
def db_creds_test() -> Dict:
return dict(
host=os.getenv("DJ_TEST_HOST", "fakeservices.datajoint.io"),
user=os.getenv("DJ_TEST_USER", "datajoint"),
password=os.getenv("DJ_TEST_PASSWORD", "datajoint"),
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another part of the move towards putting these credentials in fixtures.


@pytest.fixture(scope="session")
def db_creds_root() -> Dict:
return dict(
Expand Down Expand Up @@ -142,12 +151,9 @@ def connection_root(connection_root_bare):


@pytest.fixture(scope="session")
def connection_test(connection_root):
def connection_test(connection_root, db_creds_test):
"""Test user database connection."""
database = f"{PREFIX}%%"
credentials = dict(
host=os.getenv("DJ_HOST"), user="datajoint", password="datajoint"
)
permission = "ALL PRIVILEGES"

# Create MySQL users
Expand All @@ -157,14 +163,14 @@ def connection_test(connection_root):
# create user if necessary on mysql8
connection_root.query(
f"""
CREATE USER IF NOT EXISTS '{credentials["user"]}'@'%%'
IDENTIFIED BY '{credentials["password"]}';
CREATE USER IF NOT EXISTS '{db_creds_test["user"]}'@'%%'
IDENTIFIED BY '{db_creds_test["password"]}';
"""
)
connection_root.query(
f"""
GRANT {permission} ON `{database}`.*
TO '{credentials["user"]}'@'%%';
TO '{db_creds_test["user"]}'@'%%';
"""
)
else:
Expand All @@ -173,14 +179,14 @@ def connection_test(connection_root):
connection_root.query(
f"""
GRANT {permission} ON `{database}`.*
TO '{credentials["user"]}'@'%%'
IDENTIFIED BY '{credentials["password"]}';
TO '{db_creds_test["user"]}'@'%%'
IDENTIFIED BY '{db_creds_test["password"]}';
"""
)

connection = dj.Connection(**credentials)
connection = dj.Connection(**db_creds_test)
yield connection
connection_root.query(f"""DROP USER `{credentials["user"]}`""")
connection_root.query(f"""DROP USER `{db_creds_test["user"]}`""")
connection.close()


Expand Down
248 changes: 248 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import types
import pytest
import inspect
import datajoint as dj
from unittest.mock import patch
from inspect import getmembers
from . import schema
from . import PREFIX


class Ephys(dj.Imported):
definition = """ # This is already declare in ./schema.py
"""


def relation_selector(attr):
try:
return issubclass(attr, dj.Table)
except TypeError:
return False


def part_selector(attr):
try:
return issubclass(attr, dj.Part)
except TypeError:
return False


@pytest.fixture
def schema_empty_module(schema_any, schema_empty):
"""
Mock the module tests_old.schema_empty.
The test `test_namespace_population` will check that the module contains all the
classes in schema_any, after running `spawn_missing_classes`.
"""
namespace_dict = {
"_": schema_any,
"schema": schema_empty,
"Ephys": Ephys,
}
module = types.ModuleType("schema_empty")

# Add classes to the module's namespace
for k, v in namespace_dict.items():
setattr(module, k, v)

return module
Copy link
Contributor Author

@ethho ethho Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an interesting problem. In tests_old/schema_empty.py, spawn_missing_classes relies on the schema.schema (called the schema_any fixture in new tests) being activated on import. Because this is now in the schema_any fixture, we weren't able to create schema_empty.schema the same way in the pytest suite (in pytest, the value of the fixture is only available after everything's been imported). Instead, I used the types.ModuleType to mock a module that mimics the same behavior in tests_old/schema_empty.py.



@pytest.fixture
def schema_empty(connection_test, schema_any):
context = {**schema.LOCALS_ANY, "Ephys": Ephys}
schema_empty = dj.Schema(
PREFIX + "_test1", context=context, connection=connection_test
)
schema_empty(Ephys)
# load the rest of the classes
schema_empty.spawn_missing_classes(context=context)
yield schema_empty
schema_empty.drop()


def test_schema_size_on_disk(schema_any):
number_of_bytes = schema_any.size_on_disk
assert isinstance(number_of_bytes, int)


def test_schema_list(schema_any):
schemas = dj.list_schemas()
assert schema_any.database in schemas


def test_drop_unauthorized():
info_schema = dj.schema("information_schema")
with pytest.raises(dj.errors.AccessError):
info_schema.drop()


def test_namespace_population(schema_empty_module):
"""
With the schema_empty_module fixture, this test
mimics the behavior of `spawn_missing_classes`, as if the schema
was declared in a separate module and `spawn_missing_classes` was called in that namespace.
"""
# Spawn missing classes in the caller's (self) namespace.
schema_empty_module.schema.context = None
schema_empty_module.schema.spawn_missing_classes(context=None)
# Then add them to the mock module's namespace.
for k, v in locals().items():
if inspect.isclass(v):
setattr(schema_empty_module, k, v)

for name, rel in getmembers(schema, relation_selector):
assert hasattr(
schema_empty_module, name
), "{name} not found in schema_empty".format(name=name)
assert (
rel.__base__ is getattr(schema_empty_module, name).__base__
), "Wrong tier for {name}".format(name=name)

for name_part in dir(rel):
if name_part[0].isupper() and part_selector(getattr(rel, name_part)):
assert (
getattr(rel, name_part).__base__ is dj.Part
), "Wrong tier for {name}".format(name=name_part)


def test_undecorated_table():
"""
Undecorated user table classes should raise an informative exception upon first use
"""

class UndecoratedClass(dj.Manual):
definition = ""

a = UndecoratedClass()
with pytest.raises(dj.DataJointError):
print(a.full_table_name)


def test_reject_decorated_part(schema_any):
"""
Decorating a dj.Part table should raise an informative exception.
"""

class A(dj.Manual):
definition = ...

class B(dj.Part):
definition = ...

with pytest.raises(dj.DataJointError):
schema_any(A.B)
schema_any(A)


def test_unauthorized_database(db_creds_test):
"""
an attempt to create a database to which user has no privileges should raise an informative exception.
"""
with pytest.raises(dj.DataJointError):
dj.Schema(
"unauthorized_schema", connection=dj.conn(reset=True, **db_creds_test)
)


def test_drop_database(db_creds_test):
schema = dj.Schema(
PREFIX + "_drop_test", connection=dj.conn(reset=True, **db_creds_test)
)
assert schema.exists
schema.drop()
assert not schema.exists
schema.drop() # should do nothing


def test_overlapping_name(connection_test):
test_schema = dj.Schema(PREFIX + "_overlapping_schema", connection=connection_test)

@test_schema
class Unit(dj.Manual):
definition = """
id: int # simple id
"""

# hack to update the locals dictionary
locals()

@test_schema
class Cell(dj.Manual):
definition = """
type: varchar(32) # type of cell
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit inconsistent in the way we add Cell to test_schema, but this is exactly the way it was written in the old tests, so I'm just going to leave it.

class Unit(dj.Part):
definition = """
-> master
-> Unit
"""

test_schema.drop()


def test_list_tables(schema_simp):
"""
https://github.com/datajoint/datajoint-python/issues/838
"""
assert set(
[
"reserved_word",
"#l",
"#a",
"__d",
"__b",
"__b__c",
"__e",
"__e__f",
"#outfit_launch",
"#outfit_launch__outfit_piece",
"#i_j",
"#j_i",
"#t_test_update",
"#data_a",
"#data_b",
"f",
"#argmax_test",
"#website",
"profile",
"profile__website",
]
) == set(schema_simp.list_tables())


def test_schema_save_any(schema_any):
assert "class Experiment(dj.Imported)" in schema_any.code


def test_schema_save_empty(schema_empty):
assert "class Experiment(dj.Imported)" in schema_empty.code


def test_uppercase_schema(db_creds_root):
"""
https://github.com/datajoint/datajoint-python/issues/564
"""
dj.conn(**db_creds_root, reset=True)
schema1 = dj.Schema("Schema_A")

@schema1
class Subject(dj.Manual):
definition = """
name: varchar(32)
"""

Schema_A = dj.VirtualModule("Schema_A", "Schema_A")

schema2 = dj.Schema("schema_b")

@schema2
class Recording(dj.Manual):
definition = """
-> Schema_A.Subject
id: smallint
"""

schema2.drop()
schema1.drop()