Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import networkx as nx
import json
from pathlib import Path
import tempfile
from datajoint import errors
from datajoint.errors import ADAPTED_TYPE_SWITCH, FILEPATH_FEATURE_SWITCH
from . import (
Expand Down Expand Up @@ -176,23 +175,52 @@ def connection_test(connection_root):


@pytest.fixture(scope="session")
def stores_config():
def stores_config(tmpdir_factory):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain tmpdir_factory for me?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's similar to the builtin tmpdir, except that it's a fixture that will ensure that the temp dirs for different tests have different base directories. Not as important in this case, since the fixture is session-scoped, but if we changed the scope to function this would be important.

https://docs.pytest.org/en/6.2.x/tmpdir.html#the-tmpdir-factory-fixture

stores_config = {
"raw": dict(protocol="file", location=tempfile.mkdtemp()),
"raw": dict(protocol="file", location=tmpdir_factory.mktemp("raw")),
"repo": dict(
stage=tempfile.mkdtemp(), protocol="file", location=tempfile.mkdtemp()
stage=tmpdir_factory.mktemp("repo"),
protocol="file",
location=tmpdir_factory.mktemp("repo"),
),
"repo-s3": dict(
S3_CONN_INFO, protocol="s3", location="dj/repo", stage=tempfile.mkdtemp()
S3_CONN_INFO,
protocol="s3",
location="dj/repo",
stage=tmpdir_factory.mktemp("repo-s3"),
),
"local": dict(
protocol="file", location=tmpdir_factory.mktemp("local"), subfolding=(1, 1)
),
"local": dict(protocol="file", location=tempfile.mkdtemp(), subfolding=(1, 1)),
"share": dict(
S3_CONN_INFO, protocol="s3", location="dj/store/repo", subfolding=(2, 4)
),
}
return stores_config


@pytest.fixture
def mock_stores(stores_config):
og_stores_config = dj.config.get("stores")
dj.config["stores"] = stores_config
yield
if og_stores_config is None:
del dj.config["stores"]
else:
dj.config["stores"] = og_stores_config


@pytest.fixture
def mock_cache(tmpdir_factory):
og_cache = dj.config.get("cache")
dj.config["cache"] = tmpdir_factory.mktemp("cache")
yield
if og_cache is None:
del dj.config["cache"]
else:
dj.config["cache"] = og_cache


@pytest.fixture
def schema_any(connection_test):
schema_any = dj.Schema(
Expand Down Expand Up @@ -287,15 +315,12 @@ def schema_adv(connection_test):


@pytest.fixture
def schema_ext(connection_test, stores_config, enable_filepath_feature):
def schema_ext(connection_test, enable_filepath_feature, mock_stores, mock_cache):
schema = dj.Schema(
PREFIX + "_extern",
context=schema_external.LOCALS_EXTERNAL,
connection=connection_test,
)
dj.config["stores"] = stores_config
dj.config["cache"] = tempfile.mkdtemp()

schema(schema_external.Simple)
schema(schema_external.SimpleRemote)
schema(schema_external.Seed)
Expand Down
131 changes: 131 additions & 0 deletions tests/test_external.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import numpy as np
from numpy.testing import assert_array_equal
from datajoint.external import ExternalTable
from datajoint.blob import pack, unpack
import datajoint as dj
from .schema_external import SimpleRemote, Simple
import os


def test_external_put(schema_ext, mock_stores, mock_cache):
"""
external storage put and get and remove
"""
ext = ExternalTable(
schema_ext.connection, store="raw", database=schema_ext.database
)
initial_length = len(ext)
input_ = np.random.randn(3, 7, 8)
count = 7
extra = 3
for i in range(count):
hash1 = ext.put(pack(input_))
for i in range(extra):
hash2 = ext.put(pack(np.random.randn(4, 3, 2)))

fetched_hashes = ext.fetch("hash")
assert all(hash in fetched_hashes for hash in (hash1, hash2))
assert len(ext) == initial_length + 1 + extra

output_ = unpack(ext.get(hash1))
assert_array_equal(input_, output_)


class TestLeadingSlash:
def test_s3_leading_slash(self, schema_ext, mock_stores, mock_cache, minio_client):
"""
s3 external storage configured with leading slash
"""
self._leading_slash(schema_ext, index=100, store="share")

def test_file_leading_slash(
self, schema_ext, mock_stores, mock_cache, minio_client
):
"""
File external storage configured with leading slash
"""
self._leading_slash(schema_ext, index=200, store="local")

def _leading_slash(self, schema_ext, index, store):
oldConfig = dj.config["stores"][store]["location"]
value = np.array([1, 2, 3])

id = index
dj.config["stores"][store]["location"] = "leading/slash/test"
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

id = index + 1
dj.config["stores"][store]["location"] = "/leading/slash/test"
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

id = index + 2
dj.config["stores"][store]["location"] = "leading\\slash\\test"
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

id = index + 3
dj.config["stores"][store]["location"] = "f:\\leading\\slash\\test"
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

id = index + 4
dj.config["stores"][store]["location"] = "f:\\leading/slash\\test"
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

id = index + 5
dj.config["stores"][store]["location"] = "/"
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

id = index + 6
dj.config["stores"][store]["location"] = "C:\\"
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

id = index + 7
dj.config["stores"][store]["location"] = ""
SimpleRemote.insert([{"simple": id, "item": value}])
assert np.array_equal(
value, (SimpleRemote & "simple={}".format(id)).fetch1("item")
)

dj.config["stores"][store]["location"] = oldConfig


def test_remove_fail(schema_ext, mock_stores, mock_cache, minio_client):
"""
https://github.com/datajoint/datajoint-python/issues/953
"""
assert dj.config["stores"]["local"]["location"]

data = dict(simple=2, item=[1, 2, 3])
Simple.insert1(data)
path1 = dj.config["stores"]["local"]["location"] + "/djtest_extern/4/c/"
currentMode = int(oct(os.stat(path1).st_mode), 8)
os.chmod(path1, 0o40555)
(Simple & "simple=2").delete()
listOfErrors = schema_ext.external["local"].delete(delete_external_files=True)

assert (
len(schema_ext.external["local"] & dict(hash=listOfErrors[0][0])) == 1
), "unexpected number of rows in external table"
# ---------------------CLEAN UP--------------------
os.chmod(path1, currentMode)
listOfErrors = schema_ext.external["local"].delete(delete_external_files=True)