diff --git a/cratedb_toolkit/util/database.py b/cratedb_toolkit/util/database.py index b26bb7d2..8073d3e6 100644 --- a/cratedb_toolkit/util/database.py +++ b/cratedb_toolkit/util/database.py @@ -8,9 +8,11 @@ import typing as t from pathlib import Path +import requests import sqlalchemy as sa import sqlparse from boltons.urlutils import URL +from crate import client as crate_client from cratedb_sqlparse import sqlparse as sqlparse_cratedb from sqlalchemy.exc import ProgrammingError from sqlalchemy.sql.elements import AsBoolean @@ -52,9 +54,16 @@ def __init__(self, dburi: str, echo: bool = False, internal: bool = False, jwt: raise ValueError("Database URI must be specified") if dburi.startswith("crate://"): self.dburi = dburi + # Detect native override + self.native = "native=true" in dburi.lower() + self.native_url = dburi.replace("crate://", "https://").split("?")[0] + self.dburi_clean = dburi.replace("?native=True", "").replace("&native=True", "") else: address = DatabaseAddress.from_string(dburi) self.dburi = address.dburi + self.native = False + self.dburi_clean = self.dburi + self.internal = internal self.jwt = jwt self.ctx: contextlib.AbstractContextManager @@ -63,10 +72,16 @@ def __init__(self, dburi: str, echo: bool = False, internal: bool = False, jwt: else: self.ctx = contextlib.nullcontext() with self.ctx: - self.engine = sa.create_engine(self.dburi, echo=echo) - # TODO: Make that go away. - logger.debug(f"Connecting to CrateDB: {dburi}") - self.connection = self.engine.connect() + if self.native: + self.native_connection = crate_client.connect( + [self.dburi_clean.replace("crate://", "https://")], + verify_ssl_cert=False, + ) + logger.debug(f"[Native] Connecting to CrateDB: {self.dburi_clean}") + else: + self.engine = sa.create_engine(self.dburi_clean, echo=echo) + logger.debug(f"[SQLAlchemy] Connecting to CrateDB: {self.dburi_clean}") + self.connection = self.engine.connect() @staticmethod def quote_relation_name(ident: str) -> str: @@ -137,33 +152,41 @@ def run_sql( return None def run_sql_real(self, sql: str, parameters: t.Mapping[str, str] = None, records: bool = False): - """ - Invoke an SQL statement and return results. - """ results = [] for statement in sqlparse.split(sql): if self.internal: statement += self.internal_tag - # FIXME: Persistent self.connection risks leaks & thread-unsafety. - # https://github.com/crate/cratedb-toolkit/pull/81#discussion_r2071499204 - with self.ctx: - result = self.connection.execute(sa.text(statement), parameters) - data: t.Any - if result.returns_rows: + + if self.native: + # Make a native HTTP POST request + response = requests.post( + self.native_url + "/_sql", + verify=False, + json={"stmt": statement, "args": list(parameters.values()) if parameters else []}, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + result = response.json() + if records: - rows = result.mappings().fetchall() - data = [dict(row.items()) for row in rows] + data = [dict(zip(result["cols"], row)) for row in result["rows"]] else: - data = result.fetchall() + data = result["rows"] else: - data = None + with self.ctx: + result = self.connection.execute(sa.text(statement), parameters) + if result.returns_rows: + if records: + rows = result.mappings().fetchall() + data = [dict(row.items()) for row in rows] + else: + data = result.fetchall() + else: + data = None + results.append(data) - # Backward-compatibility. - if len(results) == 1: - return results[0] - else: - return results + return results[0] if len(results) == 1 else results def count_records(self, name: str, errors: Literal["raise", "ignore"] = "raise", where: str = ""): """