Skip to content

Commit b90e194

Browse files
tomachamotl
authored andcommitted
Workaround for sqlalchemy sslmode error
1 parent 3816479 commit b90e194

File tree

1 file changed

+45
-22
lines changed

1 file changed

+45
-22
lines changed

cratedb_toolkit/util/database.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
import typing as t
99
from pathlib import Path
1010

11+
import requests
1112
import sqlalchemy as sa
1213
import sqlparse
1314
from boltons.urlutils import URL
15+
from crate import client as crate_client
1416
from cratedb_sqlparse import sqlparse as sqlparse_cratedb
1517
from sqlalchemy.exc import ProgrammingError
1618
from sqlalchemy.sql.elements import AsBoolean
@@ -52,9 +54,16 @@ def __init__(self, dburi: str, echo: bool = False, internal: bool = False, jwt:
5254
raise ValueError("Database URI must be specified")
5355
if dburi.startswith("crate://"):
5456
self.dburi = dburi
57+
# Detect native override
58+
self.native = "native=true" in dburi.lower()
59+
self.native_url = dburi.replace("crate://", "https://").split("?")[0]
60+
self.dburi_clean = dburi.replace("?native=True", "").replace("&native=True", "")
5561
else:
5662
address = DatabaseAddress.from_string(dburi)
5763
self.dburi = address.dburi
64+
self.native = False
65+
self.dburi_clean = self.dburi
66+
5867
self.internal = internal
5968
self.jwt = jwt
6069
self.ctx: contextlib.AbstractContextManager
@@ -63,10 +72,16 @@ def __init__(self, dburi: str, echo: bool = False, internal: bool = False, jwt:
6372
else:
6473
self.ctx = contextlib.nullcontext()
6574
with self.ctx:
66-
self.engine = sa.create_engine(self.dburi, echo=echo)
67-
# TODO: Make that go away.
68-
logger.debug(f"Connecting to CrateDB: {dburi}")
69-
self.connection = self.engine.connect()
75+
if self.native:
76+
self.native_connection = crate_client.connect(
77+
[self.dburi_clean.replace("crate://", "https://")],
78+
verify_ssl_cert=False,
79+
)
80+
logger.debug(f"[Native] Connecting to CrateDB: {self.dburi_clean}")
81+
else:
82+
self.engine = sa.create_engine(self.dburi_clean, echo=echo)
83+
logger.debug(f"[SQLAlchemy] Connecting to CrateDB: {self.dburi_clean}")
84+
self.connection = self.engine.connect()
7085

7186
@staticmethod
7287
def quote_relation_name(ident: str) -> str:
@@ -137,33 +152,41 @@ def run_sql(
137152
return None
138153

139154
def run_sql_real(self, sql: str, parameters: t.Mapping[str, str] = None, records: bool = False):
140-
"""
141-
Invoke an SQL statement and return results.
142-
"""
143155
results = []
144156
for statement in sqlparse.split(sql):
145157
if self.internal:
146158
statement += self.internal_tag
147-
# FIXME: Persistent self.connection risks leaks & thread-unsafety.
148-
# https://github.com/crate/cratedb-toolkit/pull/81#discussion_r2071499204
149-
with self.ctx:
150-
result = self.connection.execute(sa.text(statement), parameters)
151-
data: t.Any
152-
if result.returns_rows:
159+
160+
if self.native:
161+
# Make a native HTTP POST request
162+
response = requests.post(
163+
self.native_url + "/_sql",
164+
verify=False,
165+
json={"stmt": statement, "args": list(parameters.values()) if parameters else []},
166+
headers={"Content-Type": "application/json"},
167+
)
168+
response.raise_for_status()
169+
result = response.json()
170+
153171
if records:
154-
rows = result.mappings().fetchall()
155-
data = [dict(row.items()) for row in rows]
172+
data = [dict(zip(result["cols"], row)) for row in result["rows"]]
156173
else:
157-
data = result.fetchall()
174+
data = result["rows"]
158175
else:
159-
data = None
176+
with self.ctx:
177+
result = self.connection.execute(sa.text(statement), parameters)
178+
if result.returns_rows:
179+
if records:
180+
rows = result.mappings().fetchall()
181+
data = [dict(row.items()) for row in rows]
182+
else:
183+
data = result.fetchall()
184+
else:
185+
data = None
186+
160187
results.append(data)
161188

162-
# Backward-compatibility.
163-
if len(results) == 1:
164-
return results[0]
165-
else:
166-
return results
189+
return results[0] if len(results) == 1 else results
167190

168191
def count_records(self, name: str, errors: Literal["raise", "ignore"] = "raise", where: str = ""):
169192
"""

0 commit comments

Comments
 (0)