diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 7942b53a..79c46fc7 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,3 +1,4 @@ +import math from typing import Dict, Sequence import logging @@ -13,7 +14,7 @@ ColType, UnknownColType, ) -from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, BaseDialect, Database, import_helper, parse_table_name +from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, BaseDialect, ThreadedDatabase, import_helper, parse_table_name @import_helper(text="You can install it using 'pip install databricks-sql-connector'") @@ -61,54 +62,57 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')" def normalize_number(self, value: str, coltype: NumericType) -> str: - return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + value = f"cast({value} as decimal(38, {coltype.precision}))" + if coltype.precision > 0: + value = f"format_number({value}, {coltype.precision})" + return f"replace({self.to_string(value)}, ',', '')" def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 1 due to wierd precision issues - return max(super()._convert_db_precision_to_digits(p) - 1, 0) + # Subtracting 2 due to wierd precision issues + return max(super()._convert_db_precision_to_digits(p) - 2, 0) -class Databricks(Database): +class Databricks(ThreadedDatabase): dialect = Dialect() - def __init__( - self, - http_path: str, - access_token: str, - server_hostname: str, - catalog: str = "hive_metastore", - schema: str = "default", - **kwargs, - ): - databricks = import_databricks() - - self._conn = databricks.sql.connect( - server_hostname=server_hostname, http_path=http_path, access_token=access_token - ) - + def __init__(self, *, thread_count, **kw): logging.getLogger("databricks.sql").setLevel(logging.WARNING) - self.catalog = catalog - self.default_schema = schema - self.kwargs = kwargs + self._args = kw + self.default_schema = kw.get("schema", "hive_metastore") + super().__init__(thread_count=thread_count) - def _query(self, sql_code: str) -> list: - "Uses the standard SQL cursor interface" - return self._query_conn(self._conn, sql_code) + def create_connection(self): + databricks = import_databricks() + + try: + return databricks.sql.connect( + server_hostname=self._args["server_hostname"], + http_path=self._args["http_path"], + access_token=self._args["access_token"], + catalog=self._args["catalog"], + ) + except databricks.sql.exc.Error as e: + raise ConnectionError(*e.args) from e def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html # So, to obtain information about schema, we should use another approach. + conn = self.create_connection() + schema, table = self._normalize_table_path(path) - with self._conn.cursor() as cursor: - cursor.columns(catalog_name=self.catalog, schema_name=schema, table_name=table) - rows = cursor.fetchall() + with conn.cursor() as cursor: + cursor.columns(catalog_name=self._args["catalog"], schema_name=schema, table_name=table) + try: + rows = cursor.fetchall() + finally: + conn.close() if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - d = {r.COLUMN_NAME: r for r in rows} + d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} assert len(d) == len(rows) return d @@ -120,27 +124,26 @@ def _process_table_schema( resulted_rows = [] for row in rows: - row_type = "DECIMAL" if row.DATA_TYPE == 3 else row.TYPE_NAME - type_cls = self.TYPE_CLASSES.get(row_type, UnknownColType) + row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1] + type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType) if issubclass(type_cls, Integer): - row = (row.COLUMN_NAME, row_type, None, None, 0) + row = (row[0], row_type, None, None, 0) elif issubclass(type_cls, Float): - numeric_precision = self._convert_db_precision_to_digits(row.DECIMAL_DIGITS) - row = (row.COLUMN_NAME, row_type, None, numeric_precision, None) + numeric_precision = math.ceil(row[2] / math.log(2, 10)) + row = (row[0], row_type, None, numeric_precision, None) elif issubclass(type_cls, Decimal): - # TYPE_NAME has a format DECIMAL(x,y) - items = row.TYPE_NAME[8:].rstrip(")").split(",") + items = row[1][8:].rstrip(")").split(",") numeric_precision, numeric_scale = int(items[0]), int(items[1]) - row = (row.COLUMN_NAME, row_type, None, numeric_precision, numeric_scale) + row = (row[0], row_type, None, numeric_precision, numeric_scale) elif issubclass(type_cls, Timestamp): - row = (row.COLUMN_NAME, row_type, row.DECIMAL_DIGITS, None, None) + row = (row[0], row_type, row[2], None, None) else: - row = (row.COLUMN_NAME, row_type, None, None, None) + row = (row[0], row_type, None, None, None) resulted_rows.append(row) @@ -153,9 +156,6 @@ def parse_table_name(self, name: str) -> DbPath: path = parse_table_name(name) return self._normalize_table_path(path) - def close(self): - self._conn.close() - @property def is_autocommit(self) -> bool: return True