Skip to content

Commit 52e715d

Browse files
committed
Adjust connect arguments to enable SSL using SQLAlchemy DB URI
1 parent 064e341 commit 52e715d

File tree

3 files changed

+62
-7
lines changed

3 files changed

+62
-7
lines changed

CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ Unreleased
1313

1414
- Adjusted connect arguments to accept credentials within the HTTP URI.
1515

16+
- Adjusted connect arguments to enable SSL using SQLAlchemy DB URI.
17+
1618
2020/09/28 0.26.0
1719
=================
1820

src/crate/client/sqlalchemy/dialect.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sqlalchemy import types as sqltypes
2626
from sqlalchemy.engine import default, reflection
2727
from sqlalchemy.sql import functions
28+
from sqlalchemy.util import asbool, to_list
2829

2930
from .compiler import (
3031
CrateCompiler,
@@ -194,8 +195,16 @@ def connect(self, host=None, port=None, *args, **kwargs):
194195
server = '{0}:{1}'.format(host, port or '4200')
195196
if 'servers' in kwargs:
196197
server = kwargs.pop('servers')
197-
if server:
198-
return self.dbapi.connect(servers=server, **kwargs)
198+
servers = to_list(server)
199+
if servers:
200+
# Evaluate "ssl" connection URI query parameter.
201+
# TODO: Evaluate more parameters like `ssl_ca`, `ssl_key`, `ssl_cert`,
202+
# `ssl_capath` and `ssl_cipher`.
203+
if "ssl" in kwargs:
204+
use_ssl = asbool(kwargs.pop("ssl"))
205+
if use_ssl:
206+
servers = ["https://" + server for server in servers]
207+
return self.dbapi.connect(servers=servers, **kwargs)
199208
return self.dbapi.connect(**kwargs)
200209

201210
def _get_default_schema_name(self, connection):

src/crate/client/sqlalchemy/tests/connection_test.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,59 @@
2121

2222
from unittest import TestCase
2323
import sqlalchemy as sa
24+
from sqlalchemy.exc import NoSuchModuleError
2425

2526

2627
class SqlAlchemyConnectionTest(TestCase):
2728

28-
def setUp(self):
29-
self.engine = sa.create_engine('crate://')
30-
self.connection = self.engine.connect()
29+
def test_connection_server_uri_unknown_sa_plugin(self):
30+
with self.assertRaises(NoSuchModuleError):
31+
sa.create_engine("foobar://otherhost:19201")
3132

3233
def test_default_connection(self):
3334
engine = sa.create_engine('crate://')
3435
conn = engine.raw_connection()
3536
self.assertEqual("<Connection <Client ['http://127.0.0.1:4200']>>",
3637
repr(conn.connection))
3738

38-
def test_connection_server(self):
39+
def test_connection_server_uri_http(self):
3940
engine = sa.create_engine(
4041
"crate://otherhost:19201")
4142
conn = engine.raw_connection()
4243
self.assertEqual("<Connection <Client ['http://otherhost:19201']>>",
4344
repr(conn.connection))
4445

45-
def test_connection_multiple_server(self):
46+
def test_connection_server_uri_https(self):
47+
engine = sa.create_engine(
48+
"crate://otherhost:19201/?ssl=true")
49+
conn = engine.raw_connection()
50+
self.assertEqual("<Connection <Client ['https://otherhost:19201']>>",
51+
repr(conn.connection))
52+
53+
def test_connection_server_uri_invalid_port(self):
54+
with self.assertRaises(ValueError) as context:
55+
sa.create_engine("crate://foo:bar")
56+
self.assertTrue("invalid literal for int() with base 10: 'bar'" in str(context.exception))
57+
58+
def test_connection_server_uri_https_with_trusted_user(self):
59+
engine = sa.create_engine(
60+
"crate://foo@otherhost:19201/?ssl=true")
61+
conn = engine.raw_connection()
62+
self.assertEqual("<Connection <Client ['https://otherhost:19201']>>",
63+
repr(conn.connection))
64+
self.assertEqual(conn.connection.client.username, "foo")
65+
self.assertEqual(conn.connection.client.password, None)
66+
67+
def test_connection_server_uri_https_with_credentials(self):
68+
engine = sa.create_engine(
69+
"crate://foo:bar@otherhost:19201/?ssl=true")
70+
conn = engine.raw_connection()
71+
self.assertEqual("<Connection <Client ['https://otherhost:19201']>>",
72+
repr(conn.connection))
73+
self.assertEqual(conn.connection.client.username, "foo")
74+
self.assertEqual(conn.connection.client.password, "bar")
75+
76+
def test_connection_multiple_server_http(self):
4677
engine = sa.create_engine(
4778
"crate://", connect_args={
4879
'servers': ['localhost:4201', 'localhost:4202']
@@ -53,3 +84,16 @@ def test_connection_multiple_server(self):
5384
"<Connection <Client ['http://localhost:4201', " +
5485
"'http://localhost:4202']>>",
5586
repr(conn.connection))
87+
88+
def test_connection_multiple_server_https(self):
89+
engine = sa.create_engine(
90+
"crate://", connect_args={
91+
'servers': ['localhost:4201', 'localhost:4202'],
92+
'ssl': True,
93+
}
94+
)
95+
conn = engine.raw_connection()
96+
self.assertEqual(
97+
"<Connection <Client ['https://localhost:4201', " +
98+
"'https://localhost:4202']>>",
99+
repr(conn.connection))

0 commit comments

Comments
 (0)