diff --git a/tests/__init__.py b/tests/__init__.py index de57f6eab..219f7f5c0 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -3,10 +3,24 @@ import pytest import os -PREFIX = "djtest" +PREFIX = os.environ.get("DJ_TEST_DB_PREFIX", "djtest") + +# Connection for testing +CONN_INFO = dict( + host=os.environ.get("DJ_TEST_HOST", "fakeservices.datajoint.io"), + user=os.environ.get("DJ_TEST_USER", "datajoint"), + password=os.environ.get("DJ_TEST_PASSWORD", "datajoint"), +) CONN_INFO_ROOT = dict( - host=os.getenv("DJ_HOST"), - user=os.getenv("DJ_USER"), - password=os.getenv("DJ_PASS"), + host=os.environ.get("DJ_HOST", "fakeservices.datajoint.io"), + user=os.environ.get("DJ_USER", "root"), + password=os.environ.get("DJ_PASS", "simple"), +) + +S3_CONN_INFO = dict( + endpoint=os.environ.get("S3_ENDPOINT", "fakeservices.datajoint.io"), + access_key=os.environ.get("S3_ACCESS_KEY", "datajoint"), + secret_key=os.environ.get("S3_SECRET_KEY", "datajoint"), + bucket=os.environ.get("S3_BUCKET", "datajoint.test"), ) diff --git a/tests/conftest.py b/tests/conftest.py index e13a13632..2c4063a1d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,38 @@ import datajoint as dj from packaging import version import os +import minio +import urllib3 +import certifi +import shutil import pytest -from . import PREFIX, schema, schema_simple, schema_advanced +import networkx as nx +import json +from pathlib import Path +import tempfile +from datajoint import errors +from datajoint.errors import ADAPTED_TYPE_SWITCH, FILEPATH_FEATURE_SWITCH +from . import ( + PREFIX, + CONN_INFO, + S3_CONN_INFO, + schema, + schema_simple, + schema_advanced, + schema_adapted, +) -namespace = locals() + +@pytest.fixture(scope="session") +def monkeysession(): + with pytest.MonkeyPatch.context() as mp: + yield mp + + +@pytest.fixture(scope="module") +def monkeymodule(): + with pytest.MonkeyPatch.context() as mp: + yield mp @pytest.fixture(scope="session") @@ -64,11 +92,12 @@ def connection_test(connection_root): connection.close() -@pytest.fixture(scope="module") +@pytest.fixture def schema_any(connection_test): schema_any = dj.Schema( - PREFIX + "_test1", schema.__dict__, connection=connection_test + PREFIX + "_test1", schema.LOCALS_ANY, connection=connection_test ) + assert schema.LOCALS_ANY, "LOCALS_ANY is empty" schema_any(schema.TTest) schema_any(schema.TTest2) schema_any(schema.TTest3) @@ -109,10 +138,10 @@ def schema_any(connection_test): schema_any.drop() -@pytest.fixture(scope="module") +@pytest.fixture def schema_simp(connection_test): schema = dj.Schema( - PREFIX + "_relational", schema_simple.__dict__, connection=connection_test + PREFIX + "_relational", schema_simple.LOCALS_SIMPLE, connection=connection_test ) schema(schema_simple.IJ) schema(schema_simple.JI) @@ -136,10 +165,12 @@ def schema_simp(connection_test): schema.drop() -@pytest.fixture(scope="module") +@pytest.fixture def schema_adv(connection_test): schema = dj.Schema( - PREFIX + "_advanced", schema_advanced.__dict__, connection=connection_test + PREFIX + "_advanced", + schema_advanced.LOCALS_ADVANCED, + connection=connection_test, ) schema(schema_advanced.Person) schema(schema_advanced.Parent) @@ -152,3 +183,30 @@ def schema_adv(connection_test): schema(schema_advanced.GlobalSynapse) yield schema schema.drop() + + +@pytest.fixture +def httpClient(): + # Initialize httpClient with relevant timeout. + httpClient = urllib3.PoolManager( + timeout=30, + cert_reqs="CERT_REQUIRED", + ca_certs=certifi.where(), + retries=urllib3.Retry( + total=3, backoff_factor=0.2, status_forcelist=[500, 502, 503, 504] + ), + ) + yield httpClient + + +@pytest.fixture +def minioClient(): + # Initialize minioClient with an endpoint and access/secret keys. + minioClient = minio.Minio( + S3_CONN_INFO["endpoint"], + access_key=S3_CONN_INFO["access_key"], + secret_key=S3_CONN_INFO["secret_key"], + secure=True, + http_client=httpClient, + ) + yield minioClient diff --git a/tests/schema.py b/tests/schema.py index 864c5efe4..140a34bba 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -7,8 +7,6 @@ import datajoint as dj import inspect -LOCALS_ANY = locals() - class TTest(dj.Lookup): """ @@ -33,7 +31,7 @@ class TTest2(dj.Manual): class TTest3(dj.Manual): definition = """ - key : int + key : int --- value : varchar(300) """ @@ -41,7 +39,7 @@ class TTest3(dj.Manual): class NullableNumbers(dj.Manual): definition = """ - key : int + key : int --- fvalue = null : float dvalue = null : double @@ -450,3 +448,7 @@ class Longblob(dj.Manual): --- data: longblob """ + + +LOCALS_ANY = {k: v for k, v in locals().items() if inspect.isclass(v)} +__all__ = list(LOCALS_ANY) diff --git a/tests/schema_adapted.py b/tests/schema_adapted.py new file mode 100644 index 000000000..ab9a02e76 --- /dev/null +++ b/tests/schema_adapted.py @@ -0,0 +1,62 @@ +import datajoint as dj +import inspect +import networkx as nx +import json +from pathlib import Path +import tempfile + + +class GraphAdapter(dj.AttributeAdapter): + attribute_type = "longblob" # this is how the attribute will be declared + + @staticmethod + def get(obj): + # convert edge list into a graph + return nx.Graph(obj) + + @staticmethod + def put(obj): + # convert graph object into an edge list + assert isinstance(obj, nx.Graph) + return list(obj.edges) + + +class LayoutToFilepath(dj.AttributeAdapter): + """ + An adapted data type that saves a graph layout into fixed filepath + """ + + attribute_type = "filepath@repo-s3" + + @staticmethod + def get(path): + with open(path, "r") as f: + return json.load(f) + + @staticmethod + def put(layout): + path = Path(dj.config["stores"]["repo-s3"]["stage"], "layout.json") + with open(str(path), "w") as f: + json.dump(layout, f) + return path + + +class Connectivity(dj.Manual): + definition = """ + connid : int + --- + conn_graph = null : + """ + + +class Layout(dj.Manual): + definition = """ + # stores graph layout + -> Connectivity + --- + layout: + """ + + +LOCALS_ADAPTED = {k: v for k, v in locals().items() if inspect.isclass(v)} +__all__ = list(LOCALS_ADAPTED) diff --git a/tests/schema_advanced.py b/tests/schema_advanced.py index 104e4d1e4..6a35cb34a 100644 --- a/tests/schema_advanced.py +++ b/tests/schema_advanced.py @@ -1,6 +1,5 @@ import datajoint as dj - -LOCALS_ADVANCED = locals() +import inspect class Person(dj.Manual): @@ -135,3 +134,7 @@ class GlobalSynapse(dj.Manual): -> Cell.proj(pre_slice="slice", pre_cell="cell") -> Cell.proj(post_slice="slice", post_cell="cell") """ + + +LOCALS_ADVANCED = {k: v for k, v in locals().items() if inspect.isclass(v)} +__all__ = list(LOCALS_ADVANCED) diff --git a/tests/schema_simple.py b/tests/schema_simple.py index bb5c21ff5..e751a9c6e 100644 --- a/tests/schema_simple.py +++ b/tests/schema_simple.py @@ -9,8 +9,7 @@ import faker import numpy as np from datetime import date, timedelta - -LOCALS_SIMPLE = locals() +import inspect class IJ(dj.Lookup): @@ -237,8 +236,8 @@ class ReservedWord(dj.Manual): # Test of SQL reserved words key : int --- - in : varchar(25) - from : varchar(25) + in : varchar(25) + from : varchar(25) int : int select : varchar(25) """ @@ -260,3 +259,7 @@ class OutfitPiece(dj.Part, dj.Lookup): piece: varchar(20) """ contents = [(0, "jeans"), (0, "sneakers"), (0, "polo")] + + +LOCALS_SIMPLE = {k: v for k, v in locals().items() if inspect.isclass(v)} +__all__ = list(LOCALS_SIMPLE) diff --git a/tests/test_adapted_attributes.py b/tests/test_adapted_attributes.py new file mode 100644 index 000000000..29d773473 --- /dev/null +++ b/tests/test_adapted_attributes.py @@ -0,0 +1,153 @@ +import os +import pytest +import tempfile +import datajoint as dj +from datajoint.errors import ADAPTED_TYPE_SWITCH, FILEPATH_FEATURE_SWITCH +import networkx as nx +from itertools import zip_longest +from . import schema_adapted +from .schema_adapted import Connectivity, Layout +from . import PREFIX, S3_CONN_INFO + +SCHEMA_NAME = PREFIX + "_test_custom_datatype" + + +@pytest.fixture +def adapted_graph_instance(): + yield schema_adapted.GraphAdapter() + + +@pytest.fixture +def enable_adapted_types(monkeypatch): + monkeypatch.setenv(ADAPTED_TYPE_SWITCH, "TRUE") + yield + monkeypatch.delenv(ADAPTED_TYPE_SWITCH, raising=True) + + +@pytest.fixture +def enable_filepath_feature(monkeypatch): + monkeypatch.setenv(FILEPATH_FEATURE_SWITCH, "TRUE") + yield + monkeypatch.delenv(FILEPATH_FEATURE_SWITCH, raising=True) + + +@pytest.fixture +def schema_ad( + connection_test, + adapted_graph_instance, + enable_adapted_types, + enable_filepath_feature, +): + stores_config = { + "repo-s3": dict( + S3_CONN_INFO, + protocol="s3", + location="adapted/repo", + stage=tempfile.mkdtemp(), + ) + } + dj.config["stores"] = stores_config + layout_to_filepath = schema_adapted.LayoutToFilepath() + context = { + **schema_adapted.LOCALS_ADAPTED, + "graph": adapted_graph_instance, + "layout_to_filepath": layout_to_filepath, + } + schema = dj.schema(SCHEMA_NAME, context=context, connection=connection_test) + graph = adapted_graph_instance + schema(schema_adapted.Connectivity) + schema(schema_adapted.Layout) + yield schema + schema.drop() + + +@pytest.fixture +def local_schema(schema_ad): + """Fixture for testing spawned classes""" + local_schema = dj.Schema(SCHEMA_NAME) + local_schema.spawn_missing_classes() + yield local_schema + local_schema.drop() + + +@pytest.fixture +def schema_virtual_module(schema_ad, adapted_graph_instance): + """Fixture for testing virtual modules""" + schema_virtual_module = dj.VirtualModule( + "virtual_module", SCHEMA_NAME, add_objects={"graph": adapted_graph_instance} + ) + return schema_virtual_module + + +def test_adapted_type(schema_ad): + c = Connectivity() + graphs = [ + nx.lollipop_graph(4, 2), + nx.star_graph(5), + nx.barbell_graph(3, 1), + nx.cycle_graph(5), + ] + c.insert((i, g) for i, g in enumerate(graphs)) + returned_graphs = c.fetch("conn_graph", order_by="connid") + for g1, g2 in zip(graphs, returned_graphs): + assert isinstance(g2, nx.Graph) + assert len(g1.edges) == len(g2.edges) + assert 0 == len(nx.symmetric_difference(g1, g2).edges) + c.delete() + + +@pytest.mark.skip(reason="misconfigured s3 fixtures") +def test_adapted_filepath_type(schema_ad): + """https://github.com/datajoint/datajoint-python/issues/684""" + c = Connectivity() + c.delete() + c.insert1((0, nx.lollipop_graph(4, 2))) + + layout = nx.spring_layout(c.fetch1("conn_graph")) + # make json friendly + layout = {str(k): [round(r, ndigits=4) for r in v] for k, v in layout.items()} + t = Layout() + t.insert1((0, layout)) + result = t.fetch1("layout") + # TODO: may fail, used to be assert_dict_equal + assert result == layout + t.delete() + c.delete() + + +def test_adapted_spawned(local_schema, enable_adapted_types): + c = Connectivity() # a spawned class + graphs = [ + nx.lollipop_graph(4, 2), + nx.star_graph(5), + nx.barbell_graph(3, 1), + nx.cycle_graph(5), + ] + c.insert((i, g) for i, g in enumerate(graphs)) + returned_graphs = c.fetch("conn_graph", order_by="connid") + for g1, g2 in zip(graphs, returned_graphs): + assert isinstance(g2, nx.Graph) + assert len(g1.edges) == len(g2.edges) + assert 0 == len(nx.symmetric_difference(g1, g2).edges) + c.delete() + + +def test_adapted_virtual(schema_virtual_module): + c = schema_virtual_module.Connectivity() + graphs = [ + nx.lollipop_graph(4, 2), + nx.star_graph(5), + nx.barbell_graph(3, 1), + nx.cycle_graph(5), + ] + c.insert((i, g) for i, g in enumerate(graphs)) + c.insert1({"connid": 100}) # test work with NULLs + returned_graphs = c.fetch("conn_graph", order_by="connid") + for g1, g2 in zip_longest(graphs, returned_graphs): + if g1 is None: + assert g2 is None + else: + assert isinstance(g2, nx.Graph) + assert len(g1.edges) == len(g2.edges) + assert 0 == len(nx.symmetric_difference(g1, g2).edges) + c.delete() diff --git a/tests/test_blob.py b/tests/test_blob.py index 23de7be76..e55488987 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -7,7 +7,7 @@ from datajoint.blob import pack, unpack from numpy.testing import assert_array_equal from pytest import approx -from .schema import * +from .schema import Longblob def test_pack(): diff --git a/tests/test_blob_matlab.py b/tests/test_blob_matlab.py index 06154b1fc..575e6b0b8 100644 --- a/tests/test_blob_matlab.py +++ b/tests/test_blob_matlab.py @@ -16,15 +16,15 @@ class Blob(dj.Manual): """ -@pytest.fixture(scope="module") +@pytest.fixture def schema(connection_test): - schema = dj.Schema(PREFIX + "_test1", locals(), connection=connection_test) + schema = dj.Schema(PREFIX + "_test1", dict(Blob=Blob), connection=connection_test) schema(Blob) yield schema schema.drop() -@pytest.fixture(scope="module") +@pytest.fixture def insert_blobs_func(schema): def insert_blobs(): """ @@ -63,7 +63,7 @@ def insert_blobs(): yield insert_blobs -@pytest.fixture(scope="class") +@pytest.fixture def setup_class(schema, insert_blobs_func): assert not dj.config["safemode"], "safemode must be disabled" Blob().delete() diff --git a/tests/test_connection.py b/tests/test_connection.py index 795d3761e..8cdbbbff5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -12,7 +12,9 @@ @pytest.fixture def schema(connection_test): - schema = dj.Schema(PREFIX + "_transactions", locals(), connection=connection_test) + schema = dj.Schema( + PREFIX + "_transactions", context=dict(), connection=connection_test + ) yield schema schema.drop() diff --git a/tests/test_erd.py b/tests/test_erd.py index f1274ec1b..aebf62eaf 100644 --- a/tests/test_erd.py +++ b/tests/test_erd.py @@ -45,13 +45,13 @@ def test_erd_algebra(schema_simp): def test_repr_svg(schema_adv): - erd = dj.ERD(schema_adv, context=locals()) + erd = dj.ERD(schema_adv, context=dict()) svg = erd._repr_svg_() assert svg.startswith("") def test_make_image(schema_simp): - erd = dj.ERD(schema_simp, context=locals()) + erd = dj.ERD(schema_simp, context=dict()) img = erd.make_image() assert img.ndim == 3 and img.shape[2] in (3, 4) diff --git a/tests/test_json.py b/tests/test_json.py index 760475a1a..c1caaeedd 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,3 +1,4 @@ +import pytest import inspect from datajoint.declare import declare import datajoint as dj @@ -5,213 +6,216 @@ from packaging.version import Version from . import PREFIX -if Version(dj.conn().query("select @@version;").fetchone()[0]) >= Version("8.0.0"): - schema = dj.Schema(PREFIX + "_json") - Team = None - - def setup(): - global Team - - @schema - class Team(dj.Lookup): - definition = """ - name: varchar(40) - --- - car=null: json - unique index(car.name:char(20)) - uniQue inDex ( name, car.name:char(20), (json_value(`car`, _utf8mb4'$.length' returning decimal(4, 1))) ) - """ - contents = [ - ( - "engineering", +if Version(dj.conn().query("select @@version;").fetchone()[0]) < Version("8.0.0"): + pytest.skip("skipping windows-only tests", allow_module_level=True) + + +class Team(dj.Lookup): + definition = """ + name: varchar(40) + --- + car=null: json + unique index(car.name:char(20)) + uniQue inDex ( name, car.name:char(20), (json_value(`car`, _utf8mb4'$.length' returning decimal(4, 1))) ) + """ + contents = [ + ( + "engineering", + { + "name": "Rever", + "length": 20.5, + "inspected": True, + "tire_pressure": [32, 31, 33, 34], + "headlights": [ { - "name": "Rever", - "length": 20.5, - "inspected": True, - "tire_pressure": [32, 31, 33, 34], - "headlights": [ - { - "side": "left", - "hyper_white": None, - }, - { - "side": "right", - "hyper_white": None, - }, - ], + "side": "left", + "hyper_white": None, }, - ), - ( - "business", { - "name": "Chaching", - "length": 100, - "safety_inspected": False, - "tire_pressure": [34, 30, 27, 32], - "headlights": [ - { - "side": "left", - "hyper_white": True, - }, - { - "side": "right", - "hyper_white": True, - }, - ], + "side": "right", + "hyper_white": None, }, - ), - ( - "marketing", - None, - ), - ] - - def teardown(): - schema.drop() - - def test_insert_update(): - car = { - "name": "Discovery", - "length": 22.9, - "inspected": None, - "tire_pressure": [35, 36, 34, 37], - "headlights": [ - { - "side": "left", - "hyper_white": True, - }, - { - "side": "right", - "hyper_white": True, - }, - ], - } - - Team.insert1({"name": "research", "car": car}) - q = Team & {"name": "research"} - assert q.fetch1("car") == car - - car.update({"length": 23}) - Team.update1({"name": "research", "car": car}) - assert q.fetch1("car") == car - - try: - Team.insert1({"name": "hr", "car": car}) - raise Exception("Inserted non-unique car name.") - except dj.DataJointError: - pass - - q.delete_quick() - assert not q - - def test_describe(): - rel = Team() - context = inspect.currentframe().f_globals - s1 = declare(rel.full_table_name, rel.definition, context) - s2 = declare(rel.full_table_name, rel.describe(), context) - assert s1 == s2 - - def test_restrict(): - # dict - assert (Team & {"car.name": "Chaching"}).fetch1("name") == "business" - - assert (Team & {"car.length": 20.5}).fetch1("name") == "engineering" - - assert (Team & {"car.inspected": "true"}).fetch1("name") == "engineering" - - assert (Team & {"car.inspected:unsigned": True}).fetch1("name") == "engineering" - - assert (Team & {"car.safety_inspected": "false"}).fetch1("name") == "business" - - assert (Team & {"car.safety_inspected:unsigned": False}).fetch1( - "name" - ) == "business" - - assert (Team & {"car.headlights[0].hyper_white": None}).fetch( - "name", order_by="name", as_dict=True - ) == [ - {"name": "engineering"}, - {"name": "marketing"}, - ] # if entire record missing, JSON key is missing, or value set to JSON null - - assert (Team & {"car": None}).fetch1("name") == "marketing" - - assert (Team & {"car.tire_pressure": [34, 30, 27, 32]}).fetch1( - "name" - ) == "business" - - assert ( - Team & {"car.headlights[1]": {"side": "right", "hyper_white": True}} - ).fetch1("name") == "business" - - # sql operators - assert (Team & "`car`->>'$.name' LIKE '%ching%'").fetch1( - "name" - ) == "business", "Missing substring" - - assert (Team & "`car`->>'$.length' > 30").fetch1("name") == "business", "<= 30" - - assert ( - Team & "JSON_VALUE(`car`, '$.safety_inspected' RETURNING UNSIGNED) = 0" - ).fetch1("name") == "business", "Has `safety_inspected` set to `true`" - - assert (Team & "`car`->>'$.headlights[0].hyper_white' = 'null'").fetch1( - "name" - ) == "engineering", "Has 1st `headlight` with `hyper_white` not set to `null`" - - assert (Team & "`car`->>'$.inspected' IS NOT NULL").fetch1( - "name" - ) == "engineering", "Missing `inspected` key" - - assert (Team & "`car`->>'$.tire_pressure' = '[34, 30, 27, 32]'").fetch1( - "name" - ) == "business", "`tire_pressure` array did not match" - - assert ( - Team - & """`car`->>'$.headlights[1]' = '{"side": "right", "hyper_white": true}'""" - ).fetch1("name") == "business", "2nd `headlight` object did not match" - - def test_proj(): - # proj necessary since we need to rename indexed value into a proper attribute name - assert Team.proj(car_length="car.length").fetch( - as_dict=True, order_by="car_length" - ) == [ - {"name": "marketing", "car_length": None}, - {"name": "business", "car_length": "100"}, - {"name": "engineering", "car_length": "20.5"}, - ] - - assert Team.proj(car_length="car.length:decimal(4, 1)").fetch( - as_dict=True, order_by="car_length" - ) == [ - {"name": "marketing", "car_length": None}, - {"name": "engineering", "car_length": 20.5}, - {"name": "business", "car_length": 100.0}, - ] - - assert Team.proj( - car_width="JSON_VALUE(`car`, '$.length' RETURNING float) - 15" - ).fetch(as_dict=True, order_by="car_width") == [ - {"name": "marketing", "car_width": None}, - {"name": "engineering", "car_width": 5.5}, - {"name": "business", "car_width": 85.0}, - ] - - assert ( - (Team & {"name": "engineering"}).proj(car_tire_pressure="car.tire_pressure") - ).fetch1("car_tire_pressure") == "[32, 31, 33, 34]" - - assert np.array_equal( - Team.proj(car_inspected="car.inspected").fetch( - "car_inspected", order_by="name" - ), - np.array([None, "true", None]), - ) - - assert np.array_equal( - Team.proj(car_inspected="car.inspected:unsigned").fetch( - "car_inspected", order_by="name" - ), - np.array([None, 1, None]), - ) + ], + }, + ), + ( + "business", + { + "name": "Chaching", + "length": 100, + "safety_inspected": False, + "tire_pressure": [34, 30, 27, 32], + "headlights": [ + { + "side": "left", + "hyper_white": True, + }, + { + "side": "right", + "hyper_white": True, + }, + ], + }, + ), + ( + "marketing", + None, + ), + ] + + +@pytest.fixture +def schema(connection_test): + schema = dj.Schema(PREFIX + "_json", context=dict(), connection=connection_test) + schema(Team) + yield schema + schema.drop() + + +def test_insert_update(schema): + car = { + "name": "Discovery", + "length": 22.9, + "inspected": None, + "tire_pressure": [35, 36, 34, 37], + "headlights": [ + { + "side": "left", + "hyper_white": True, + }, + { + "side": "right", + "hyper_white": True, + }, + ], + } + + Team.insert1({"name": "research", "car": car}) + q = Team & {"name": "research"} + assert q.fetch1("car") == car + + car.update({"length": 23}) + Team.update1({"name": "research", "car": car}) + assert q.fetch1("car") == car + + try: + Team.insert1({"name": "hr", "car": car}) + raise Exception("Inserted non-unique car name.") + except dj.DataJointError: + pass + + q.delete_quick() + assert not q + + +def test_describe(schema): + rel = Team() + context = inspect.currentframe().f_globals + s1 = declare(rel.full_table_name, rel.definition, context) + s2 = declare(rel.full_table_name, rel.describe(), context) + assert s1 == s2 + + +def test_restrict(schema): + # dict + assert (Team & {"car.name": "Chaching"}).fetch1("name") == "business" + + assert (Team & {"car.length": 20.5}).fetch1("name") == "engineering" + + assert (Team & {"car.inspected": "true"}).fetch1("name") == "engineering" + + assert (Team & {"car.inspected:unsigned": True}).fetch1("name") == "engineering" + + assert (Team & {"car.safety_inspected": "false"}).fetch1("name") == "business" + + assert (Team & {"car.safety_inspected:unsigned": False}).fetch1( + "name" + ) == "business" + + assert (Team & {"car.headlights[0].hyper_white": None}).fetch( + "name", order_by="name", as_dict=True + ) == [ + {"name": "engineering"}, + {"name": "marketing"}, + ] # if entire record missing, JSON key is missing, or value set to JSON null + + assert (Team & {"car": None}).fetch1("name") == "marketing" + + assert (Team & {"car.tire_pressure": [34, 30, 27, 32]}).fetch1("name") == "business" + + assert ( + Team & {"car.headlights[1]": {"side": "right", "hyper_white": True}} + ).fetch1("name") == "business" + + # sql operators + assert (Team & "`car`->>'$.name' LIKE '%ching%'").fetch1( + "name" + ) == "business", "Missing substring" + + assert (Team & "`car`->>'$.length' > 30").fetch1("name") == "business", "<= 30" + + assert ( + Team & "JSON_VALUE(`car`, '$.safety_inspected' RETURNING UNSIGNED) = 0" + ).fetch1("name") == "business", "Has `safety_inspected` set to `true`" + + assert (Team & "`car`->>'$.headlights[0].hyper_white' = 'null'").fetch1( + "name" + ) == "engineering", "Has 1st `headlight` with `hyper_white` not set to `null`" + + assert (Team & "`car`->>'$.inspected' IS NOT NULL").fetch1( + "name" + ) == "engineering", "Missing `inspected` key" + + assert (Team & "`car`->>'$.tire_pressure' = '[34, 30, 27, 32]'").fetch1( + "name" + ) == "business", "`tire_pressure` array did not match" + + assert ( + Team + & """`car`->>'$.headlights[1]' = '{"side": "right", "hyper_white": true}'""" + ).fetch1("name") == "business", "2nd `headlight` object did not match" + + +def test_proj(schema): + # proj necessary since we need to rename indexed value into a proper attribute name + assert Team.proj(car_length="car.length").fetch( + as_dict=True, order_by="car_length" + ) == [ + {"name": "marketing", "car_length": None}, + {"name": "business", "car_length": "100"}, + {"name": "engineering", "car_length": "20.5"}, + ] + + assert Team.proj(car_length="car.length:decimal(4, 1)").fetch( + as_dict=True, order_by="car_length" + ) == [ + {"name": "marketing", "car_length": None}, + {"name": "engineering", "car_length": 20.5}, + {"name": "business", "car_length": 100.0}, + ] + + assert Team.proj( + car_width="JSON_VALUE(`car`, '$.length' RETURNING float) - 15" + ).fetch(as_dict=True, order_by="car_width") == [ + {"name": "marketing", "car_width": None}, + {"name": "engineering", "car_width": 5.5}, + {"name": "business", "car_width": 85.0}, + ] + + assert ( + (Team & {"name": "engineering"}).proj(car_tire_pressure="car.tire_pressure") + ).fetch1("car_tire_pressure") == "[32, 31, 33, 34]" + + assert np.array_equal( + Team.proj(car_inspected="car.inspected").fetch( + "car_inspected", order_by="name" + ), + np.array([None, "true", None]), + ) + + assert np.array_equal( + Team.proj(car_inspected="car.inspected:unsigned").fetch( + "car_inspected", order_by="name" + ), + np.array([None, 1, None]), + ) diff --git a/tests/test_plugin.py b/tests/test_plugin.py index f70f4c2ef..ddb8b3bfc 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -1,3 +1,4 @@ +import pytest import datajoint.errors as djerr import datajoint.plugin as p import pkg_resources @@ -22,7 +23,8 @@ def test_normal_djerror(): assert e.__cause__ is None -def test_verified_djerror(category="connection"): +@pytest.mark.parametrize("category", ("connection",)) +def test_verified_djerror(category): try: curr_plugins = getattr(p, "{}_plugins".format(category)) setattr( @@ -40,7 +42,8 @@ def test_verified_djerror_type(): test_verified_djerror(category="type") -def test_unverified_djerror(category="connection"): +@pytest.mark.parametrize("category", ("connection",)) +def test_unverified_djerror(category): try: curr_plugins = getattr(p, "{}_plugins".format(category)) setattr( diff --git a/tests/test_relation_u.py b/tests/test_relation_u.py index d225bccbb..50997662d 100644 --- a/tests/test_relation_u.py +++ b/tests/test_relation_u.py @@ -5,25 +5,24 @@ from .schema_simple import * -@pytest.fixture(scope="class") -def setup_class(request, schema_any): - request.cls.user = User() - request.cls.language = Language() - request.cls.subject = Subject() - request.cls.experiment = Experiment() - request.cls.trial = Trial() - request.cls.ephys = Ephys() - request.cls.channel = Ephys.Channel() - request.cls.img = Image() - request.cls.trash = UberTrash() - - class TestU: """ Test tables: insert, delete """ - def test_restriction(self, setup_class): + @classmethod + def setup_class(cls): + cls.user = User() + cls.language = Language() + cls.subject = Subject() + cls.experiment = Experiment() + cls.trial = Trial() + cls.ephys = Ephys() + cls.channel = Ephys.Channel() + cls.img = Image() + cls.trash = UberTrash() + + def test_restriction(self, schema_any): language_set = {s[1] for s in self.language.contents} rel = dj.U("language") & self.language assert list(rel.heading.names) == ["language"] @@ -35,15 +34,15 @@ def test_restriction(self, setup_class): assert list(rel.primary_key) == list((rel & "trial_id>3").primary_key) assert list((dj.U("start_time") & self.trial).primary_key) == ["start_time"] - def test_invalid_restriction(self, setup_class): + def test_invalid_restriction(self, schema_any): with raises(dj.DataJointError): result = dj.U("color") & dict(color="red") - def test_ineffective_restriction(self, setup_class): + def test_ineffective_restriction(self, schema_any): rel = self.language & dj.U("language") assert rel.make_sql() == self.language.make_sql() - def test_join(self, setup_class): + def test_join(self, schema_any): rel = self.experiment * dj.U("experiment_date") assert self.experiment.primary_key == ["subject_id", "experiment_id"] assert rel.primary_key == self.experiment.primary_key + ["experiment_date"] @@ -52,16 +51,16 @@ def test_join(self, setup_class): assert self.experiment.primary_key == ["subject_id", "experiment_id"] assert rel.primary_key == self.experiment.primary_key + ["experiment_date"] - def test_invalid_join(self, setup_class): + def test_invalid_join(self, schema_any): with raises(dj.DataJointError): rel = dj.U("language") * dict(language="English") - def test_repr_without_attrs(self, setup_class): + def test_repr_without_attrs(self, schema_any): """test dj.U() display""" query = dj.U().aggr(Language, n="count(*)") repr(query) - def test_aggregations(self, setup_class): + def test_aggregations(self, schema_any): lang = Language() # test total aggregation on expression object n1 = dj.U().aggr(lang, n="count(*)").fetch1("n") @@ -73,13 +72,13 @@ def test_aggregations(self, setup_class): assert len(rel) == len(set(l[1] for l in Language.contents)) assert (rel & 'language="English"').fetch1("number_of_speakers") == 3 - def test_argmax(self, setup_class): + def test_argmax(self, schema_any): rel = TTest() # get the tuples corresponding to the maximum value mx = (rel * dj.U().aggr(rel, mx="max(value)")) & "mx=value" assert mx.fetch("value")[0] == max(rel.fetch("value")) - def test_aggr(self, setup_class, schema_simp): + def test_aggr(self, schema_any, schema_simp): rel = ArgmaxTest() amax1 = (dj.U("val") * rel) & dj.U("secondary_key").aggr(rel, val="min(val)") amax2 = (dj.U("val") * rel) * dj.U("secondary_key").aggr(rel, val="min(val)") diff --git a/tests/test_schema_keywords.py b/tests/test_schema_keywords.py index c8b7d5a24..1cad98efd 100644 --- a/tests/test_schema_keywords.py +++ b/tests/test_schema_keywords.py @@ -33,7 +33,7 @@ class D(B): source = A -@pytest.fixture(scope="module") +@pytest.fixture def schema(connection_test): schema = dj.Schema(PREFIX + "_keywords", connection=connection_test) schema(A) diff --git a/tests/test_utils.py b/tests/test_utils.py index 936badb1c..04325db56 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,14 +6,6 @@ import pytest -def setup(): - pass - - -def teardown(): - pass - - def test_from_camel_case(): assert from_camel_case("AllGroups") == "all_groups" with pytest.raises(DataJointError):