Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions sqlalchemy_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal
from uuid import UUID
import json

# By convention, backend-specific SQLA types are defined in uppercase
# This dialect exposes Databricks SQL's TIMESTAMP and TINYINT types
# This dialect exposes Databricks SQL's TIMESTAMP, TINYINT, and VARIANT types
# as these are not covered by the generic, camelcase types shown below
from databricks.sqlalchemy import TIMESTAMP, TINYINT
from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksVariant

# Beside the CamelCase types shown below, line comments reflect
# the underlying Databricks SQL / Delta table type
Expand Down Expand Up @@ -82,6 +83,12 @@ class SampleObject(Base):
datetime_col_ntz = Column(DateTime)
time_col = Column(Time)
uuid_col = Column(Uuid)
variant_col = Column(DatabricksVariant)

Base.metadata.drop_all(engine)

# Output SQL is:
# DROP TABLE pysql_sqlalchemy_example_table

# This generates a CREATE TABLE statement against the catalog and schema
# specified in the connection string
Expand All @@ -100,6 +107,7 @@ class SampleObject(Base):
# datetime_col_ntz TIMESTAMP_NTZ,
# time_col STRING,
# uuid_col STRING,
# variant_col VARIANT,
# PRIMARY KEY (bigint_col)
# ) USING DELTA

Expand All @@ -120,6 +128,23 @@ class SampleObject(Base):
"datetime_col_ntz": datetime(1990, 12, 4, 6, 33, 41),
"time_col": time(23, 59, 59),
"uuid_col": UUID(int=255),
"variant_col": {
"name": "John Doe",
"age": 30,
"address": {
"street": "123 Main St",
"city": "San Francisco",
"state": "CA",
"zip": "94105"
},
"hobbies": ["reading", "hiking", "cooking"],
"is_active": True,
"metadata": {
"created_at": "2024-01-15T10:30:00Z",
"version": 1.2,
"tags": ["premium", "verified"]
}
},
}
sa_obj = SampleObject(**sample_object)

Expand All @@ -140,7 +165,8 @@ class SampleObject(Base):
# datetime_col,
# datetime_col_ntz,
# time_col,
# uuid_col
# uuid_col,
# variant_col
# )
# VALUES
# (
Expand All @@ -154,7 +180,8 @@ class SampleObject(Base):
# :datetime_col,
# :datetime_col_ntz,
# :time_col,
# :uuid_col
# :uuid_col,
# PARSE_JSON(:variant_col)
# )

# Here we build a SELECT query using ORM
Expand All @@ -165,6 +192,7 @@ class SampleObject(Base):

# Finally, we read out the input data and compare it to the output
compare = {key: getattr(result, key) for key in sample_object.keys()}
compare['variant_col'] = json.loads(compare['variant_col'])
assert compare == sample_object

# Then we drop the demonstration table
Expand Down
10 changes: 9 additions & 1 deletion src/databricks/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
TIMESTAMP_NTZ,
DatabricksArray,
DatabricksMap,
DatabricksVariant,
)

__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ", "DatabricksArray", "DatabricksMap"]
__all__ = [
"TINYINT",
"TIMESTAMP",
"TIMESTAMP_NTZ",
"DatabricksArray",
"DatabricksMap",
"DatabricksVariant",
]
1 change: 1 addition & 0 deletions src/databricks/sqlalchemy/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def get_comment_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[st
"map": sqlalchemy.types.String,
"struct": sqlalchemy.types.String,
"uniontype": sqlalchemy.types.String,
"variant": type_overrides.DatabricksVariant,
"decimal": sqlalchemy.types.Numeric,
"timestamp": type_overrides.TIMESTAMP,
"timestamp_ntz": type_overrides.TIMESTAMP_NTZ,
Expand Down
60 changes: 60 additions & 0 deletions src/databricks/sqlalchemy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

from databricks.sql.utils import ParamEscaper

from sqlalchemy.sql import expression
import json


def process_literal_param_hack(value: Any):
"""This method is supposed to accept a Python type and return a string representation of that type.
Expand Down Expand Up @@ -397,3 +400,60 @@ def compile_databricks_map(type_, compiler, **kw):
key_type = compiler.process(type_.key_type, **kw)
value_type = compiler.process(type_.value_type, **kw)
return f"MAP<{key_type},{value_type}>"


class DatabricksVariant(UserDefinedType):
"""
A custom variant type for storing semi-structured data including STRUCT, ARRAY, MAP, and scalar types.
Note: VARIANT MAP types can only have STRING keys.

Examples:
DatabricksVariant() -> VARIANT

Usage:
Column('data', DatabricksVariant())
"""

cache_ok = True

def __init__(self):
self.pe = ParamEscaper()

def bind_processor(self, dialect):
"""Process values before sending to database."""

def process(value):
if value is None:
return None
try:
return json.dumps(value, ensure_ascii=False, separators=(",", ":"))
except (TypeError, ValueError) as e:
raise ValueError(f"Cannot serialize value {value} to JSON: {e}")

return process

def bind_expression(self, bindvalue):
"""Wrap with PARSE_JSON() in SQL"""
return expression.func.PARSE_JSON(bindvalue)

def literal_processor(self, dialect):
"""Process literal values for SQL generation.
For VARIANT columns, use PARSE_JSON() to properly insert data.
"""

def process(value):
if value is None:
return "NULL"
try:
return self.pe.escape_string(
json.dumps(value, ensure_ascii=False, separators=(",", ":"))
)
except (TypeError, ValueError) as e:
raise ValueError(f"Cannot serialize value {value} to JSON: {e}")

return process


@compiles(DatabricksVariant, "databricks")
def compile_variant(type_, compiler, **kw):
return "VARIANT"
113 changes: 110 additions & 3 deletions tests/test_local/e2e/test_complex_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
DateTime,
)
from collections.abc import Sequence
from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap
from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap, DatabricksVariant
from sqlalchemy.orm import DeclarativeBase, Session
from sqlalchemy import select
from datetime import date, datetime, time, timedelta, timezone
import pandas as pd
import numpy as np
import decimal

import json

class TestComplexTypes(TestSetup):
def _parse_to_common_type(self, value):
Expand Down Expand Up @@ -46,7 +46,7 @@ def _parse_to_common_type(self, value):
):
return tuple(value)
elif isinstance(value, dict):
return tuple(value.items())
return tuple(sorted(value.items()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the sorting needed? the response from the server is in the same order that we insert right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed because parse_json makes changes to the order of the insertion. So we need to verify after sorting itself,

elif isinstance(value, np.generic):
return value.item()
elif isinstance(value, decimal.Decimal):
Expand Down Expand Up @@ -152,6 +152,35 @@ class MapTable(Base):

return MapTable, sample_data

def sample_variant_table(self) -> tuple[DeclarativeBase, dict]:
class Base(DeclarativeBase):
pass

class VariantTable(Base):
__tablename__ = "sqlalchemy_variant_table"

int_col = Column(Integer, primary_key=True)
variant_simple_col = Column(DatabricksVariant())
variant_nested_col = Column(DatabricksVariant())
variant_array_col = Column(DatabricksVariant())
variant_mixed_col = Column(DatabricksVariant())

sample_data = {
"int_col": 1,
"variant_simple_col": {"key": "value", "number": 42},
"variant_nested_col": {"user": {"name": "John", "age": 30}, "active": True},
"variant_array_col": [1, 2, 3, "hello", {"nested": "data"}],
"variant_mixed_col": {
"string": "test",
"number": 123,
"boolean": True,
"array": [1, 2, 3],
"object": {"nested": "value"}
}
}

return VariantTable, sample_data

def test_insert_array_table_sqlalchemy(self):
table, sample_data = self.sample_array_table()

Expand Down Expand Up @@ -209,3 +238,81 @@ def test_map_table_creation_pandas(self):
stmt = select(table)
df_result = pd.read_sql(stmt, engine)
assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data)

def test_insert_variant_table_sqlalchemy(self):
table, sample_data = self.sample_variant_table()

with self.table_context(table) as engine:

sa_obj = table(**sample_data)
session = Session(engine)
session.add(sa_obj)
session.commit()

stmt = select(table).where(table.int_col == 1)
result = session.scalar(stmt)
compare = {key: getattr(result, key) for key in sample_data.keys()}
# Parse JSON values back to original format for comparison
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
if compare[key] is not None:
compare[key] = json.loads(compare[key])
Comment on lines +256 to +258
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this part even needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes because we get a string so it's better to verify after converting it from JSON to check if the output matches.


assert compare == sample_data

def test_variant_table_creation_pandas(self):
table, sample_data = self.sample_variant_table()

with self.table_context(table) as engine:

df = pd.DataFrame([sample_data])
dtype_mapping = {
"variant_simple_col": DatabricksVariant,
"variant_nested_col": DatabricksVariant,
"variant_array_col": DatabricksVariant,
"variant_mixed_col": DatabricksVariant
Comment on lines +268 to +272
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if there are other types apart from variant, such as int or array,etc. Does this dtype mapping need to provided for only the variant columns or for all

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to provide the entire mapping. If we do not provide this mapping then the data is stored as a string for complex type. However for the general types like int, float, etc we do not need to explicitly map

}
df.to_sql(table.__tablename__, engine, if_exists="append", index=False, dtype=dtype_mapping)

stmt = select(table)
df_result = pd.read_sql(stmt, engine)
result_dict = df_result.iloc[0].to_dict()
# Parse JSON values back to original format for comparison
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
if result_dict[key] is not None:
result_dict[key] = json.loads(result_dict[key])

assert result_dict == sample_data

def test_variant_literal_processor(self):
table, sample_data = self.sample_variant_table()

with self.table_context(table) as engine:
stmt = table.__table__.insert().values(**sample_data)

try:
compiled = stmt.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True}
)
sql_str = str(compiled)

# Assert that JSON actually got inlined
assert '{"key":"value","number":42}' in sql_str
except NotImplementedError:
raise

with engine.begin() as conn:
conn.execute(stmt)

session = Session(engine)
stmt_select = select(table).where(table.int_col == sample_data["int_col"])
result = session.scalar(stmt_select)

compare = {key: getattr(result, key) for key in sample_data.keys()}

# Parse JSON values back to original Python objects
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
if compare[key] is not None:
compare[key] = json.loads(compare[key])

assert compare == sample_data
6 changes: 4 additions & 2 deletions tests/test_local/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
SetColumnComment,
SetTableComment,
)
from databricks.sqlalchemy import DatabricksArray, DatabricksMap
from databricks.sqlalchemy import DatabricksArray, DatabricksMap, DatabricksVariant


class DDLTestBase:
Expand Down Expand Up @@ -103,7 +103,8 @@ def metadata(self) -> MetaData:
metadata = MetaData()
col1 = Column("array_array_string", DatabricksArray(DatabricksArray(String)))
col2 = Column("map_string_string", DatabricksMap(String, String))
table = Table("complex_type", metadata, col1, col2)
col3 = Column("variant_col", DatabricksVariant())
table = Table("complex_type", metadata, col1, col2, col3)
return metadata

def test_create_table_with_complex_type(self, metadata):
Expand All @@ -112,3 +113,4 @@ def test_create_table_with_complex_type(self, metadata):

assert "array_array_string ARRAY<ARRAY<STRING>>" in output
assert "map_string_string MAP<STRING,STRING>" in output
assert "variant_col VARIANT" in output
4 changes: 3 additions & 1 deletion tests/test_local/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sqlalchemy

from databricks.sqlalchemy.base import DatabricksDialect
from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ
from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ, DatabricksVariant


class DatabricksDataType(enum.Enum):
Expand All @@ -28,6 +28,7 @@ class DatabricksDataType(enum.Enum):
ARRAY = enum.auto()
MAP = enum.auto()
STRUCT = enum.auto()
VARIANT = enum.auto()


# Defines the way that SQLAlchemy CamelCase types are compiled into Databricks SQL types.
Expand Down Expand Up @@ -131,6 +132,7 @@ def test_numeric_renders_as_decimal_with_precision_and_scale(self):
TINYINT: DatabricksDataType.TINYINT,
TIMESTAMP: DatabricksDataType.TIMESTAMP,
TIMESTAMP_NTZ: DatabricksDataType.TIMESTAMP_NTZ,
DatabricksVariant: DatabricksDataType.VARIANT,
}


Expand Down
Loading