Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import networkx as nx
import json
from pathlib import Path
import tempfile
from datajoint import errors
from datajoint.errors import ADAPTED_TYPE_SWITCH, FILEPATH_FEATURE_SWITCH
from . import (
Expand Down Expand Up @@ -176,23 +175,52 @@ def connection_test(connection_root):


@pytest.fixture(scope="session")
def stores_config():
def stores_config(tmpdir_factory):
stores_config = {
"raw": dict(protocol="file", location=tempfile.mkdtemp()),
"raw": dict(protocol="file", location=tmpdir_factory.mktemp("raw")),
"repo": dict(
stage=tempfile.mkdtemp(), protocol="file", location=tempfile.mkdtemp()
stage=tmpdir_factory.mktemp("repo"),
protocol="file",
location=tmpdir_factory.mktemp("repo"),
),
"repo-s3": dict(
S3_CONN_INFO, protocol="s3", location="dj/repo", stage=tempfile.mkdtemp()
S3_CONN_INFO,
protocol="s3",
location="dj/repo",
stage=tmpdir_factory.mktemp("repo-s3"),
),
"local": dict(
protocol="file", location=tmpdir_factory.mktemp("local"), subfolding=(1, 1)
),
"local": dict(protocol="file", location=tempfile.mkdtemp(), subfolding=(1, 1)),
"share": dict(
S3_CONN_INFO, protocol="s3", location="dj/store/repo", subfolding=(2, 4)
),
}
return stores_config


@pytest.fixture
def mock_stores(stores_config):
og_stores_config = dj.config.get("stores")
dj.config["stores"] = stores_config
yield
if og_stores_config is None:
del dj.config["stores"]
else:
dj.config["stores"] = og_stores_config


@pytest.fixture
def mock_cache(tmpdir_factory):
og_cache = dj.config.get("cache")
dj.config["cache"] = tmpdir_factory.mktemp("cache")
yield
if og_cache is None:
del dj.config["cache"]
else:
dj.config["cache"] = og_cache


@pytest.fixture
def schema_any(connection_test):
schema_any = dj.Schema(
Expand Down Expand Up @@ -287,15 +315,12 @@ def schema_adv(connection_test):


@pytest.fixture
def schema_ext(connection_test, stores_config, enable_filepath_feature):
def schema_ext(connection_test, enable_filepath_feature, mock_stores, mock_cache):
schema = dj.Schema(
PREFIX + "_extern",
context=schema_external.LOCALS_EXTERNAL,
connection=connection_test,
)
dj.config["stores"] = stores_config
dj.config["cache"] = tempfile.mkdtemp()

schema(schema_external.Simple)
schema(schema_external.SimpleRemote)
schema(schema_external.Seed)
Expand Down
131 changes: 131 additions & 0 deletions tests/test_external.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import numpy as np
from numpy.testing import assert_array_equal
from datajoint.external import ExternalTable
from datajoint.blob import pack, unpack
import datajoint as dj
from .schema_external import SimpleRemote, Simple
import os


def test_external_put(schema_ext, mock_stores, mock_cache):
"""
external storage put and get and remove
"""
ext = ExternalTable(
schema_ext.connection, store="raw", database=schema_ext.database
)
initial_length = len(ext)
input_ = np.random.randn(3, 7, 8)
count = 7
extra = 3
for i in range(count):
hash1 = ext.put(pack(input_))
for i in range(extra):
hash2 = ext.put(pack(np.random.randn(4, 3, 2)))

fetched_hashes = ext.fetch("hash")
assert all(hash in fetched_hashes for hash in (hash1, hash2))
assert len(ext) == initial_length + 1 + extra

output_ = unpack(ext.get(hash1))
assert_array_equal(input_, output_)


class TestLeadingSlash:
def test_s3_leading_slash(self, schema_ext, mock_stores, mock_cache, minio_client):
"""
s3 external storage configured with leading slash
"""
self._leading_slash(schema_ext, index=100, store="share")

def test_file_leading_slash(
self, schema_ext, mock_stores, mock_cache, minio_client
):
"""
File external storage configured with leading slash
"""
self._leading_slash(schema_ext, index=200, store="local")

def _leading_slash(self, schema_ext, index, store):
oldConfig = dj.config["stores"][store]["location"]
value = np.array([1, 2, 3])

id = index
dj.config["stores"][store]["location"] = "leading/slash/test"
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

id = index + 1
dj.config["stores"][store]["location"] = "/leading/slash/test"
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

id = index + 2
dj.config["stores"][store]["location"] = "leading\\slash\\test"
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

id = index + 3
dj.config["stores"][store]["location"] = "f:\\leading\\slash\\test"
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

id = index + 4
dj.config["stores"][store]["location"] = "f:\\leading/slash\\test"
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

id = index + 5
dj.config["stores"][store]["location"] = "/"
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

id = index + 6
dj.config["stores"][store]["location"] = "C:\\"
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

id = index + 7
dj.config["stores"][store]["location"] = ""
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

dj.config["stores"][store]["location"] = oldConfig


def test_remove_fail(schema_ext, mock_stores, mock_cache, minio_client):
"""
https://github.com/datajoint/datajoint-python/issues/953
"""
assert dj.config["stores"]["local"]["location"]

data = dict(simple=2, item=[1, 2, 3])
Simple.insert1(data)
path1 = dj.config["stores"]["local"]["location"] + "/djtest_extern/4/c/"
currentMode = int(oct(os.stat(path1).st_mode), 8)
os.chmod(path1, 0o40555)
(Simple & "simple=2").delete()
listOfErrors = schema_ext.external["local"].delete(delete_external_files=True)

assert (
len(schema_ext.external["local"] & dict(hash=listOfErrors[0][0])) == 1
), "unexpected number of rows in external table"
# ---------------------CLEAN UP--------------------
os.chmod(path1, currentMode)
listOfErrors = schema_ext.external["local"].delete(delete_external_files=True)
50 changes: 50 additions & 0 deletions tests/test_external_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from numpy.testing import assert_almost_equal
import datajoint as dj
from . import schema_external


def test_heading(schema_ext, mock_stores):
heading = schema_external.Simple().heading
assert "item" in heading
assert heading["item"].is_external


def test_insert_and_fetch(schema_ext, mock_stores, mock_cache):
original_list = [1, 3, 8]
schema_external.Simple().insert1(dict(simple=1, item=original_list))
# test fetch
q = (schema_external.Simple() & {"simple": 1}).fetch("item")[0]
assert list(q) == original_list
# test fetch1 as a tuple
q = (schema_external.Simple() & {"simple": 1}).fetch1("item")
assert list(q) == original_list
# test fetch1 as a dict
q = (schema_external.Simple() & {"simple": 1}).fetch1()
assert list(q["item"]) == original_list
# test without cache
previous_cache = dj.config["cache"]
dj.config["cache"] = None
q = (schema_external.Simple() & {"simple": 1}).fetch1()
assert list(q["item"]) == original_list
# test with cache
dj.config["cache"] = previous_cache
q = (schema_external.Simple() & {"simple": 1}).fetch1()
assert list(q["item"]) == original_list


def test_populate(schema_ext, mock_stores):
image = schema_external.Image()
image.populate()
remaining, total = image.progress()
assert (
total == len(schema_external.Dimension() * schema_external.Seed())
and remaining == 0
)
for img, neg, dimensions in zip(
*(image * schema_external.Dimension()).fetch("img", "neg", "dimensions")
):
assert list(img.shape) == list(dimensions)
assert_almost_equal(img, -neg)
image.delete()
for external_table in image.external.values():
external_table.delete(display_progress=False, delete_external_files=True)