From 9b6129a2c61d32e67bb84dd14f4da39e3488b868 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 14 Oct 2022 17:26:04 +0200 Subject: [PATCH 1/4] Now tests for unique key constraints (if possible) instead of always actively validating (+ tests) --- data_diff/databases/base.py | 15 ++++++++ data_diff/databases/bigquery.py | 5 ++- data_diff/databases/database_types.py | 11 ++++++ data_diff/databases/mysql.py | 1 + data_diff/databases/oracle.py | 1 + data_diff/databases/postgresql.py | 1 + data_diff/databases/snowflake.py | 5 ++- data_diff/joindiff_tables.py | 16 ++++++--- data_diff/queries/ast_classes.py | 8 +++-- tests/common.py | 1 + tests/test_joindiff.py | 49 +++++++++++++++++++++++++++ 11 files changed, 103 insertions(+), 10 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index f0a96f20..48a551b2 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -110,6 +110,7 @@ class Database(AbstractDatabase): TYPE_CLASSES: Dict[str, type] = {} default_schema: str = None SUPPORTS_ALPHANUMS = True + SUPPORTS_PRIMARY_KEY = False _interactive = False @@ -235,6 +236,20 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: assert len(d) == len(rows) return d + + def select_table_unique_columns(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + + return ( + "SELECT column_name " + "FROM information_schema.key_column_usage " + f"WHERE table_name = '{table}' AND table_schema = '{schema}'" + ) + + def query_table_unique_columns(self, path: DbPath) -> List[str]: + res = self.query(self.select_table_unique_columns(path), List[str]) + return list(res) + def _process_table_schema( self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None ): diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 9c500dd5..3d3720b6 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import List, Union from .database_types import Timestamp, Datetime, Integer, Decimal, Float, Text, DbPath, FractionalType, TemporalType from .base import Database, import_helper, parse_table_name, ConnectError, apply_query from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter @@ -78,6 +78,9 @@ def select_table_schema(self, path: DbPath) -> str: f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) + def query_table_unique_columns(self, path: DbPath) -> List[str]: + return [] + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 1de1d2fc..2a76ae05 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -253,6 +253,17 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: """ ... + @abstractmethod + def select_table_unique_columns(self, path: DbPath) -> str: + "Provide SQL for selecting the names of unique columns in the table" + ... + + @abstractmethod + def query_table_unique_columns(self, path: DbPath) -> List[str]: + """Query the table for its unique columns for table in 'path', and return {column} + """ + ... + @abstractmethod def _process_table_schema( self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index 3f9eb98c..f7023946 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -39,6 +39,7 @@ class MySQL(ThreadedDatabase): } ROUNDS_ON_PREC_LOSS = True SUPPORTS_ALPHANUMS = False + SUPPORTS_PRIMARY_KEY = True def __init__(self, *, thread_count, **kw): self._args = kw diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 80647ba3..d59e5b55 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -38,6 +38,7 @@ class Oracle(ThreadedDatabase): "VARCHAR2": Text, } ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True def __init__(self, *, host, database, thread_count, **kw): self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 72d26d07..02920c2b 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -46,6 +46,7 @@ class PostgreSQL(ThreadedDatabase): "uuid": Native_UUID, } ROUNDS_ON_PREC_LOSS = True + SUPPORTS_PRIMARY_KEY = True default_schema = "public" diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 635ba8f4..afd52ba8 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, List import logging from .database_types import Timestamp, TimestampTZ, Decimal, Float, Text, FractionalType, TemporalType, DbPath @@ -95,3 +95,6 @@ def is_autocommit(self) -> bool: def explain_as_text(self, query: str) -> str: return f"EXPLAIN USING TEXT {query}" + + def query_table_unique_columns(self, path: DbPath) -> List[str]: + return [] diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index b630d66e..26789bc2 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -196,18 +196,24 @@ def _diff_segments( if not is_xa: yield "+", tuple(b_row) - def _test_duplicate_keys(self, table1, table2): + def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): logger.debug("Testing for duplicate keys") # Test duplicate keys for ts in [table1, table2]: + unique = ts.database.query_table_unique_columns(ts.table_path) + t = ts.make_select() key_columns = ts.key_columns - q = t.select(total=Count(), total_distinct=Count(Concat(this[key_columns]), distinct=True)) - total, total_distinct = ts.database.query(q, tuple) - if total != total_distinct: - raise ValueError("Duplicate primary keys") + unvalidated = list(set(key_columns) - set(unique)) + if unvalidated: + # Validate that there are no duplicate keys + self.stats['validated_unique_keys'] = self.stats.get('validated_unique_keys', []) + [unvalidated] + q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True)) + total, total_distinct = ts.database.query(q, tuple) + if total != total_distinct: + raise ValueError("Duplicate primary keys") def _test_null_keys(self, table1, table2): logger.debug("Testing for null keys") diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index b5456b59..9b7fe63a 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -298,12 +298,12 @@ class TablePath(ExprNode, ITable): path: DbPath schema: Optional[Schema] = field(default=None, repr=False) - def create(self, source_table: ITable = None, *, if_not_exists=False): + def create(self, source_table: ITable = None, *, if_not_exists=False, primary_keys=None): if source_table is None and not self.schema: raise ValueError("Either schema or source table needed to create table") if isinstance(source_table, TablePath): source_table = source_table.select() - return CreateTable(self, source_table, if_not_exists=if_not_exists) + return CreateTable(self, source_table, if_not_exists=if_not_exists, primary_keys=primary_keys) def drop(self, if_exists=False): return DropTable(self, if_exists=if_exists) @@ -634,6 +634,7 @@ class CreateTable(Statement): path: TablePath source_table: Expr = None if_not_exists: bool = False + primary_keys: List[str] = None def compile(self, c: Compiler) -> str: ne = "IF NOT EXISTS " if self.if_not_exists else "" @@ -641,7 +642,8 @@ def compile(self, c: Compiler) -> str: return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}" schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) - return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})" + pks = ", PRIMARY KEY (%s)" % ', '.join(self.primary_keys) if self.primary_keys and c.database.SUPPORTS_PRIMARY_KEY else "" + return f"CREATE TABLE {ne}{c.compile(self.path)}({schema}{pks})" @dataclass diff --git a/tests/common.py b/tests/common.py index aad75074..cd974e34 100644 --- a/tests/common.py +++ b/tests/common.py @@ -149,6 +149,7 @@ def tearDown(self): def _parameterized_class_per_conn(test_databases): + test_databases = set(test_databases) names = [(cls.__name__, cls) for cls in CONN_STRINGS if cls in test_databases] return parameterized_class(("name", "db_cls"), names) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 03ca3d69..557b207f 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -273,3 +273,52 @@ def test_null_pks(self): x = self.differ.diff_tables(self.table, self.table2) self.assertRaises(ValueError, list, x) + + +@test_each_database_in_list(d for d in TEST_DATABASES if d.SUPPORTS_PRIMARY_KEY) +class TestPrimaryKeys(TestPerDatabase): + def setUp(self): + super().setUp() + + self.src_table = table( + self.table_src_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float}, + ) + self.dst_table = table( + self.table_dst_path, + schema={"id": int, "userid": int, "movieid": int, "rating": float}, + ) + + self.connection.query([ + self.src_table.create(primary_keys=['id']), + self.dst_table.create(primary_keys=['id', 'userid']), + commit + ] + ) + + self.differ = JoinDiffer() + + def test_unique_constraint(self): + self.connection.query( + [ + self.src_table.insert_rows([[1, 1, 1, 9], [2, 2, 2, 9]]), + self.dst_table.insert_rows([[1, 1, 1, 9], [2, 2, 2, 9]]), + commit, + ] + ) + + # Test no active validation + table = TableSegment(self.connection, self.table_src_path, ("id",), case_sensitive=False) + table2 = TableSegment(self.connection, self.table_dst_path, ("id",), case_sensitive=False) + + res = list(self.differ.diff_tables(table, table2)) + assert not res + assert 'validated_unique_keys' not in self.differ.stats + + # Test active validation + table = TableSegment(self.connection, self.table_src_path, ("userid",), case_sensitive=False) + table2 = TableSegment(self.connection, self.table_dst_path, ("userid",), case_sensitive=False) + + res = list(self.differ.diff_tables(table, table2)) + assert not res + self.assertEqual( self.differ.stats['validated_unique_keys'], [['userid']] ) From 9ad272df0c738a74da94246003332ad65a44ad40 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 14 Oct 2022 17:26:16 +0200 Subject: [PATCH 2/4] black --- data_diff/databases/base.py | 1 - data_diff/databases/database_types.py | 3 +-- data_diff/joindiff_tables.py | 2 +- data_diff/queries/ast_classes.py | 6 +++++- tests/test_joindiff.py | 11 ++++------- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 48a551b2..18192152 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -236,7 +236,6 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: assert len(d) == len(rows) return d - def select_table_unique_columns(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) diff --git a/data_diff/databases/database_types.py b/data_diff/databases/database_types.py index 2a76ae05..c02cc3e2 100644 --- a/data_diff/databases/database_types.py +++ b/data_diff/databases/database_types.py @@ -260,8 +260,7 @@ def select_table_unique_columns(self, path: DbPath) -> str: @abstractmethod def query_table_unique_columns(self, path: DbPath) -> List[str]: - """Query the table for its unique columns for table in 'path', and return {column} - """ + """Query the table for its unique columns for table in 'path', and return {column}""" ... @abstractmethod diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 26789bc2..863a8450 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -209,7 +209,7 @@ def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): unvalidated = list(set(key_columns) - set(unique)) if unvalidated: # Validate that there are no duplicate keys - self.stats['validated_unique_keys'] = self.stats.get('validated_unique_keys', []) + [unvalidated] + self.stats["validated_unique_keys"] = self.stats.get("validated_unique_keys", []) + [unvalidated] q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True)) total, total_distinct = ts.database.query(q, tuple) if total != total_distinct: diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 9b7fe63a..4b93efdf 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -642,7 +642,11 @@ def compile(self, c: Compiler) -> str: return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}" schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items()) - pks = ", PRIMARY KEY (%s)" % ', '.join(self.primary_keys) if self.primary_keys and c.database.SUPPORTS_PRIMARY_KEY else "" + pks = ( + ", PRIMARY KEY (%s)" % ", ".join(self.primary_keys) + if self.primary_keys and c.database.SUPPORTS_PRIMARY_KEY + else "" + ) return f"CREATE TABLE {ne}{c.compile(self.path)}({schema}{pks})" diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 557b207f..35e2c79f 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -289,11 +289,8 @@ def setUp(self): schema={"id": int, "userid": int, "movieid": int, "rating": float}, ) - self.connection.query([ - self.src_table.create(primary_keys=['id']), - self.dst_table.create(primary_keys=['id', 'userid']), - commit - ] + self.connection.query( + [self.src_table.create(primary_keys=["id"]), self.dst_table.create(primary_keys=["id", "userid"]), commit] ) self.differ = JoinDiffer() @@ -313,7 +310,7 @@ def test_unique_constraint(self): res = list(self.differ.diff_tables(table, table2)) assert not res - assert 'validated_unique_keys' not in self.differ.stats + assert "validated_unique_keys" not in self.differ.stats # Test active validation table = TableSegment(self.connection, self.table_src_path, ("userid",), case_sensitive=False) @@ -321,4 +318,4 @@ def test_unique_constraint(self): res = list(self.differ.diff_tables(table, table2)) assert not res - self.assertEqual( self.differ.stats['validated_unique_keys'], [['userid']] ) + self.assertEqual(self.differ.stats["validated_unique_keys"], [["userid"]]) From 56b45da4aafdbb4c8cea033d469bf51abc55ec4a Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Mon, 17 Oct 2022 13:31:25 +0200 Subject: [PATCH 3/4] Added Database.SUPPORTS_UNIQUE_CONSTAINT --- data_diff/databases/base.py | 3 +++ data_diff/databases/mysql.py | 1 + data_diff/databases/oracle.py | 1 + data_diff/databases/postgresql.py | 1 + data_diff/joindiff_tables.py | 2 +- 5 files changed, 7 insertions(+), 1 deletion(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 18192152..7d0ac864 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -111,6 +111,7 @@ class Database(AbstractDatabase): default_schema: str = None SUPPORTS_ALPHANUMS = True SUPPORTS_PRIMARY_KEY = False + SUPPORTS_UNIQUE_CONSTAINT = False _interactive = False @@ -246,6 +247,8 @@ def select_table_unique_columns(self, path: DbPath) -> str: ) def query_table_unique_columns(self, path: DbPath) -> List[str]: + if not self.SUPPORTS_UNIQUE_CONSTAINT: + raise NotImplementedError("This database doesn't support 'unique' constraints") res = self.query(self.select_table_unique_columns(path), List[str]) return list(res) diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index f7023946..e8e47b1b 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -40,6 +40,7 @@ class MySQL(ThreadedDatabase): ROUNDS_ON_PREC_LOSS = True SUPPORTS_ALPHANUMS = False SUPPORTS_PRIMARY_KEY = True + SUPPORTS_UNIQUE_CONSTAINT = True def __init__(self, *, thread_count, **kw): self._args = kw diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index d59e5b55..c135f71f 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -39,6 +39,7 @@ class Oracle(ThreadedDatabase): } ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True + SUPPORTS_UNIQUE_CONSTAINT = True def __init__(self, *, host, database, thread_count, **kw): self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 02920c2b..3181dab1 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -47,6 +47,7 @@ class PostgreSQL(ThreadedDatabase): } ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True + SUPPORTS_UNIQUE_CONSTAINT = True default_schema = "public" diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 863a8450..1107ba87 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -201,7 +201,7 @@ def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment): # Test duplicate keys for ts in [table1, table2]: - unique = ts.database.query_table_unique_columns(ts.table_path) + unique = ts.database.query_table_unique_columns(ts.table_path) if ts.database.SUPPORTS_UNIQUE_CONSTAINT else [] t = ts.make_select() key_columns = ts.key_columns From 3e82588d4f10815d612100f683e1f11249766ac7 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Wed, 19 Oct 2022 15:23:56 -0300 Subject: [PATCH 4/4] Fix for Oracle --- data_diff/databases/oracle.py | 1 - tests/test_joindiff.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index c135f71f..d59e5b55 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -39,7 +39,6 @@ class Oracle(ThreadedDatabase): } ROUNDS_ON_PREC_LOSS = True SUPPORTS_PRIMARY_KEY = True - SUPPORTS_UNIQUE_CONSTAINT = True def __init__(self, *, host, database, thread_count, **kw): self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index 35e2c79f..22ed217d 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -275,8 +275,8 @@ def test_null_pks(self): self.assertRaises(ValueError, list, x) -@test_each_database_in_list(d for d in TEST_DATABASES if d.SUPPORTS_PRIMARY_KEY) -class TestPrimaryKeys(TestPerDatabase): +@test_each_database_in_list(d for d in TEST_DATABASES if d.SUPPORTS_PRIMARY_KEY and d.SUPPORTS_UNIQUE_CONSTAINT) +class TestUniqueConstraint(TestPerDatabase): def setUp(self): super().setUp()