Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion ingestion/examples/sample_data/tests/testCaseResults.json
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,26 @@
{
"name": "addedRows",
"value": "2234"
},
{
"name": "addedColumns",
"value": "2"
},
{
"name": "removedColumns",
"value": "1"
},
{
"name": "changedColumns",
"value": "1"
},
{
"name": "schemaTable1",
"value": "serviceType='BigQuery' fullyQualifiedTableName='ecommerce_db.shopify.dim_address' schema={'order_id': Integer(_notes=[], precision=0, python_type=<class 'int'>), 'quantity': Integer(_notes=[], precision=0, python_type=<class 'int'>), 'product_id': Integer(_notes=[], precision=0, python_type=<class 'int'>), 'customer_id': Integer(_notes=[], precision=0, python_type=<class 'int'>), 'status': String_VaryingAlphanum(_notes=[], collation=None), 'order_date': Date(_notes=[], precision=6, rounds=True), 'price': Decimal(_notes=[], precision=2)}"
},
{
"name": "schemaTable2",
"value": "serviceType='BigQuery' fullyQualifiedTableName='shopify.production_dim_address' schema={'user_id': Integer(_notes=[], precision=0, python_type=<class 'int'>), 'id': Integer(_notes=[], precision=0, python_type=<class 'int'>), 'cycle_name': String_VaryingAlphanum(_notes=[], collation=None), 'status': String_VaryingAlphanum(_notes=[], collation=None), 'order_date': Date(_notes=[], precision=6, rounds=True)}"
}
]
}
Expand Down Expand Up @@ -748,7 +768,7 @@
[
"+",
"facf92d7-05ea-43d2-ba2a-067d63dee60c",
"a8d30187-1409-4606-9259-322a4f6caf74",
"e02e1fac-b650-4db8-8c9d-5fa5edf5d86",
"Amber",
"Albert",
"3170 Warren Orchard Apt. 834",
Expand Down Expand Up @@ -796,6 +816,14 @@
{
"name": "changedColumns",
"value": "1"
},
{
"name": "schemaTable1",
"value": "serviceType='BigQuery' fullyQualifiedTableName='ecommerce_db.shopify.dim_address' schema={'order_id': Integer(_notes=[], precision=0, python_type=<class 'int'>), 'quantity': Integer(_notes=[], precision=0, python_type=<class 'int'>), 'product_id': Integer(_notes=[], precision=0, python_type=<class 'int'>), 'customer_id': Integer(_notes=[], precision=0, python_type=<class 'int'>), 'status': String_VaryingAlphanum(_notes=[], collation=None), 'order_date': Date(_notes=[], precision=6, rounds=True), 'price': Decimal(_notes=[], precision=2)}"
},
{
"name": "schemaTable2",
"value": "serviceType='BigQuery' fullyQualifiedTableName='shopify.production_dim_address' schema={'user_id': Integer(_notes=[], precision=0, python_type=<class 'int'>), 'id': Integer(_notes=[], precision=0, python_type=<class 'int'>), 'cycle_name': String_VaryingAlphanum(_notes=[], collation=None), 'status': String_VaryingAlphanum(_notes=[], collation=None), 'order_date': Date(_notes=[], precision=6, rounds=True)}"
}
]
}
Expand Down
4 changes: 4 additions & 0 deletions ingestion/examples/sample_data/tests/testSuites.json
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,10 @@
{
"name": "table2",
"value": "sample_data.ecommerce_db.shopify.production_dim_address"
},
{
"name": "keyColumns",
"value": "[\"address_id\"]"
}
],
"resolutions": {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from data_diff.diff_tables import DiffResultWrapper
from data_diff.errors import DataDiffMismatchingKeyTypesError
from data_diff.utils import ArithAlphanumeric, CaseInsensitiveDict
from pydantic import BaseModel
from sqlalchemy import Column as SAColumn
from sqlalchemy import literal, select
from sqlalchemy.engine import make_url
Expand Down Expand Up @@ -76,6 +77,26 @@
]


class SchemaDiffResult(BaseModel):
class Config:
arbitrary_types_allowed = True

serviceType: str
fullyQualifiedTableName: str
schema: Dict[str, Dict[str, str]]


class ColumnDiffResult(BaseModel):
class Config:
arbitrary_types_allowed = True

removed: List[str]
added: List[str]
changed: List[str]
schemaTable1: SchemaDiffResult
schemaTable2: SchemaDiffResult


def build_sample_where_clause(
table: TableParameter, key_columns: List[str], salt: str, hex_nounce: str
) -> str:
Expand Down Expand Up @@ -229,12 +250,19 @@ def _run_dimensional_validation(self):
return []

def _run(self) -> TestCaseResult:
result = self.get_column_diff()
if result:
return result
column_diff: ColumnDiffResult = self.get_column_diff()
threshold = self.get_test_case_param_value(
self.test_case.parameterValues, "threshold", int, default=0
)
if column_diff:
# If there are column differences, we set extra_columns to the common columns for the diff
common_columns = list(
set(column_diff.schemaTable1.schema.keys())
& set(column_diff.schemaTable2.schema.keys())
)
self.runtime_params.extraColumns = common_columns
self.runtime_params.table1.extra_columns = common_columns
self.runtime_params.table2.extra_columns = common_columns
table_diff_iter = self.get_table_diff()

if not threshold or self.test_case.computePassedFailedRowCount:
Expand All @@ -255,6 +283,7 @@ def _run(self) -> TestCaseResult:
stats["updated"],
stats["exclusive_A"],
stats["exclusive_B"],
column_diff,
)
count = self._compute_row_count(self.runner, None) # type: ignore
test_case_result.passedRows = stats["unchanged"]
Expand All @@ -268,6 +297,7 @@ def _run(self) -> TestCaseResult:
return self.get_row_diff_test_case_result(
threshold,
self.calculate_diffs_with_limit(table_diff_iter, threshold),
column_diff,
)

def get_incomparable_columns(self) -> List[str]:
Expand Down Expand Up @@ -324,7 +354,11 @@ def get_incomparable_columns(self) -> List[str]:
continue
if col1_type != col2_type:
result.append(column)
return result
return (
result,
table1._schema,
table2._schema,
) # pylint: disable=protected-access

@staticmethod
def _get_column_python_type(column: SAColumn):
Expand Down Expand Up @@ -508,6 +542,7 @@ def get_row_diff_test_case_result(
changed: Optional[int] = None,
removed: Optional[int] = None,
added: Optional[int] = None,
column_diff: Optional[ColumnDiffResult] = None,
) -> TestCaseResult:
"""Build a test case result for a row diff test. If the number of differences is less than the threshold,
the test will pass, otherwise it will fail. The result will contain the number of added, removed, and changed
Expand All @@ -523,6 +558,34 @@ def get_row_diff_test_case_result(
Returns:
TestCaseResult: The result of the row diff test
"""
test_case_results = [
TestResultValue(name="removedRows", value=str(removed)),
TestResultValue(name="addedRows", value=str(added)),
TestResultValue(name="changedRows", value=str(changed)),
TestResultValue(name="diffCount", value=str(total_diffs)),
]

if column_diff:
test_case_results.extend(
[
TestResultValue(
name="removedColumns", value=str(len(column_diff.removed))
),
TestResultValue(
name="addedColumns", value=str(len(column_diff.added))
),
TestResultValue(
name="changedColumns", value=str(len(column_diff.changed))
),
TestResultValue(
name="schemaTable1", value=str(column_diff.schemaTable1)
),
TestResultValue(
name="schemaTable2", value=str(column_diff.schemaTable2)
),
]
)

return TestCaseResult(
timestamp=self.execution_date, # type: ignore
testCaseStatus=self.get_test_case_status(
Expand All @@ -531,12 +594,7 @@ def get_row_diff_test_case_result(
result=f"Found {total_diffs} different rows which is more than the threshold of {threshold}",
failedRows=total_diffs,
validateColumns=False,
testResultValue=[
TestResultValue(name="removedRows", value=str(removed)),
TestResultValue(name="addedRows", value=str(added)),
TestResultValue(name="changedRows", value=str(changed)),
TestResultValue(name="diffCount", value=str(total_diffs)),
],
testResultValue=test_case_results,
)

def _validate_dialects(self):
Expand All @@ -551,7 +609,7 @@ def _validate_dialects(self):
if dialect not in SUPPORTED_DIALECTS:
raise UnsupportedDialectError(name, dialect)

def get_column_diff(self) -> Optional[TestCaseResult]:
def get_column_diff(self) -> Optional[ColumnDiffResult]:
"""Get the column diff between the two tables. If there are no differences, return None."""
removed, added = self.get_changed_added_columns(
[
Expand All @@ -566,12 +624,34 @@ def get_column_diff(self) -> Optional[TestCaseResult]:
],
self.get_case_sensitive(),
)
changed = self.get_incomparable_columns()
changed, schema_table1, schema_table2 = self.get_incomparable_columns()
if removed or added or changed:
return self.column_validation_result(
removed,
added,
changed,
return ColumnDiffResult(
removed=removed,
added=added,
changed=changed,
schemaTable1=SchemaDiffResult(
serviceType=self.runtime_params.table1.database_service_type.name,
fullyQualifiedTableName=self.runtime_params.table1.path,
schema={
c.name.root: {
"type": c.dataTypeDisplay,
"constraints": c.constraint.value,
}
for c in self.runtime_params.table1.columns
},
),
schemaTable2=SchemaDiffResult(
serviceType=self.runtime_params.table2.database_service_type.name,
fullyQualifiedTableName=self.runtime_params.table2.path,
schema={
c.name.root: {
"type": c.dataTypeDisplay,
"constraints": c.constraint.value,
}
for c in self.runtime_params.table2.columns
},
),
)
return None

Expand All @@ -598,10 +678,10 @@ def get_changed_added_columns(
for column in left:
table2_column = right_columns_dict.get(column.name.root)
if table2_column is None:
removed.append(column.name.root)
added.append(column.name.root)
continue
del right_columns_dict[column.name.root]
added.extend(right_columns_dict.keys())
removed.extend(right_columns_dict.keys())
return removed, added

def column_validation_result(
Expand Down
Loading
Loading