diff --git a/data_diff/database.py b/data_diff/database.py index 5e39e8a3..0a4d900e 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -1,3 +1,4 @@ +import math from functools import lru_cache from itertools import zip_longest import re @@ -63,6 +64,7 @@ def import_presto(): class ConnectError(Exception): pass + class QueryError(Exception): pass @@ -105,6 +107,26 @@ class Datetime(TemporalType): pass +@dataclass +class NumericType(ColType): + # 'precision' signifies how many fractional digits (after the dot) we want to compare + precision: int + + +class Float(NumericType): + pass + + +class Decimal(NumericType): + pass + + +@dataclass +class Integer(Decimal): + def __post_init__(self): + assert self.precision == 0 + + @dataclass class UnknownColType(ColType): text: str @@ -162,6 +184,19 @@ def normalize_value_by_type(value: str, coltype: ColType) -> str: Rounded up/down according to coltype.rounds + - Floats/Decimals are expected in the format + "I.P" + + Where I is the integer part of the number (as many digits as necessary), + and must be at least one digit (0). + P is the fractional digits, the amount of which is specified with + coltype.precision. Trailing zeroes may be necessary. + + Note: This precision is different than the one used by databases. For decimals, + it's the same as "numeric_scale", and for floats, who use binary precision, + it can be calculated as log10(2**p) + + """ ... @@ -212,23 +247,48 @@ def query(self, sql_ast: SqlOrStr, res_type: type): def enable_interactive(self): self._interactive = True - def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType: + def _convert_db_precision_to_digits(self, p: int) -> int: + """Convert from binary precision, used by floats, to decimal precision.""" + # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format + return math.floor(math.log(2**p, 10)) + + def _parse_type( + self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None + ) -> ColType: """ """ cls = self.DATETIME_TYPES.get(type_repr) if cls: return cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION, + precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, rounds=self.ROUNDS_ON_PREC_LOSS, ) + cls = self.NUMERIC_TYPES.get(type_repr) + if cls: + if issubclass(cls, Integer): + # Some DBs have a constant numeric_scale, so they don't report it. + # We fill in the constant, so we need to ignore it for integers. + return cls(precision=0) + + elif issubclass(cls, Decimal): + return cls(precision=numeric_scale) + + assert issubclass(cls, Float) + # assert numeric_scale is None + return cls( + precision=self._convert_db_precision_to_digits( + numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION + ) + ) + return UnknownColType(type_repr) def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision FROM information_schema.columns " + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns " f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) @@ -250,7 +310,9 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: elif len(path) == 2: return path - raise ValueError(f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table") + raise ValueError( + f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table" + ) def parse_table_name(self, name: str) -> DbPath: return parse_table_name(name) @@ -295,7 +357,8 @@ def close(self): _CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2 CHECKSUM_MASK = (2**_CHECKSUM_BITSIZE) - 1 -DEFAULT_PRECISION = 6 +DEFAULT_DATETIME_PRECISION = 6 +DEFAULT_NUMERIC_PRECISION = 24 TIMESTAMP_PRECISION_POS = 20 # len("2022-06-03 12:24:35.") == 20 @@ -307,6 +370,13 @@ class Postgres(ThreadedDatabase): "timestamp": Timestamp, # "datetime": Datetime, } + NUMERIC_TYPES = { + "double precision": Float, + "real": Float, + "decimal": Decimal, + "integer": Integer, + "numeric": Decimal, + } ROUNDS_ON_PREC_LOSS = True default_schema = "public" @@ -316,6 +386,10 @@ def __init__(self, host, port, user, password, *, database, thread_count, **kw): super().__init__(thread_count=thread_count) + def _convert_db_precision_to_digits(self, p: int) -> int: + # Subtracting 2 due to wierd precision issues in Postgres + return super()._convert_db_precision_to_digits(p) - 2 + def create_connection(self): postgres = import_postgres() try: @@ -351,6 +425,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + elif isinstance(coltype, NumericType): + value = f"{value}::decimal(38, {coltype.precision})" + return self.to_string(f"{value}") @@ -362,6 +439,11 @@ class Presto(Database): "timestamp": Timestamp, # "datetime": Datetime, } + NUMERIC_TYPES = { + "integer": Integer, + "real": Float, + "double": Float, + } ROUNDS_ON_PREC_LOSS = True def __init__(self, host, port, user, password, *, catalog, schema=None, **kw): @@ -401,6 +483,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" ) + elif isinstance(coltype, NumericType): + value = f"cast({value} as decimal(38,{coltype.precision}))" + return self.to_string(value) def select_table_schema(self, path: DbPath) -> str: @@ -412,7 +497,6 @@ def select_table_schema(self, path: DbPath) -> str: ) def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType: - """ """ regexps = { r"timestamp\((\d)\)": Timestamp, r"timestamp\((\d)\) with time zone": TimestampTZ, @@ -422,8 +506,29 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr if m: datetime_precision = int(m.group(1)) return cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION, rounds=False + precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, + rounds=False, + ) + + regexps = {r"decimal\((\d+),(\d+)\)": Decimal} + for regexp, cls in regexps.items(): + m = re.match(regexp + "$", type_repr) + if m: + prec, scale = map(int, m.groups()) + return cls(scale) + + cls = self.NUMERIC_TYPES.get(type_repr) + if cls: + if issubclass(cls, Integer): + assert numeric_precision is not None + return cls(0) + + assert issubclass(cls, Float) + return cls( + precision=self._convert_db_precision_to_digits( + numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION ) + ) return UnknownColType(type_repr) @@ -433,6 +538,12 @@ class MySQL(ThreadedDatabase): "datetime": Datetime, "timestamp": Timestamp, } + NUMERIC_TYPES = { + "double": Float, + "float": Float, + "decimal": Decimal, + "int": Integer, + } ROUNDS_ON_PREC_LOSS = True def __init__(self, host, port, user, password, *, database, thread_count, **kw): @@ -472,6 +583,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: s = self.to_string(f"cast({value} as datetime(6))") return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" + elif isinstance(coltype, NumericType): + value = f"cast({value} as decimal(38,{coltype.precision}))" + return self.to_string(f"{value}") @@ -513,16 +627,24 @@ def select_table_schema(self, path: DbPath) -> str: (table,) = path return ( - f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision" + f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" f" FROM USER_TAB_COLUMNS WHERE table_name = '{table.upper()}'" ) def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, PrecisionType): + if isinstance(coltype, TemporalType): return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" + elif isinstance(coltype, NumericType): + # FM999.9990 + format_str = "FM" + "9" * (38 - coltype.precision) + if coltype.precision: + format_str += "0." + "9" * (coltype.precision - 1) + "0" + return f"to_char({value}, '{format_str}')" return self.to_string(f"{value}") - def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType: + def _parse_type( + self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None + ) -> ColType: """ """ regexps = { r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, @@ -532,14 +654,40 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr m = re.match(regexp + "$", type_repr) if m: datetime_precision = int(m.group(1)) - return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION, - rounds=self.ROUNDS_ON_PREC_LOSS + return cls( + precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, + rounds=self.ROUNDS_ON_PREC_LOSS, ) + cls = { + "NUMBER": Decimal, + "FLOAT": Float, + }.get(type_repr, None) + if cls: + if issubclass(cls, Decimal): + assert numeric_scale is not None, (type_repr, numeric_precision, numeric_scale) + return cls(precision=numeric_scale) + + assert issubclass(cls, Float) + return cls( + precision=self._convert_db_precision_to_digits( + numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION + ) + ) + return UnknownColType(type_repr) class Redshift(Postgres): + NUMERIC_TYPES = { + **Postgres.NUMERIC_TYPES, + "double": Float, + "real": Float, + } + + # def _convert_db_precision_to_digits(self, p: int) -> int: + # return super()._convert_db_precision_to_digits(p // 2) + def md5_to_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" @@ -560,6 +708,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + elif isinstance(coltype, NumericType): + value = f"{value}::decimal(38,{coltype.precision})" + return self.to_string(f"{value}") @@ -595,6 +746,14 @@ class BigQuery(Database): "TIMESTAMP": Timestamp, "DATETIME": Datetime, } + NUMERIC_TYPES = { + "INT64": Integer, + "INT32": Integer, + "NUMERIC": Decimal, + "BIGNUMERIC": Decimal, + "FLOAT64": Float, + "FLOAT32": Float, + } ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation def __init__(self, project, *, dataset, **kw): @@ -640,12 +799,12 @@ def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) return ( - f"SELECT column_name, data_type, 6 as datetime_precision, 6 as numeric_precision FROM {schema}.INFORMATION_SCHEMA.COLUMNS " + f"SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale FROM {schema}.INFORMATION_SCHEMA.COLUMNS " f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, PrecisionType): + if isinstance(coltype, TemporalType): if coltype.rounds: timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" @@ -657,6 +816,12 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + elif isinstance(coltype, Integer): + pass + + elif isinstance(coltype, NumericType): + # value = f"cast({value} as decimal)" + return f"format('%.{coltype.precision}f', ({value}))" return self.to_string(f"{value}") @@ -671,6 +836,10 @@ class Snowflake(Database): "TIMESTAMP_LTZ": Timestamp, "TIMESTAMP_TZ": TimestampTZ, } + NUMERIC_TYPES = { + "NUMBER": Decimal, + "FLOAT": Float, + } ROUNDS_ON_PREC_LOSS = False def __init__( @@ -729,7 +898,7 @@ def select_table_schema(self, path: DbPath) -> str: return super().select_table_schema((schema, table)) def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, PrecisionType): + if isinstance(coltype, TemporalType): if coltype.rounds: timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))" else: @@ -737,6 +906,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" + elif isinstance(coltype, NumericType): + value = f"cast({value} as decimal(38, {coltype.precision}))" + return self.to_string(f"{value}") diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 4087b49d..565d51a2 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -12,7 +12,7 @@ from runtype import dataclass from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Min, Max -from .database import Database, PrecisionType, ColType +from .database import Database, NumericType, PrecisionType, ColType logger = logging.getLogger("diff_tables") @@ -147,10 +147,10 @@ def with_schema(self) -> "TableSegment": schema = Schema_CaseSensitive(schema) else: if len({k.lower() for k in schema}) < len(schema): - logger.warn( + logger.warning( f'Ambiguous schema for {self.database}:{".".join(self.table_path)} | Columns = {", ".join(list(schema))}' ) - logger.warn("We recommend to disable case-insensitivity (remove --any-case).") + logger.warning("We recommend to disable case-insensitivity (remove --any-case).") schema = Schema_CaseInsensitive(schema) return self.new(_schema=schema) @@ -241,7 +241,7 @@ def count_and_checksum(self) -> Tuple[int, int]: ) duration = time.time() - start if duration > RECOMMENDED_CHECKSUM_DURATION: - logger.warn( + logger.warning( f"Checksum is taking longer than expected ({duration:.2f}s). " "We recommend increasing --bisection-factor or decreasing --threads." ) @@ -364,11 +364,23 @@ def _validate_and_adjust_columns(self, table1, table2): lowest = min(col1, col2, key=attrgetter("precision")) if col1.precision != col2.precision: - logger.warn(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}") + logger.warning(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}") table1._schema[c] = col1.replace(precision=lowest.precision, rounds=lowest.rounds) table2._schema[c] = col2.replace(precision=lowest.precision, rounds=lowest.rounds) + elif isinstance(col1, NumericType): + if not isinstance(col2, NumericType): + raise TypeError(f"Incompatible types for column {c}: {col1} <-> {col2}") + + lowest = min(col1, col2, key=attrgetter("precision")) + + if col1.precision != col2.precision: + logger.warning(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}") + + table1._schema[c] = col1.replace(precision=lowest.precision) + table2._schema[c] = col2.replace(precision=lowest.precision) + def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None): assert table1.is_bounded and table2.is_bounded @@ -412,7 +424,7 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) if count1 == 0 and count2 == 0: - logger.warn( + logger.warning( "Uneven distribution of keys detected. (big gaps in the key column). " "For better performance, we recommend to increase the bisection-threshold." ) diff --git a/tests/common.py b/tests/common.py index 33281861..32c4c30b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -19,8 +19,12 @@ except ImportError: pass # No local settings +if TEST_BIGQUERY_CONN_STRING and TEST_SNOWFLAKE_CONN_STRING: + # TODO Fix this. Seems to have something to do with pyarrow + raise RuntimeError("Using BigQuery at the same time as Snowflake causes an error!!") + CONN_STRINGS = { - # db.BigQuery: TEST_BIGQUERY_CONN_STRING, # TODO BigQuery before/after Snowflake causes an error! + db.BigQuery: TEST_BIGQUERY_CONN_STRING, db.MySQL: TEST_MYSQL_CONN_STRING, db.Postgres: TEST_POSTGRES_CONN_STRING, db.Snowflake: TEST_SNOWFLAKE_CONN_STRING, diff --git a/tests/test_api.py b/tests/test_api.py index cd5b9c19..2a532edd 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -15,6 +15,7 @@ def setUpClass(cls): cls.preql = preql.Preql(TEST_MYSQL_CONN_STRING) def setUp(self) -> None: + self.preql = preql.Preql(TEST_MYSQL_CONN_STRING) self.preql( r""" table test_api { diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 6b95d310..1e5602f9 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -1,13 +1,20 @@ from contextlib import suppress import unittest import time +import logging +from decimal import Decimal + +from parameterized import parameterized, parameterized_class +import preql + from data_diff import database as db from data_diff.diff_tables import TableDiffer, TableSegment from parameterized import parameterized, parameterized_class from .common import CONN_STRINGS import logging -logging.getLogger("diff_tables").setLevel(logging.WARN) + +logging.getLogger("diff_tables").setLevel(logging.ERROR) logging.getLogger("database").setLevel(logging.WARN) CONNS = {k: db.connect_to_uri(v) for k, v in CONN_STRINGS.items()} @@ -24,7 +31,24 @@ "2022-05-01 15:10:03.003030", "2022-06-01 15:10:05.009900", ], - "float": [0.0, 0.1, 0.10, 10.0, 100.98], + "float": [ + 0.0, + 0.1, + 0.00188, + 0.99999, + 0.091919, + 0.10, + 10.0, + 100.98, + 0.001201923076923077, + 1 / 3, + 1 / 5, + 1 / 109, + 1 / 109489, + 1 / 1094893892389, + 1 / 10948938923893289, + 3.141592653589793, + ], } DATABASE_TYPES = { @@ -43,9 +67,10 @@ ], # https://www.postgresql.org/docs/current/datatype-numeric.html "float": [ - # "real", - # "double precision", - # "numeric(6,3)", + "real", + "float", + "double precision", + "numeric(6,3)", ], }, db.MySQL: { @@ -58,12 +83,19 @@ # "bigint", # 8 bytes ], # https://dev.mysql.com/doc/refman/8.0/en/datetime.html - "datetime_no_timezone": ["timestamp(6)", "timestamp(3)", "timestamp(0)", "timestamp", "datetime(6)"], + "datetime_no_timezone": [ + "timestamp(6)", + "timestamp(3)", + "timestamp(0)", + "timestamp", + "datetime(6)", + ], # https://dev.mysql.com/doc/refman/8.0/en/numeric-types.html "float": [ - # "float", - # "double", - # "numeric", + "float", + "double", + "numeric", + "numeric(65, 10)", ], }, db.BigQuery: { @@ -71,6 +103,11 @@ "timestamp", # "datetime", ], + "float": [ + "numeric", + "float64", + "bignumeric", + ], }, db.Snowflake: { # https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint @@ -92,8 +129,8 @@ ], # https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#decimal-numeric "float": [ - # "float" - # "numeric", + "float", + "numeric", ], }, db.Redshift: { @@ -105,9 +142,9 @@ ], # https://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html#r_Numeric_types201-floating-point-types "float": [ - # "float4", - # "float8", - # "numeric", + "float4", + "float8", + "numeric", ], }, db.Oracle: { @@ -120,8 +157,8 @@ "timestamp(9) with local time zone", ], "float": [ - # "float", - # "numeric", + "float", + "numeric", ], }, db.Presto: { @@ -132,11 +169,18 @@ # "int", # 4 bytes # "bigint", # 8 bytes ], - "datetime_no_timezone": ["timestamp(6)", "timestamp(3)", "timestamp(0)", "timestamp", "datetime(6)"], + "datetime_no_timezone": [ + "timestamp(6)", + "timestamp(3)", + "timestamp(0)", + "timestamp", + "datetime(6)", + ], "float": [ - # "float", - # "double", - # "numeric", + "real", + "double", + "decimal(10,2)", + "decimal(30,6)", ], }, } @@ -150,7 +194,10 @@ # target_type: (int, bigint) } for source_db, source_type_categories in DATABASE_TYPES.items(): for target_db, target_type_categories in DATABASE_TYPES.items(): - for type_category, source_types in source_type_categories.items(): # int, datetime, .. + for ( + type_category, + source_types, + ) in source_type_categories.items(): # int, datetime, .. for source_type in source_types: for target_type in target_type_categories[type_category]: if CONNS.get(source_db, False) and CONNS.get(target_db, False): @@ -184,25 +231,38 @@ def _insert_to_table(conn, table, values): if isinstance(conn, db.Oracle): selects = [] for j, sample in values: - selects.append( f"SELECT {j}, timestamp '{sample}' FROM dual" ) - insertion_query += ' UNION ALL '.join(selects) + if isinstance(sample, (float, Decimal, int)): + value = str(sample) + else: + value = f"timestamp '{sample}'" + selects.append(f"SELECT {j}, {value} FROM dual") + insertion_query += " UNION ALL ".join(selects) else: - insertion_query += ' VALUES ' + insertion_query += " VALUES " for j, sample in values: - insertion_query += f"({j}, '{sample}')," + if isinstance(sample, (float, Decimal)): + value = str(sample) + else: + value = f"'{sample}'" + insertion_query += f"({j}, {value})," + insertion_query = insertion_query[0:-1] conn.query(insertion_query, None) + if not isinstance(conn, db.BigQuery): conn.query("COMMIT", None) + def _drop_table_if_exists(conn, table): with suppress(db.QueryError): if isinstance(conn, db.Oracle): conn.query(f"DROP TABLE {table}", None) + conn.query(f"DROP TABLE {table}", None) else: conn.query(f"DROP TABLE IF EXISTS {table}", None) + class TestDiffCrossDatabaseTables(unittest.TestCase): @parameterized.expand(type_pairs, name_func=expand_params) def test_types(self, source_db, target_db, source_type, target_type, type_category): @@ -214,8 +274,12 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego self.connections = [self.src_conn, self.dst_conn] sample_values = TYPE_SAMPLES[type_category] - src_table_path = src_conn.parse_table_name("src") - dst_table_path = dst_conn.parse_table_name("dst") + # Limit in MySQL is 64 + src_table_name = f"src_{self._testMethodName[:60]}" + dst_table_name = f"dst_{self._testMethodName[:60]}" + + src_table_path = src_conn.parse_table_name(src_table_name) + dst_table_path = dst_conn.parse_table_name(dst_table_name) src_table = src_conn.quote(".".join(src_table_path)) dst_table = dst_conn.quote(".".join(dst_table_path)) @@ -250,4 +314,3 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego duration = time.time() - start # print(f"source_db={source_db.__name__} target_db={target_db.__name__} source_type={source_type} target_type={target_type} duration={round(duration * 1000, 2)}ms") - diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index a457081d..84bb6b9a 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -33,8 +33,24 @@ def tearDownClass(cls): cls.connection.close() + # Fallback for test runners that doesn't support setUpClass/tearDownClass + def setUp(self) -> None: + if not hasattr(self, 'connection'): + self.setUpClass.__func__(self) + self.private_connection = True + + return super().setUp() + + def tearDown(self) -> None: + if hasattr(self, 'private_connection'): + self.tearDownClass.__func__(self) + + return super().tearDown() + + class TestDates(TestWithConnection): def setUp(self): + super().setUp() self.connection.query("DROP TABLE IF EXISTS a", None) self.connection.query("DROP TABLE IF EXISTS b", None) self.preql( @@ -110,6 +126,7 @@ def test_offset(self): class TestDiffTables(TestWithConnection): def setUp(self): + super().setUp() self.connection.query("DROP TABLE IF EXISTS ratings_test", None) self.connection.query("DROP TABLE IF EXISTS ratings_test2", None) self.preql.load("./tests/setup.pql") @@ -221,9 +238,9 @@ def test_diff_sorted_by_key(self): class TestTableSegment(TestWithConnection): def setUp(self) -> None: + super().setUp() self.table = TableSegment(self.connection, ("ratings_test",), "id", "timestamp") self.table2 = TableSegment(self.connection, ("ratings_test2",), "id", "timestamp") - return super().setUp() def test_table_segment(self): early = datetime.datetime(2021, 1, 1, 0, 0) diff --git a/tests/test_normalize_fields.py b/tests/test_normalize_fields.py index 2953f8ad..468d2667 100644 --- a/tests/test_normalize_fields.py +++ b/tests/test_normalize_fields.py @@ -5,7 +5,7 @@ import preql -from data_diff.database import BigQuery, MySQL, Snowflake, connect_to_uri, Oracle, DEFAULT_PRECISION +from data_diff.database import BigQuery, MySQL, Snowflake, connect_to_uri, Oracle from data_diff.sql import Select from data_diff import database as db