From e973e04c29472d85921be626cb0b3c7f5c84359b Mon Sep 17 00:00:00 2001 From: Simon Eskildsen Date: Thu, 16 Jun 2022 20:05:41 -0400 Subject: [PATCH] tests: database_types dual-use for benchmarks presto: fix test suite diff_tables: fix not double recursing tests: fix presto create indexes, etc., haven't tested all dbs more --- data_diff/databases/presto.py | 10 +- data_diff/diff_tables.py | 7 + data_diff/sql.py | 5 +- .../standalone/catalog/postgresql.properties | 1 + docker-compose.yml | 16 +- tests/common.py | 15 +- tests/test_database_types.py | 654 ++++++++++++++---- 7 files changed, 578 insertions(+), 130 deletions(-) diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index a0f75010..ae86121e 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -50,7 +50,13 @@ def to_string(self, s: str): def _query(self, sql_code: str) -> list: "Uses the standard SQL cursor interface" - return _query_conn(self._conn, sql_code) + c = self._conn.cursor() + c.execute(sql_code) + if sql_code.lower().startswith("select"): + return c.fetchall() + # Required for the query to actually run 🤯 + if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE): + return c.fetchone() def close(self): self._conn.close() @@ -88,7 +94,7 @@ def _parse_type( datetime_precision = int(m.group(1)) return cls( precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, - rounds=False, + rounds=self.ROUNDS_ON_PREC_LOSS, ) number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal} diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 52f596f9..94f45710 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -403,6 +403,13 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun f"size: {table2.max_key-table1.min_key}" ) + # The entire segment wasn't below the threshold, but the next set of + # segments might be. In that case, it's useless to checksum them. + max_rows_from_keys = max(table1.max_key - table1.min_key, table2.max_key - table2.min_key) + if max_rows_from_keys < self.bisection_threshold: + yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max_rows_from_keys) + return + (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) if count1 == 0 and count2 == 0: diff --git a/data_diff/sql.py b/data_diff/sql.py index 6946c26f..840c96e0 100644 --- a/data_diff/sql.py +++ b/data_diff/sql.py @@ -115,7 +115,10 @@ class Checksum(Sql): def compile(self, c: Compiler): compiled_exprs = ", ".join(map(c.compile, self.exprs)) - expr = f"concat({compiled_exprs})" + expr = compiled_exprs + if len(self.exprs) > 1: + expr = f"concat({compiled_exprs})" + md5 = c.database.md5_to_int(expr) return f"sum({md5})" diff --git a/dev/presto-conf/standalone/catalog/postgresql.properties b/dev/presto-conf/standalone/catalog/postgresql.properties index b6bc3006..ef479967 100644 --- a/dev/presto-conf/standalone/catalog/postgresql.properties +++ b/dev/presto-conf/standalone/catalog/postgresql.properties @@ -2,3 +2,4 @@ connector.name=postgresql connection-url=jdbc:postgresql://postgres:5432/postgres connection-user=postgres connection-password=Password1 +allow-drop-table=true diff --git a/docker-compose.yml b/docker-compose.yml index 286ae81c..bb937c6b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,13 +4,23 @@ services: postgres: container_name: postgresql image: postgres:14.1-alpine + shm_size: 1g # work_mem: less tmp files # maintenance_work_mem: improve table-level op perf # max_wal_size: allow more time before merging to heap command: > - -c work_mem=1GB - -c maintenance_work_mem=1GB - -c max_wal_size=8GB + -c shared_buffers=16GB + -c effective_cache_size=48GB + -c maintenance_work_mem=2GB + -c checkpoint_completion_target=0.9 + -c default_statistics_target=100 + -c random_page_cost=1.1 + -c effective_io_concurrency=200 + -c work_mem=20971kB + -c max_worker_processes=14 + -c max_parallel_workers_per_gather=4 + -c max_parallel_workers=14 + -c max_parallel_maintenance_workers=4 restart: always volumes: - postgresql-data:/var/lib/postgresql/data:delegated diff --git a/tests/common.py b/tests/common.py index e0b35ac1..1547aca6 100644 --- a/tests/common.py +++ b/tests/common.py @@ -2,8 +2,21 @@ from data_diff import databases as db import logging +import os -logging.basicConfig(level=logging.INFO) +DEFAULT_N_SAMPLES = 50 +N_SAMPLES = int(os.environ.get('N_SAMPLES', DEFAULT_N_SAMPLES)) +BENCHMARK = os.environ.get('BENCHMARK', False) + +level = logging.WARN +if os.environ.get('DEBUG', False): + level = logging.DEBUG + +logging.basicConfig(level=level) +logging.getLogger("diff_tables").setLevel(level) +logging.getLogger("database").setLevel(level) +if BENCHMARK: + logging.getLogger("benchmark").setLevel(logging.DEBUG) TEST_MYSQL_CONN_STRING: str = "mysql://mysql:Password1@localhost/mysql" TEST_POSTGRESQL_CONN_STRING: str = None diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 4515de29..2a464376 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -1,34 +1,125 @@ -from contextlib import suppress import unittest import time +import csv +import math +from google.cloud import bigquery +import re +import time +import datetime +from parameterized import parameterized +import rich.progress +from data_diff import databases as db +from data_diff.diff_tables import TableDiffer, TableSegment +from .common import ( + CONN_STRINGS, + BENCHMARK, + DEFAULT_N_SAMPLES, + N_SAMPLES, + str_to_checksum, +) import logging from decimal import Decimal -from parameterized import parameterized -from data_diff import databases as db -from data_diff.diff_tables import TableDiffer, TableSegment -from .common import CONN_STRINGS +class Faker: + pass -logging.getLogger("diff_tables").setLevel(logging.ERROR) -logging.getLogger("database").setLevel(logging.WARN) +class PaginatedTable: + # We can't query all the rows at once for large tables. It'll occupy too + # much memory. + RECORDS_PER_BATCH = 1000000 -CONNS = {k: db.connect_to_uri(v) for k, v in CONN_STRINGS.items()} + def __init__(self, table, conn): + self.table = table + self.conn = conn -CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'", None) + def __iter__(self): + self.last_id = 0 + self.values = [] + self.value_index = 0 + return self -TYPE_SAMPLES = { - "int": [127, -3, -9, 37, 15, 127], - "datetime_no_timezone": [ + def __next__(self) -> str: + if self.value_index == len(self.values): # end of current batch + query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC LIMIT {self.RECORDS_PER_BATCH}" + if isinstance(self.conn, db.Oracle): + query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC OFFSET 0 ROWS FETCH NEXT {self.RECORDS_PER_BATCH} ROWS ONLY" + + self.values = self.conn.query(query, list) + if len(self.values) == 0: # we must be done! + raise StopIteration + self.last_id = self.values[-1][0] + self.value_index = 0 + + this_value = self.values[self.value_index] + self.value_index += 1 + return this_value + + +class DateTimeFaker: + MANUAL_FAKES = [ "2020-01-01 15:10:10", "2020-02-01 9:9:9", "2022-03-01 15:10:01.139", "2022-04-01 15:10:02.020409", "2022-05-01 15:10:03.003030", "2022-06-01 15:10:05.009900", - ], - "float": [ + ] + + def __init__(self, max): + self.max = max + + def __iter__(self): + self.prev = datetime.datetime(2000, 1, 1, 0, 0, 0, 0) + self.i = 0 + return self + + def __len__(self): + return self.max + + def __next__(self) -> str: + if self.i < len(self.MANUAL_FAKES): + fake = self.MANUAL_FAKES[self.i] + self.i += 1 + return fake + elif self.i < self.max: + self.prev = self.prev + datetime.timedelta(seconds=3, microseconds=571) + self.i += 1 + return str(self.prev) + else: + raise StopIteration + + +class IntFaker: + MANUAL_FAKES = [127, -3, -9, 37, 15, 127] + + def __init__(self, max): + self.max = max + + def __iter__(self): + self.prev = -128 + self.i = 0 + return self + + def __len__(self): + return self.max + + def __next__(self) -> int: + if self.i < len(self.MANUAL_FAKES): + fake = self.MANUAL_FAKES[self.i] + self.i += 1 + return fake + elif self.i < self.max: + self.prev += 1 + self.i += 1 + return self.prev + else: + raise StopIteration + + +class FloatFaker: + MANUAL_FAKES = [ 0.0, 0.1, 0.00188, @@ -45,65 +136,112 @@ 1 / 1094893892389, 1 / 10948938923893289, 3.141592653589793, - ], + ] + + def __init__(self, max): + self.max = max + + def __iter__(self): + self.prev = -10.0001 + self.i = 0 + return self + + def __len__(self): + return self.max + + def __next__(self) -> float: + if self.i < len(self.MANUAL_FAKES): + fake = self.MANUAL_FAKES[self.i] + self.i += 1 + return fake + elif self.i < self.max: + self.prev += 0.00571 + self.i += 1 + return self.prev + else: + raise StopIteration + + +TYPE_SAMPLES = { + "int": IntFaker(N_SAMPLES), + "datetime_no_timezone": DateTimeFaker(N_SAMPLES), + "float": FloatFaker(N_SAMPLES), } +# This adds _benchmark after the test name so we can easily run them with -k +# benchmark. +BENCHMARK_TESTS = [ + # "test_types_postgres_int_to_postgres_int", + # "test_types_mysql_int_to_mysql_int", + # "test_types_postgres_int_to_mysql_int", + # "test_types_postgres_timestamp6_no_tz_to_mysql_timestamp", + # "test_types_postgres_timestamp6_no_tz_to_snowflake_timestamp9", + # "test_types_postgres_int_to_presto_int", + # "test_types_postgres_int_to_redshift_int", + # "test_types_postgres_int_to_snowflake_int", + # "test_types_postgres_int_to_bigquery_int", + "test_types_snowflake_int_to_snowflake_int", +] + DATABASE_TYPES = { db.PostgreSQL: { # https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-INT "int": [ # "smallint", # 2 bytes - # "int", # 4 bytes - # "bigint", # 8 bytes + # "int", # 4 bytes + # "bigint", # 8 bytes ], # https://www.postgresql.org/docs/current/datatype-datetime.html "datetime_no_timezone": [ "timestamp(6) without time zone", - "timestamp(3) without time zone", - "timestamp(0) without time zone", + # "timestamp(3) without time zone", + # "timestamp(0) without time zone", ], # https://www.postgresql.org/docs/current/datatype-numeric.html "float": [ - "real", - "float", - "double precision", - "numeric(6,3)", + # "real", + # "float", + # "double precision", + # "numeric(6,3)", ], }, db.MySQL: { # https://dev.mysql.com/doc/refman/8.0/en/integer-types.html "int": [ - # "tinyint", # 1 byte - # "smallint", # 2 bytes - # "mediumint", # 3 bytes - # "int", # 4 bytes - # "bigint", # 8 bytes + # "tinyint", # 1 byte + # "smallint", # 2 bytes + # "mediumint", # 3 bytes + # "int", # 4 bytes + # "bigint", # 8 bytes ], # https://dev.mysql.com/doc/refman/8.0/en/datetime.html "datetime_no_timezone": [ - "timestamp(6)", - "timestamp(3)", - "timestamp(0)", + # "timestamp(6)", + # "timestamp(3)", + # "timestamp(0)", "timestamp", - "datetime(6)", + # "datetime(6)", ], # https://dev.mysql.com/doc/refman/8.0/en/numeric-types.html "float": [ - "float", - "double", - "numeric", - "numeric(65, 10)", + # "float", + # "double", + # "numeric", + # "numeric(65, 10)", ], }, db.BigQuery: { + "int": [ + # "int", + ], "datetime_no_timezone": [ "timestamp", # "datetime", ], "float": [ - "numeric", - "float64", - "bignumeric", + # "numeric", + # "float64", + # "bignumeric", ], }, db.Snowflake: { @@ -119,15 +257,15 @@ ], # https://docs.snowflake.com/en/sql-reference/data-types-datetime.html "datetime_no_timezone": [ - "timestamp(0)", - "timestamp(3)", - "timestamp(6)", + # "timestamp(0)", + # "timestamp(3)", + # "timestamp(6)", "timestamp(9)", ], # https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#decimal-numeric "float": [ - "float", - "numeric", + # "float", + # "numeric", ], }, db.Redshift: { @@ -139,9 +277,9 @@ ], # https://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html#r_Numeric_types201-floating-point-types "float": [ - "float4", - "float8", - "numeric", + # "float4", + # "float8", + # "numeric", ], }, db.Oracle: { @@ -150,12 +288,12 @@ ], "datetime_no_timezone": [ "timestamp with local time zone", - "timestamp(6) with local time zone", - "timestamp(9) with local time zone", + # "timestamp(6) with local time zone", + # "timestamp(9) with local time zone", ], "float": [ - "float", - "numeric", + # "float", + # "numeric", ], }, db.Presto: { @@ -163,26 +301,33 @@ # "tinyint", # 1 byte # "smallint", # 2 bytes # "mediumint", # 3 bytes - # "int", # 4 bytes - # "bigint", # 8 bytes - ], - "datetime_no_timezone": [ - "timestamp(6)", - "timestamp(3)", - "timestamp(0)", - "timestamp", - "datetime(6)", + # "int", # 4 bytes + # "bigint", # 8 bytes ], + "datetime_no_timezone": ["timestamp"], "float": [ - "real", - "double", - "decimal(10,2)", - "decimal(30,6)", + # "real", + # "double", + # "decimal(10,2)", + # "decimal(30,6)", ], }, } +def human_format(n): + millnames = ["", "K", "M", "B"] + n = float(n) + millidx = max( + 0, + min( + len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3)) + ), + ) + + return "{:.0f}{}".format(n / 10 ** (3 * millidx), millnames[millidx]) + + type_pairs = [] for source_db, source_type_categories in DATABASE_TYPES.items(): for target_db, target_type_categories in DATABASE_TYPES.items(): @@ -192,7 +337,9 @@ ) in source_type_categories.items(): # int, datetime, .. for source_type in source_types: for target_type in target_type_categories[type_category]: - if CONNS.get(source_db, False) and CONNS.get(target_db, False): + if CONN_STRINGS.get(source_db, False) and CONN_STRINGS.get( + target_db, False + ): type_pairs.append( ( source_db, @@ -203,106 +350,367 @@ ) ) +# timestamp(9) +def sanitize(name): + name = name.lower() + name = re.sub(r"[\(\)]", "", name) # timestamp(9) -> timestamp9 + name = name.replace(r"without time zone", "no_tz") # too long for some DBs + return parameterized.to_safe_name(name) + + # Pass --verbose to test run to get a nice output. def expand_params(testcase_func, param_num, param): source_db, target_db, source_type, target_type, type_category = param.args source_db_type = source_db.__name__ target_db_type = target_db.__name__ - return "%s_%s_%s_to_%s_%s" % ( + name = "%s_%s_%s_to_%s_%s" % ( testcase_func.__name__, - source_db_type, - parameterized.to_safe_name(source_type), - target_db_type, - parameterized.to_safe_name(target_type), + sanitize(source_db_type), + sanitize(source_type), + sanitize(target_db_type), + sanitize(target_type), ) + if name in BENCHMARK_TESTS: + name += "_benchmark" -def _insert_to_table(conn, table, values): - insertion_query = f"INSERT INTO {table} (id, col) " - - if isinstance(conn, db.Oracle): - selects = [] - for j, sample in values: - if isinstance(sample, (float, Decimal, int)): - value = str(sample) - else: - value = f"timestamp '{sample}'" - selects.append(f"SELECT {j}, {value} FROM dual") - insertion_query += " UNION ALL ".join(selects) - else: - insertion_query += " VALUES " - for j, sample in values: - if isinstance(sample, (float, Decimal)): - value = str(sample) - else: - value = f"'{sample}'" - insertion_query += f"({j}, {value})," - - insertion_query = insertion_query[0:-1] - - conn.query(insertion_query, None) - - if not isinstance(conn, db.BigQuery): - conn.query("COMMIT", None) + return name def _drop_table_if_exists(conn, table): - with suppress(db.QueryError): + try: if isinstance(conn, db.Oracle): conn.query(f"DROP TABLE {table}", None) - conn.query(f"DROP TABLE {table}", None) else: conn.query(f"DROP TABLE IF EXISTS {table}", None) + except BaseException as err: + # Oracle's error for not existing + if str(err).startswith("ORA-00942"): + pass + else: + raise (err) class TestDiffCrossDatabaseTables(unittest.TestCase): + # https://docs.python.org/2/library/unittest.html#unittest.TestCase.maxDiff + # For showing assertEqual differences under a certain length. + maxDiff = 10000 + + def tearDown(self): + self.src_conn.close() + self.dst_conn.close() + self.src_thread_pool.close() + self.dst_thread_pool.close() + @parameterized.expand(type_pairs, name_func=expand_params) def test_types(self, source_db, target_db, source_type, target_type, type_category): - start = time.time() + # TODO: Rename to DDL conn. We use a single thread here to avoid the + # thread pool giving us a new thread for some reason, where the changes + # might not be applied to. + self.src_conn = src_conn = db.connect_to_uri(CONN_STRINGS[source_db], 1) + self.dst_conn = dst_conn = db.connect_to_uri(CONN_STRINGS[target_db], 1) - self.src_conn = src_conn = CONNS[source_db] - self.dst_conn = dst_conn = CONNS[target_db] + if source_db == db.MySQL: + src_conn.query("SET @@session.time_zone='+00:00'", None) + if target_db == db.MySQL: + dst_conn.query("SET @@session.time_zone='+00:00'", None) - self.connections = [self.src_conn, self.dst_conn] sample_values = TYPE_SAMPLES[type_category] # Limit in MySQL is 64 - src_table_name = f"src_{self._testMethodName[:60]}" - dst_table_name = f"dst_{self._testMethodName[:60]}" + # src_table_name = f"src_{self._testMethodName[:60]}" + # dst_table_name = f"dst_{self._testMethodName[:60]}" + # We need to include the database name because of e.g. Presto which has + # Postgres as a backing catalog and shouldn't re-use those. + src_table_name = f"src_{source_db.__name__.lower()}_{sanitize(source_type)}_{human_format(N_SAMPLES)}" src_table_path = src_conn.parse_table_name(src_table_name) - dst_table_path = dst_conn.parse_table_name(dst_table_name) src_table = src_conn.quote(".".join(src_table_path)) - dst_table = dst_conn.quote(".".join(dst_table_path)) - _drop_table_if_exists(src_conn, src_table) - src_conn.query(f"CREATE TABLE {src_table}(id int, col {source_type})", None) - _insert_to_table(src_conn, src_table, enumerate(sample_values, 1)) + dst_table_name = ( + f"dst_{target_db.__name__.lower()}_{sanitize(target_type)}_{src_table_name}" + ) - values_in_source = src_conn.query(f"SELECT id, col FROM {src_table}", list) + if len(dst_table_name) > 64: # length limits, cut from start + dst_table_name = dst_table_name[len(dst_table_name) - 60 :] + # Since we insert `src` to `dst`, the `dst` can be different for the + # same type. 😭 + dst_table_path = dst_conn.parse_table_name(dst_table_name) + dst_table = dst_conn.quote(".".join(dst_table_path)) - _drop_table_if_exists(dst_conn, dst_table) - dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type})", None) - _insert_to_table(dst_conn, dst_table, values_in_source) + # For Benchmark we might be working with millions of rows; let's not + # recreate them. + if not BENCHMARK: + _drop_table_if_exists(src_conn, src_table) + _drop_table_if_exists(dst_conn, dst_table) - self.table = TableSegment(self.src_conn, src_table_path, "id", None, ("col",), case_sensitive=False) - self.table2 = TableSegment(self.dst_conn, dst_table_path, "id", None, ("col",), case_sensitive=False) + # 990785 + start = time.time() + already_seeded = self._create_table(src_conn, src_table, source_type) + if not already_seeded: + self._insert_to_table( + src_conn, src_table, source_type, enumerate(sample_values, 1) + ) + insertion_source_duration = time.time() - start - self.assertEqual(len(sample_values), self.table.count()) - self.assertEqual(len(sample_values), self.table2.count()) + start = time.time() + already_seeded = self._create_table(dst_conn, dst_table, target_type) + if not already_seeded: + values_in_source = PaginatedTable(src_table, src_conn) + self._insert_to_table(dst_conn, dst_table, target_type, values_in_source) + insertion_target_duration = time.time() - start + + self.src_thread_pool = src_thread_pool = db.connect_to_uri( + CONN_STRINGS[source_db], 8 + ) + self.dst_thread_pool = dst_thread_pool = db.connect_to_uri( + CONN_STRINGS[target_db], 8 + ) + + self.table = TableSegment( + src_thread_pool, src_table_path, "id", None, ("col",), case_sensitive=False + ) + self.table2 = TableSegment( + dst_thread_pool, dst_table_path, "id", None, ("col",), case_sensitive=False + ) + + start = time.time() + self.assertEqual(N_SAMPLES, self.table.count()) + count1_duration = time.time() - start - differ = TableDiffer(bisection_threshold=3, bisection_factor=2) # ensure we actually checksum + start = time.time() + self.assertEqual(N_SAMPLES, self.table2.count()) + count2_duration = time.time() - start + + # For large sample sizes (e.g. for benchmarks) we set the batch to ~10k, + # and try to keep the checksummed batches to a minimum of 250k records + # each, to minimize round-trips while also trying to pump some + # concurrency. + # + # For unit tests, we keep the values low to ensure we actually use + # checksums even for small tests. + ch_threshold = 10_000 if N_SAMPLES > DEFAULT_N_SAMPLES else 3 + ch_factor = ( + min(max(int(N_SAMPLES / 250_000), 2), 128) + if N_SAMPLES > DEFAULT_N_SAMPLES + else 2 + ) + ch_threads = 16 + + differ = TableDiffer( + bisection_threshold=ch_threshold, + bisection_factor=ch_factor, + max_threadpool_size=ch_threads, + ) + start = time.time() diff = list(differ.diff_tables(self.table, self.table2)) + checksum_duration = time.time() - start expected = [] self.assertEqual(expected, diff) self.assertEqual(0, differ.stats.get("rows_downloaded", 0)) - # Ensure that Python agrees with the checksum! - differ = TableDiffer(bisection_threshold=1000000000) + start = time.time() + + # Here, we force-download. + # + # By default, we set the threshold above the samples. If we have more + # samples than the default, we'll split it into a bunch of segments of + # ~100K records to pull at a time. This ensure it's threaded in a + # benchmarking context where BENCHMARK is set. + dl_factor = ( + max(int(N_SAMPLES / 100_000), 2) if N_SAMPLES > DEFAULT_N_SAMPLES else 2 + ) + dl_threshold = ( + int(N_SAMPLES / dl_factor) + 1 + if N_SAMPLES > DEFAULT_N_SAMPLES + else N_SAMPLES + 1 + ) + dl_threads = 16 + + differ = TableDiffer( + bisection_threshold=dl_threshold, + bisection_factor=dl_factor, + max_threadpool_size=dl_threads, + ) diff = list(differ.diff_tables(self.table, self.table2)) expected = [] self.assertEqual(expected, diff) self.assertEqual(len(sample_values), differ.stats.get("rows_downloaded", 0)) - duration = time.time() - start - # print(f"source_db={source_db.__name__} target_db={target_db.__name__} source_type={source_type} target_type={target_type} duration={round(duration * 1000, 2)}ms") + download_duration = time.time() - start + + logging.getLogger("benchmark").debug( + f""" + test={self._testMethodName} + source_db={source_db.__name__} + target_db={target_db.__name__} + rows={N_SAMPLES} + rows_human={human_format(N_SAMPLES)} + + src_table={src_table} + target_table={dst_table} + source_type={repr(source_type)} + target_type={repr(target_type)} + + insertion_source={insertion_source_duration:.2f}s + insertion_target={insertion_target_duration:.2f}s + + count_source={count1_duration:.3f}s + count_target={count2_duration:.3f}s + checksum={checksum_duration:.3f}s + download={download_duration:.3f}s + + checksum_threads={ch_threads} + checksum_bisection_factor={ch_factor} + checksum_bisection_threshold={ch_threshold} + + download_threads={dl_threads} + download_bisection_factor={dl_factor} + download_bisection_threshold={dl_threshold} + """ + ) + + if BENCHMARK: + with open("benchmark.csv", "a") as file: + file.write( + f"{self._testMethodName}, {source_db.__name__} -> {target_db.__name__}, {N_SAMPLES}, {max(count1_duration, count2_duration):.3f}, {checksum_duration:.3f}, {download_duration:.3f}, {source_type}, {target_type}\n" + ) + + def _create_table(self, conn, table, type) -> bool: + if isinstance(conn, db.Oracle): + already_exists = ( + conn.query( + f"select count(*) from tab where tname='{table.upper()}'", int + ) + > 0 + ) + if not already_exists: + conn.query(f"CREATE TABLE {table}(id int, col {type})", None) + else: + conn.query(f"CREATE TABLE IF NOT EXISTS {table}(id int, col {type})", None) + + idx_name = f"idx_{table[1:-1]}" + + if isinstance(conn, db.MySQL) or isinstance(conn, db.Oracle): + max_suffix = len("_id_col") + if len(idx_name) + max_suffix > 64: # length limits, cut from start + idx_name = idx_name[65 - max_suffix - len(idx_name)] + + try: + conn.query( + f"CREATE UNIQUE INDEX {idx_name}_id_col ON {table}(id, col)", + None, + ) + conn.query( + f"CREATE UNIQUE INDEX {idx_name}_id ON {table}(id)", + None, + ) + except Exception as err: + if "Duplicate key name" in str(err): # mysql + pass + elif "such column list already indexed" in str(err): # oracle + pass + elif "name is already used" in str(err): # oracle + pass + else: + raise (err) + elif ( + not isinstance(conn, db.Snowflake) + and not isinstance(conn, db.Presto) + and not isinstance(conn, db.Redshift) + and not isinstance(conn, db.BigQuery) + ): + conn.query( + f"CREATE UNIQUE INDEX IF NOT EXISTS {idx_name}_id_col ON {table} (id, col)", + None, + ) + conn.query( + f"CREATE UNIQUE INDEX IF NOT EXISTS {idx_name}_id ON {table} (id)", + None, + ) + + # print(conn.query("SHOW TABLES FROM public", list)) + # print(conn.query("SELECT * FROM INFORMATION_SCHEMA.TABLES", list)) + existing_count = conn.query(f"SELECT COUNT(*) FROM {table}", int) + + # Ensure it's clean if it was partially instantiated. + # This should only be relevant for BENCHMARK. + if existing_count != N_SAMPLES and existing_count != 0: + _drop_table_if_exists(conn, table) + return self._create_table(conn, table, type) + + if not isinstance(conn, db.BigQuery): + conn.query("COMMIT", None) + + return existing_count == N_SAMPLES + + def _insert_to_table(self, conn, table, col_type, values): + default_insertion_query = f"INSERT INTO {table} (id, col) VALUES " + if isinstance(conn, db.Oracle): + default_insertion_query = f"INSERT INTO {table} (id, col)" + + insertion_query = default_insertion_query + + if BENCHMARK: + description = f"{type(conn).__name__}: {table}" + values = rich.progress.track( + values, total=N_SAMPLES, description=description + ) + + selects = [] + with open("_tmp.csv", "w+", newline="") as csvfile: + writer = csv.writer(csvfile) + writer.writerow(["id", "col"]) + for id, sample in values: + if ( + isinstance(conn, db.Presto) or isinstance(conn, db.Oracle) + ) and col_type.startswith("timestamp"): + sample = f"timestamp '{sample}'" # must be cast... + elif isinstance(sample, int) or isinstance(sample, float): + pass # don't make string, some dbs need them to be raw + elif col_type == "timestamp" and isinstance(conn, db.BigQuery): + pass + else: + sample = f"'{sample}'" + + # TODO: create one per test; so we can parallelize + if isinstance(conn, db.Oracle): + selects.append(f"SELECT {id}, {sample} FROM dual") + else: + insertion_query += f"({id}, {sample})," + + # snowflake has some annoying limitations here + if id % 8000 == 0 and not isinstance(conn, db.BigQuery): + if isinstance(conn, db.Oracle): + insertion_query += " UNION ALL ".join(selects) + conn.query(insertion_query, None) + selects = [] + else: + conn.query(insertion_query[0:-1], None) + + insertion_query = default_insertion_query + else: + writer.writerow([id, sample]) + + if not isinstance(conn, db.BigQuery): + if ( + insertion_query != default_insertion_query + ): # didn't end at a clean divisor + conn.query(insertion_query[0:-1], None) + conn.query("COMMIT", None) + elif isinstance(conn, db.Oracle) and len(selects) > 0: + insertion_query += " UNION ALL ".join(selects) + conn.query(insertion_query, None) + conn.query("COMMIT", None) + else: + client = conn._client + job_config = bigquery.LoadJobConfig( + source_format=bigquery.SourceFormat.CSV, + skip_leading_rows=1, + autodetect=True, + ) + with open("_tmp.csv", "rb") as source_file: + job = client.load_table_from_file( + source_file, table[1:-1], job_config=job_config + ) + job.result()