Skip to content

Commit 1814f9b

Browse files
authored
Merge pull request #1116 from ethho/dev-tests
PLAT-138 Follow Up to #1114
2 parents 6d77392 + eff463d commit 1814f9b

16 files changed

+564
-269
lines changed

tests/__init__.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,24 @@
33
import pytest
44
import os
55

6-
PREFIX = "djtest"
6+
PREFIX = os.environ.get("DJ_TEST_DB_PREFIX", "djtest")
7+
8+
# Connection for testing
9+
CONN_INFO = dict(
10+
host=os.environ.get("DJ_TEST_HOST", "fakeservices.datajoint.io"),
11+
user=os.environ.get("DJ_TEST_USER", "datajoint"),
12+
password=os.environ.get("DJ_TEST_PASSWORD", "datajoint"),
13+
)
714

815
CONN_INFO_ROOT = dict(
9-
host=os.getenv("DJ_HOST"),
10-
user=os.getenv("DJ_USER"),
11-
password=os.getenv("DJ_PASS"),
16+
host=os.environ.get("DJ_HOST", "fakeservices.datajoint.io"),
17+
user=os.environ.get("DJ_USER", "root"),
18+
password=os.environ.get("DJ_PASS", "simple"),
19+
)
20+
21+
S3_CONN_INFO = dict(
22+
endpoint=os.environ.get("S3_ENDPOINT", "fakeservices.datajoint.io"),
23+
access_key=os.environ.get("S3_ACCESS_KEY", "datajoint"),
24+
secret_key=os.environ.get("S3_SECRET_KEY", "datajoint"),
25+
bucket=os.environ.get("S3_BUCKET", "datajoint.test"),
1226
)

tests/conftest.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,38 @@
11
import datajoint as dj
22
from packaging import version
33
import os
4+
import minio
5+
import urllib3
6+
import certifi
7+
import shutil
48
import pytest
5-
from . import PREFIX, schema, schema_simple, schema_advanced
9+
import networkx as nx
10+
import json
11+
from pathlib import Path
12+
import tempfile
13+
from datajoint import errors
14+
from datajoint.errors import ADAPTED_TYPE_SWITCH, FILEPATH_FEATURE_SWITCH
15+
from . import (
16+
PREFIX,
17+
CONN_INFO,
18+
S3_CONN_INFO,
19+
schema,
20+
schema_simple,
21+
schema_advanced,
22+
schema_adapted,
23+
)
624

7-
namespace = locals()
25+
26+
@pytest.fixture(scope="session")
27+
def monkeysession():
28+
with pytest.MonkeyPatch.context() as mp:
29+
yield mp
30+
31+
32+
@pytest.fixture(scope="module")
33+
def monkeymodule():
34+
with pytest.MonkeyPatch.context() as mp:
35+
yield mp
836

937

1038
@pytest.fixture(scope="session")
@@ -64,11 +92,12 @@ def connection_test(connection_root):
6492
connection.close()
6593

6694

67-
@pytest.fixture(scope="module")
95+
@pytest.fixture
6896
def schema_any(connection_test):
6997
schema_any = dj.Schema(
70-
PREFIX + "_test1", schema.__dict__, connection=connection_test
98+
PREFIX + "_test1", schema.LOCALS_ANY, connection=connection_test
7199
)
100+
assert schema.LOCALS_ANY, "LOCALS_ANY is empty"
72101
schema_any(schema.TTest)
73102
schema_any(schema.TTest2)
74103
schema_any(schema.TTest3)
@@ -109,10 +138,10 @@ def schema_any(connection_test):
109138
schema_any.drop()
110139

111140

112-
@pytest.fixture(scope="module")
141+
@pytest.fixture
113142
def schema_simp(connection_test):
114143
schema = dj.Schema(
115-
PREFIX + "_relational", schema_simple.__dict__, connection=connection_test
144+
PREFIX + "_relational", schema_simple.LOCALS_SIMPLE, connection=connection_test
116145
)
117146
schema(schema_simple.IJ)
118147
schema(schema_simple.JI)
@@ -136,10 +165,12 @@ def schema_simp(connection_test):
136165
schema.drop()
137166

138167

139-
@pytest.fixture(scope="module")
168+
@pytest.fixture
140169
def schema_adv(connection_test):
141170
schema = dj.Schema(
142-
PREFIX + "_advanced", schema_advanced.__dict__, connection=connection_test
171+
PREFIX + "_advanced",
172+
schema_advanced.LOCALS_ADVANCED,
173+
connection=connection_test,
143174
)
144175
schema(schema_advanced.Person)
145176
schema(schema_advanced.Parent)
@@ -152,3 +183,30 @@ def schema_adv(connection_test):
152183
schema(schema_advanced.GlobalSynapse)
153184
yield schema
154185
schema.drop()
186+
187+
188+
@pytest.fixture
189+
def httpClient():
190+
# Initialize httpClient with relevant timeout.
191+
httpClient = urllib3.PoolManager(
192+
timeout=30,
193+
cert_reqs="CERT_REQUIRED",
194+
ca_certs=certifi.where(),
195+
retries=urllib3.Retry(
196+
total=3, backoff_factor=0.2, status_forcelist=[500, 502, 503, 504]
197+
),
198+
)
199+
yield httpClient
200+
201+
202+
@pytest.fixture
203+
def minioClient():
204+
# Initialize minioClient with an endpoint and access/secret keys.
205+
minioClient = minio.Minio(
206+
S3_CONN_INFO["endpoint"],
207+
access_key=S3_CONN_INFO["access_key"],
208+
secret_key=S3_CONN_INFO["secret_key"],
209+
secure=True,
210+
http_client=httpClient,
211+
)
212+
yield minioClient

tests/schema.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import datajoint as dj
88
import inspect
99

10-
LOCALS_ANY = locals()
11-
1210

1311
class TTest(dj.Lookup):
1412
"""
@@ -33,15 +31,15 @@ class TTest2(dj.Manual):
3331

3432
class TTest3(dj.Manual):
3533
definition = """
36-
key : int
34+
key : int
3735
---
3836
value : varchar(300)
3937
"""
4038

4139

4240
class NullableNumbers(dj.Manual):
4341
definition = """
44-
key : int
42+
key : int
4543
---
4644
fvalue = null : float
4745
dvalue = null : double
@@ -450,3 +448,7 @@ class Longblob(dj.Manual):
450448
---
451449
data: longblob
452450
"""
451+
452+
453+
LOCALS_ANY = {k: v for k, v in locals().items() if inspect.isclass(v)}
454+
__all__ = list(LOCALS_ANY)

tests/schema_adapted.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import datajoint as dj
2+
import inspect
3+
import networkx as nx
4+
import json
5+
from pathlib import Path
6+
import tempfile
7+
8+
9+
class GraphAdapter(dj.AttributeAdapter):
10+
attribute_type = "longblob" # this is how the attribute will be declared
11+
12+
@staticmethod
13+
def get(obj):
14+
# convert edge list into a graph
15+
return nx.Graph(obj)
16+
17+
@staticmethod
18+
def put(obj):
19+
# convert graph object into an edge list
20+
assert isinstance(obj, nx.Graph)
21+
return list(obj.edges)
22+
23+
24+
class LayoutToFilepath(dj.AttributeAdapter):
25+
"""
26+
An adapted data type that saves a graph layout into fixed filepath
27+
"""
28+
29+
attribute_type = "filepath@repo-s3"
30+
31+
@staticmethod
32+
def get(path):
33+
with open(path, "r") as f:
34+
return json.load(f)
35+
36+
@staticmethod
37+
def put(layout):
38+
path = Path(dj.config["stores"]["repo-s3"]["stage"], "layout.json")
39+
with open(str(path), "w") as f:
40+
json.dump(layout, f)
41+
return path
42+
43+
44+
class Connectivity(dj.Manual):
45+
definition = """
46+
connid : int
47+
---
48+
conn_graph = null : <graph>
49+
"""
50+
51+
52+
class Layout(dj.Manual):
53+
definition = """
54+
# stores graph layout
55+
-> Connectivity
56+
---
57+
layout: <layout_to_filepath>
58+
"""
59+
60+
61+
LOCALS_ADAPTED = {k: v for k, v in locals().items() if inspect.isclass(v)}
62+
__all__ = list(LOCALS_ADAPTED)

tests/schema_advanced.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import datajoint as dj
2-
3-
LOCALS_ADVANCED = locals()
2+
import inspect
43

54

65
class Person(dj.Manual):
@@ -135,3 +134,7 @@ class GlobalSynapse(dj.Manual):
135134
-> Cell.proj(pre_slice="slice", pre_cell="cell")
136135
-> Cell.proj(post_slice="slice", post_cell="cell")
137136
"""
137+
138+
139+
LOCALS_ADVANCED = {k: v for k, v in locals().items() if inspect.isclass(v)}
140+
__all__ = list(LOCALS_ADVANCED)

tests/schema_simple.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
import faker
1010
import numpy as np
1111
from datetime import date, timedelta
12-
13-
LOCALS_SIMPLE = locals()
12+
import inspect
1413

1514

1615
class IJ(dj.Lookup):
@@ -237,8 +236,8 @@ class ReservedWord(dj.Manual):
237236
# Test of SQL reserved words
238237
key : int
239238
---
240-
in : varchar(25)
241-
from : varchar(25)
239+
in : varchar(25)
240+
from : varchar(25)
242241
int : int
243242
select : varchar(25)
244243
"""
@@ -260,3 +259,7 @@ class OutfitPiece(dj.Part, dj.Lookup):
260259
piece: varchar(20)
261260
"""
262261
contents = [(0, "jeans"), (0, "sneakers"), (0, "polo")]
262+
263+
264+
LOCALS_SIMPLE = {k: v for k, v in locals().items() if inspect.isclass(v)}
265+
__all__ = list(LOCALS_SIMPLE)

0 commit comments

Comments
 (0)