diff --git a/tests/conftest.py b/tests/conftest.py index 5a38eef90..a9474b502 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -56,6 +56,15 @@ def enable_filepath_feature(monkeypatch): monkeypatch.delenv(FILEPATH_FEATURE_SWITCH, raising=True) +@pytest.fixture(scope="session") +def db_creds_test() -> Dict: + return dict( + host=os.getenv("DJ_TEST_HOST", "fakeservices.datajoint.io"), + user=os.getenv("DJ_TEST_USER", "datajoint"), + password=os.getenv("DJ_TEST_PASSWORD", "datajoint"), + ) + + @pytest.fixture(scope="session") def db_creds_root() -> Dict: return dict( @@ -142,12 +151,9 @@ def connection_root(connection_root_bare): @pytest.fixture(scope="session") -def connection_test(connection_root): +def connection_test(connection_root, db_creds_test): """Test user database connection.""" database = f"{PREFIX}%%" - credentials = dict( - host=os.getenv("DJ_HOST"), user="datajoint", password="datajoint" - ) permission = "ALL PRIVILEGES" # Create MySQL users @@ -157,14 +163,14 @@ def connection_test(connection_root): # create user if necessary on mysql8 connection_root.query( f""" - CREATE USER IF NOT EXISTS '{credentials["user"]}'@'%%' - IDENTIFIED BY '{credentials["password"]}'; + CREATE USER IF NOT EXISTS '{db_creds_test["user"]}'@'%%' + IDENTIFIED BY '{db_creds_test["password"]}'; """ ) connection_root.query( f""" GRANT {permission} ON `{database}`.* - TO '{credentials["user"]}'@'%%'; + TO '{db_creds_test["user"]}'@'%%'; """ ) else: @@ -173,14 +179,14 @@ def connection_test(connection_root): connection_root.query( f""" GRANT {permission} ON `{database}`.* - TO '{credentials["user"]}'@'%%' - IDENTIFIED BY '{credentials["password"]}'; + TO '{db_creds_test["user"]}'@'%%' + IDENTIFIED BY '{db_creds_test["password"]}'; """ ) - connection = dj.Connection(**credentials) + connection = dj.Connection(**db_creds_test) yield connection - connection_root.query(f"""DROP USER `{credentials["user"]}`""") + connection_root.query(f"""DROP USER `{db_creds_test["user"]}`""") connection.close() diff --git a/tests/test_tls.py b/tests/test_tls.py new file mode 100644 index 000000000..22558af5b --- /dev/null +++ b/tests/test_tls.py @@ -0,0 +1,32 @@ +import pytest +import datajoint as dj +from pymysql.err import OperationalError + + +def test_secure_connection(db_creds_test, connection_test): + result = ( + dj.conn(reset=True, **db_creds_test) + .query("SHOW STATUS LIKE 'Ssl_cipher';") + .fetchone()[1] + ) + assert len(result) > 0 + + +def test_insecure_connection(db_creds_test, connection_test): + result = ( + dj.conn(use_tls=False, reset=True, **db_creds_test) + .query("SHOW STATUS LIKE 'Ssl_cipher';") + .fetchone()[1] + ) + assert result == "" + + +def test_reject_insecure(db_creds_test, connection_test): + with pytest.raises(OperationalError): + dj.conn( + db_creds_test["host"], + user="djssl", + password="djssl", + use_tls=False, + reset=True, + ).query("SHOW STATUS LIKE 'Ssl_cipher';").fetchone()[1]