-
Notifications
You must be signed in to change notification settings - Fork 11
[PECOBLR-666]Added support for Variant datatype in SQLAlchemy #42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d697cfc
5ec04bc
d0cce3f
ed7cd94
446c496
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,14 +11,14 @@ | |
| DateTime, | ||
| ) | ||
| from collections.abc import Sequence | ||
| from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap | ||
| from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap, DatabricksVariant | ||
| from sqlalchemy.orm import DeclarativeBase, Session | ||
| from sqlalchemy import select | ||
| from datetime import date, datetime, time, timedelta, timezone | ||
| import pandas as pd | ||
| import numpy as np | ||
| import decimal | ||
|
|
||
| import json | ||
|
|
||
| class TestComplexTypes(TestSetup): | ||
| def _parse_to_common_type(self, value): | ||
|
|
@@ -46,7 +46,7 @@ def _parse_to_common_type(self, value): | |
| ): | ||
| return tuple(value) | ||
| elif isinstance(value, dict): | ||
| return tuple(value.items()) | ||
| return tuple(sorted(value.items())) | ||
| elif isinstance(value, np.generic): | ||
| return value.item() | ||
| elif isinstance(value, decimal.Decimal): | ||
|
|
@@ -152,6 +152,35 @@ class MapTable(Base): | |
|
|
||
| return MapTable, sample_data | ||
|
|
||
| def sample_variant_table(self) -> tuple[DeclarativeBase, dict]: | ||
| class Base(DeclarativeBase): | ||
| pass | ||
|
|
||
| class VariantTable(Base): | ||
| __tablename__ = "sqlalchemy_variant_table" | ||
|
|
||
| int_col = Column(Integer, primary_key=True) | ||
| variant_simple_col = Column(DatabricksVariant()) | ||
| variant_nested_col = Column(DatabricksVariant()) | ||
| variant_array_col = Column(DatabricksVariant()) | ||
| variant_mixed_col = Column(DatabricksVariant()) | ||
|
|
||
| sample_data = { | ||
| "int_col": 1, | ||
| "variant_simple_col": {"key": "value", "number": 42}, | ||
| "variant_nested_col": {"user": {"name": "John", "age": 30}, "active": True}, | ||
| "variant_array_col": [1, 2, 3, "hello", {"nested": "data"}], | ||
| "variant_mixed_col": { | ||
| "string": "test", | ||
| "number": 123, | ||
| "boolean": True, | ||
| "array": [1, 2, 3], | ||
| "object": {"nested": "value"} | ||
| } | ||
| } | ||
|
|
||
| return VariantTable, sample_data | ||
|
|
||
| def test_insert_array_table_sqlalchemy(self): | ||
| table, sample_data = self.sample_array_table() | ||
|
|
||
|
|
@@ -209,3 +238,81 @@ def test_map_table_creation_pandas(self): | |
| stmt = select(table) | ||
| df_result = pd.read_sql(stmt, engine) | ||
| assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data) | ||
|
|
||
| def test_insert_variant_table_sqlalchemy(self): | ||
| table, sample_data = self.sample_variant_table() | ||
|
|
||
| with self.table_context(table) as engine: | ||
|
|
||
| sa_obj = table(**sample_data) | ||
| session = Session(engine) | ||
| session.add(sa_obj) | ||
| session.commit() | ||
|
|
||
| stmt = select(table).where(table.int_col == 1) | ||
| result = session.scalar(stmt) | ||
| compare = {key: getattr(result, key) for key in sample_data.keys()} | ||
| # Parse JSON values back to original format for comparison | ||
| for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: | ||
msrathore-db marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if compare[key] is not None: | ||
| compare[key] = json.loads(compare[key]) | ||
|
Comment on lines
+256
to
+258
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this part even needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes because we get a string so it's better to verify after converting it from JSON to check if the output matches. |
||
|
|
||
| assert compare == sample_data | ||
|
|
||
| def test_variant_table_creation_pandas(self): | ||
| table, sample_data = self.sample_variant_table() | ||
|
|
||
| with self.table_context(table) as engine: | ||
|
|
||
| df = pd.DataFrame([sample_data]) | ||
| dtype_mapping = { | ||
| "variant_simple_col": DatabricksVariant, | ||
| "variant_nested_col": DatabricksVariant, | ||
| "variant_array_col": DatabricksVariant, | ||
| "variant_mixed_col": DatabricksVariant | ||
|
Comment on lines
+268
to
+272
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if there are other types apart from variant, such as int or array,etc. Does this dtype mapping need to provided for only the variant columns or for all There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's better to provide the entire mapping. If we do not provide this mapping then the data is stored as a string for complex type. However for the general types like int, float, etc we do not need to explicitly map |
||
| } | ||
| df.to_sql(table.__tablename__, engine, if_exists="append", index=False, dtype=dtype_mapping) | ||
|
|
||
| stmt = select(table) | ||
| df_result = pd.read_sql(stmt, engine) | ||
| result_dict = df_result.iloc[0].to_dict() | ||
| # Parse JSON values back to original format for comparison | ||
| for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: | ||
| if result_dict[key] is not None: | ||
| result_dict[key] = json.loads(result_dict[key]) | ||
|
|
||
| assert result_dict == sample_data | ||
|
|
||
| def test_variant_literal_processor(self): | ||
| table, sample_data = self.sample_variant_table() | ||
|
|
||
| with self.table_context(table) as engine: | ||
| stmt = table.__table__.insert().values(**sample_data) | ||
|
|
||
| try: | ||
| compiled = stmt.compile( | ||
| dialect=engine.dialect, | ||
| compile_kwargs={"literal_binds": True} | ||
| ) | ||
| sql_str = str(compiled) | ||
|
|
||
| # Assert that JSON actually got inlined | ||
| assert '{"key":"value","number":42}' in sql_str | ||
| except NotImplementedError: | ||
| raise | ||
|
|
||
| with engine.begin() as conn: | ||
| conn.execute(stmt) | ||
|
|
||
| session = Session(engine) | ||
| stmt_select = select(table).where(table.int_col == sample_data["int_col"]) | ||
| result = session.scalar(stmt_select) | ||
|
|
||
| compare = {key: getattr(result, key) for key in sample_data.keys()} | ||
|
|
||
| # Parse JSON values back to original Python objects | ||
| for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: | ||
| if compare[key] is not None: | ||
| compare[key] = json.loads(compare[key]) | ||
|
|
||
| assert compare == sample_data | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the sorting needed? the response from the server is in the same order that we insert right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is needed because parse_json makes changes to the order of the insertion. So we need to verify after sorting itself,