Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions cratedb_toolkit/util/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import typing as t
from pathlib import Path

import requests
import sqlalchemy as sa
import sqlparse
from boltons.urlutils import URL
from crate import client as crate_client
from cratedb_sqlparse import sqlparse as sqlparse_cratedb
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.sql.elements import AsBoolean
Expand Down Expand Up @@ -52,9 +54,16 @@ def __init__(self, dburi: str, echo: bool = False, internal: bool = False, jwt:
raise ValueError("Database URI must be specified")
if dburi.startswith("crate://"):
self.dburi = dburi
# Detect native override
self.native = "native=true" in dburi.lower()
self.native_url = dburi.replace("crate://", "https://").split("?")[0]
self.dburi_clean = dburi.replace("?native=True", "").replace("&native=True", "")
else:
address = DatabaseAddress.from_string(dburi)
self.dburi = address.dburi
self.native = False
self.dburi_clean = self.dburi

self.internal = internal
self.jwt = jwt
self.ctx: contextlib.AbstractContextManager
Expand All @@ -63,10 +72,16 @@ def __init__(self, dburi: str, echo: bool = False, internal: bool = False, jwt:
else:
self.ctx = contextlib.nullcontext()
with self.ctx:
self.engine = sa.create_engine(self.dburi, echo=echo)
# TODO: Make that go away.
logger.debug(f"Connecting to CrateDB: {dburi}")
self.connection = self.engine.connect()
if self.native:
self.native_connection = crate_client.connect(
[self.dburi_clean.replace("crate://", "https://")],
verify_ssl_cert=False,
)
logger.debug(f"[Native] Connecting to CrateDB: {self.dburi_clean}")
else:
self.engine = sa.create_engine(self.dburi_clean, echo=echo)
logger.debug(f"[SQLAlchemy] Connecting to CrateDB: {self.dburi_clean}")
self.connection = self.engine.connect()

@staticmethod
def quote_relation_name(ident: str) -> str:
Expand Down Expand Up @@ -137,33 +152,41 @@ def run_sql(
return None

def run_sql_real(self, sql: str, parameters: t.Mapping[str, str] = None, records: bool = False):
"""
Invoke an SQL statement and return results.
"""
results = []
for statement in sqlparse.split(sql):
if self.internal:
statement += self.internal_tag
# FIXME: Persistent self.connection risks leaks & thread-unsafety.
# https://github.com/crate/cratedb-toolkit/pull/81#discussion_r2071499204
with self.ctx:
result = self.connection.execute(sa.text(statement), parameters)
data: t.Any
if result.returns_rows:

if self.native:
# Make a native HTTP POST request
response = requests.post(
self.native_url + "/_sql",
verify=False,
json={"stmt": statement, "args": list(parameters.values()) if parameters else []},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
result = response.json()

if records:
rows = result.mappings().fetchall()
data = [dict(row.items()) for row in rows]
data = [dict(zip(result["cols"], row)) for row in result["rows"]]
else:
data = result.fetchall()
data = result["rows"]
else:
data = None
with self.ctx:
result = self.connection.execute(sa.text(statement), parameters)
if result.returns_rows:
if records:
rows = result.mappings().fetchall()
data = [dict(row.items()) for row in rows]
else:
data = result.fetchall()
else:
data = None

results.append(data)

# Backward-compatibility.
if len(results) == 1:
return results[0]
else:
return results
return results[0] if len(results) == 1 else results

def count_records(self, name: str, errors: Literal["raise", "ignore"] = "raise", where: str = ""):
"""
Expand Down
Loading