From b0c8f606fde9767aeba174799e536075c60c0bcd Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk Date: Wed, 2 Nov 2022 00:11:05 +0600 Subject: [PATCH 1/4] pass a catalog name to connection to find tables not only in hive_metastore --- data_diff/databases/databricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 7942b53a..0d993cdf 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -83,7 +83,7 @@ def __init__( databricks = import_databricks() self._conn = databricks.sql.connect( - server_hostname=server_hostname, http_path=http_path, access_token=access_token + server_hostname=server_hostname, http_path=http_path, access_token=access_token, catalog=catalog ) logging.getLogger("databricks.sql").setLevel(logging.WARNING) From b82b3ed1d0f3f8965ab60e6666b67463a9c70478 Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk Date: Wed, 2 Nov 2022 00:11:23 +0600 Subject: [PATCH 2/4] fix schema parsing --- data_diff/databases/databricks.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 0d993cdf..8327f90a 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -108,7 +108,7 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - d = {r.COLUMN_NAME: r for r in rows} + d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} assert len(d) == len(rows) return d @@ -120,27 +120,26 @@ def _process_table_schema( resulted_rows = [] for row in rows: - row_type = "DECIMAL" if row.DATA_TYPE == 3 else row.TYPE_NAME + row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1] type_cls = self.TYPE_CLASSES.get(row_type, UnknownColType) if issubclass(type_cls, Integer): - row = (row.COLUMN_NAME, row_type, None, None, 0) + row = (row[0], row_type, None, None, 0) elif issubclass(type_cls, Float): - numeric_precision = self._convert_db_precision_to_digits(row.DECIMAL_DIGITS) - row = (row.COLUMN_NAME, row_type, None, numeric_precision, None) + numeric_precision = self._convert_db_precision_to_digits(row[2]) + row = (row[0], row_type, None, numeric_precision, None) elif issubclass(type_cls, Decimal): - # TYPE_NAME has a format DECIMAL(x,y) - items = row.TYPE_NAME[8:].rstrip(")").split(",") + items = row[1][8:].rstrip(")").split(",") numeric_precision, numeric_scale = int(items[0]), int(items[1]) - row = (row.COLUMN_NAME, row_type, None, numeric_precision, numeric_scale) + row = (row[0], row_type, None, numeric_precision, numeric_scale) elif issubclass(type_cls, Timestamp): - row = (row.COLUMN_NAME, row_type, row.DECIMAL_DIGITS, None, None) + row = (row[0], row_type, row[2], None, None) else: - row = (row.COLUMN_NAME, row_type, None, None, None) + row = (row[0], row_type, None, None, None) resulted_rows.append(row) From 1790a38fd7e8a597098d40a1327f0ece2bc4b754 Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk Date: Wed, 2 Nov 2022 17:04:33 +0600 Subject: [PATCH 3/4] support multithreading for databricks The databricks connector is not thread-safe so we should inherit ThreadedDatabase class --- data_diff/databases/databricks.py | 59 +++++++++++++++---------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 8327f90a..c38704b5 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -13,7 +13,7 @@ ColType, UnknownColType, ) -from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, BaseDialect, Database, import_helper, parse_table_name +from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, BaseDialect, ThreadedDatabase, import_helper, parse_table_name @import_helper(text="You can install it using 'pip install databricks-sql-connector'") @@ -68,43 +68,45 @@ def _convert_db_precision_to_digits(self, p: int) -> int: return max(super()._convert_db_precision_to_digits(p) - 1, 0) -class Databricks(Database): +class Databricks(ThreadedDatabase): dialect = Dialect() - def __init__( - self, - http_path: str, - access_token: str, - server_hostname: str, - catalog: str = "hive_metastore", - schema: str = "default", - **kwargs, - ): - databricks = import_databricks() - - self._conn = databricks.sql.connect( - server_hostname=server_hostname, http_path=http_path, access_token=access_token, catalog=catalog - ) - + def __init__(self, *, thread_count, **kw): logging.getLogger("databricks.sql").setLevel(logging.WARNING) - self.catalog = catalog - self.default_schema = schema - self.kwargs = kwargs + self._args = kw + self.default_schema = kw.get('schema', 'hive_metastore') + super().__init__(thread_count=thread_count) - def _query(self, sql_code: str) -> list: - "Uses the standard SQL cursor interface" - return self._query_conn(self._conn, sql_code) + def create_connection(self): + databricks = import_databricks() + + try: + return databricks.sql.connect( + server_hostname=self._args['server_hostname'], + http_path=self._args['http_path'], + access_token=self._args['access_token'], + catalog=self._args['catalog'], + ) + except databricks.sql.exc.Error as e: + raise ConnectionError(*e.args) from e def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html # So, to obtain information about schema, we should use another approach. + conn = self.create_connection() + schema, table = self._normalize_table_path(path) - with self._conn.cursor() as cursor: - cursor.columns(catalog_name=self.catalog, schema_name=schema, table_name=table) - rows = cursor.fetchall() + with conn.cursor() as cursor: + cursor.columns(catalog_name=self._args['catalog'], schema_name=schema, table_name=table) + try: + rows = cursor.fetchall() + except: + rows = None + finally: + conn.close() if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") @@ -121,7 +123,7 @@ def _process_table_schema( resulted_rows = [] for row in rows: row_type = "DECIMAL" if row[1].startswith("DECIMAL") else row[1] - type_cls = self.TYPE_CLASSES.get(row_type, UnknownColType) + type_cls = self.dialect.TYPE_CLASSES.get(row_type, UnknownColType) if issubclass(type_cls, Integer): row = (row[0], row_type, None, None, 0) @@ -152,9 +154,6 @@ def parse_table_name(self, name: str) -> DbPath: path = parse_table_name(name) return self._normalize_table_path(path) - def close(self): - self._conn.close() - @property def is_autocommit(self) -> bool: return True From 9c93229be75e652f90597ab11ef422815ad84370 Mon Sep 17 00:00:00 2001 From: Ilia Pinchuk Date: Wed, 2 Nov 2022 17:29:42 +0600 Subject: [PATCH 4/4] fix float value precision calculation --- data_diff/databases/databricks.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index c38704b5..79c46fc7 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,3 +1,4 @@ +import math from typing import Dict, Sequence import logging @@ -61,11 +62,14 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')" def normalize_number(self, value: str, coltype: NumericType) -> str: - return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") + value = f"cast({value} as decimal(38, {coltype.precision}))" + if coltype.precision > 0: + value = f"format_number({value}, {coltype.precision})" + return f"replace({self.to_string(value)}, ',', '')" def _convert_db_precision_to_digits(self, p: int) -> int: - # Subtracting 1 due to wierd precision issues - return max(super()._convert_db_precision_to_digits(p) - 1, 0) + # Subtracting 2 due to wierd precision issues + return max(super()._convert_db_precision_to_digits(p) - 2, 0) class Databricks(ThreadedDatabase): @@ -75,7 +79,7 @@ def __init__(self, *, thread_count, **kw): logging.getLogger("databricks.sql").setLevel(logging.WARNING) self._args = kw - self.default_schema = kw.get('schema', 'hive_metastore') + self.default_schema = kw.get("schema", "hive_metastore") super().__init__(thread_count=thread_count) def create_connection(self): @@ -83,11 +87,11 @@ def create_connection(self): try: return databricks.sql.connect( - server_hostname=self._args['server_hostname'], - http_path=self._args['http_path'], - access_token=self._args['access_token'], - catalog=self._args['catalog'], - ) + server_hostname=self._args["server_hostname"], + http_path=self._args["http_path"], + access_token=self._args["access_token"], + catalog=self._args["catalog"], + ) except databricks.sql.exc.Error as e: raise ConnectionError(*e.args) from e @@ -100,11 +104,9 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: schema, table = self._normalize_table_path(path) with conn.cursor() as cursor: - cursor.columns(catalog_name=self._args['catalog'], schema_name=schema, table_name=table) + cursor.columns(catalog_name=self._args["catalog"], schema_name=schema, table_name=table) try: rows = cursor.fetchall() - except: - rows = None finally: conn.close() if not rows: @@ -129,7 +131,7 @@ def _process_table_schema( row = (row[0], row_type, None, None, 0) elif issubclass(type_cls, Float): - numeric_precision = self._convert_db_precision_to_digits(row[2]) + numeric_precision = math.ceil(row[2] / math.log(2, 10)) row = (row[0], row_type, None, numeric_precision, None) elif issubclass(type_cls, Decimal):