diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b5f8d548..7c42e356 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -name: CI +name: CI-COVER-VERSIONS on: push: @@ -46,8 +46,9 @@ jobs: env: DATADIFF_SNOWFLAKE_URI: '${{ secrets.DATADIFF_SNOWFLAKE_URI }}' DATADIFF_PRESTO_URI: '${{ secrets.DATADIFF_PRESTO_URI }}' + DATADIFF_TRINO_URI: '${{ secrets.DATADIFF_TRINO_URI }}' DATADIFF_CLICKHOUSE_URI: 'clickhouse://clickhouse:Password1@localhost:9000/clickhouse' DATADIFF_VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica' run: | chmod +x tests/waiting_for_stack_up.sh - ./tests/waiting_for_stack_up.sh && poetry run unittest-parallel -j 16 + ./tests/waiting_for_stack_up.sh && TEST_ACROSS_ALL_DBS=0 poetry run unittest-parallel -j 16 diff --git a/.github/workflows/ci_full.yml b/.github/workflows/ci_full.yml new file mode 100644 index 00000000..0de7da52 --- /dev/null +++ b/.github/workflows/ci_full.yml @@ -0,0 +1,50 @@ +name: CI-COVER-DATABASES + +on: + push: + paths: + - '**.py' + - '.github/workflows/**' + - '!dev/**' + pull_request: + branches: [ master ] + + workflow_dispatch: + +jobs: + unit_tests: + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: + - "3.10" + + name: Check Python ${{ matrix.python-version }} on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Build the stack + run: docker-compose up -d mysql postgres presto trino clickhouse vertica + + - name: Install Poetry + run: pip install poetry + + - name: Install package + run: "poetry install" + + - name: Run unit tests + env: + DATADIFF_SNOWFLAKE_URI: '${{ secrets.DATADIFF_SNOWFLAKE_URI }}' + DATADIFF_PRESTO_URI: '${{ secrets.DATADIFF_PRESTO_URI }}' + DATADIFF_CLICKHOUSE_URI: 'clickhouse://clickhouse:Password1@localhost:9000/clickhouse' + DATADIFF_VERTICA_URI: 'vertica://vertica:Password1@localhost:5433/vertica' + run: | + chmod +x tests/waiting_for_stack_up.sh + ./tests/waiting_for_stack_up.sh && TEST_ACROSS_ALL_DBS=full poetry run unittest-parallel -j 16 diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 0a5a2fa6..672c4e0b 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -267,6 +267,9 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): if res is None: # May happen due to sum() of 0 items return None return int(res) + elif res_type is datetime: + res = _one(_one(res)) + return res # XXX parse timestamp? elif res_type is tuple: assert len(res) == 1, (sql_code, res) return res[0] diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 3d6f9d34..b5f2f577 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -55,7 +55,6 @@ class Dialect(BaseDialect): "DateTime64": Timestamp, } - def normalize_number(self, value: str, coltype: FractionalType) -> str: # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. # For example: diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index b68b216a..64127e9a 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -127,9 +127,7 @@ def parse_type( precision = int(m.group(1)) return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS) - return super().parse_type( - table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale - ) + return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) class Oracle(ThreadedDatabase): diff --git a/tests/common.py b/tests/common.py index fccd5ddc..c1ae30a0 100644 --- a/tests/common.py +++ b/tests/common.py @@ -36,6 +36,7 @@ N_SAMPLES = int(os.environ.get("N_SAMPLES", DEFAULT_N_SAMPLES)) BENCHMARK = os.environ.get("BENCHMARK", False) N_THREADS = int(os.environ.get("N_THREADS", 1)) +TEST_ACROSS_ALL_DBS = os.environ.get("TEST_ACROSS_ALL_DBS", True) # Should we run the full db<->db test suite? def get_git_revision_short_hash() -> str: @@ -94,6 +95,10 @@ def _print_used_dbs(): logging.info(f"Testing databases: {', '.join(used)}") if unused: logging.info(f"Connection not configured; skipping tests for: {', '.join(unused)}") + if TEST_ACROSS_ALL_DBS: + logging.info( + f"Full tests enabled (every db<->db). May take very long when many dbs are involved. ={TEST_ACROSS_ALL_DBS}" + ) _print_used_dbs() diff --git a/tests/test_cli.py b/tests/test_cli.py index 1c131f54..b63b1c7f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -37,7 +37,10 @@ def setUp(self) -> None: src_table = table(table_src_name, schema={"id": int, "datetime": datetime, "text_comment": str}) self.conn.query(src_table.create()) - self.now = now = arrow.get(datetime.now()) + + self.conn.query("SET @@session.time_zone='+00:00'") + db_time = self.conn.query("select now()", datetime) + self.now = now = arrow.get(db_time) rows = [ (now, "now"), diff --git a/tests/test_database_types.py b/tests/test_database_types.py index c8ed69aa..250b4537 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -25,6 +25,7 @@ N_THREADS, BENCHMARK, GIT_REVISION, + TEST_ACROSS_ALL_DBS, get_conn, random_table_suffix, ) @@ -418,22 +419,42 @@ def __iter__(self): "uuid": UUID_Faker(N_SAMPLES), } + +def _get_test_db_pairs(): + if str(TEST_ACROSS_ALL_DBS).lower() == "full": + for source_db in DATABASE_TYPES: + for target_db in DATABASE_TYPES: + yield source_db, target_db + elif int(TEST_ACROSS_ALL_DBS): + for db_cls in DATABASE_TYPES: + yield db_cls, db.PostgreSQL + yield db.PostgreSQL, db_cls + yield db_cls, db.Snowflake + yield db.Snowflake, db_cls + else: + yield db.PostgreSQL, db.PostgreSQL + + +def get_test_db_pairs(): + active_pairs = {(db1, db2) for db1, db2 in _get_test_db_pairs() if db1 in CONN_STRINGS and db2 in CONN_STRINGS} + for db1, db2 in active_pairs: + yield db1, DATABASE_TYPES[db1], db2, DATABASE_TYPES[db2] + + type_pairs = [] -for source_db, source_type_categories in DATABASE_TYPES.items(): - for target_db, target_type_categories in DATABASE_TYPES.items(): - if CONN_STRINGS.get(source_db, False) and CONN_STRINGS.get(target_db, False): - for type_category, source_types in source_type_categories.items(): # int, datetime, .. - for source_type in source_types: - for target_type in target_type_categories[type_category]: - type_pairs.append( - ( - source_db, - target_db, - source_type, - target_type, - type_category, - ) - ) +for source_db, source_type_categories, target_db, target_type_categories in get_test_db_pairs(): + for type_category, source_types in source_type_categories.items(): # int, datetime, .. + for source_type in source_types: + for target_type in target_type_categories[type_category]: + type_pairs.append( + ( + source_db, + target_db, + source_type, + target_type, + type_category, + ) + ) def sanitize(name):