Skip to content

Commit a6f4460

Browse files
authored
[PECOBLR-666]Added support for Variant datatype in SQLAlchemy
[PECOBLR-666]Added support for Variant datatype in SQLAlchemy
2 parents a811411 + 446c496 commit a6f4460

File tree

7 files changed

+219
-11
lines changed

7 files changed

+219
-11
lines changed

sqlalchemy_example.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
from datetime import date, datetime, time, timedelta, timezone
1818
from decimal import Decimal
1919
from uuid import UUID
20+
import json
2021

2122
# By convention, backend-specific SQLA types are defined in uppercase
22-
# This dialect exposes Databricks SQL's TIMESTAMP and TINYINT types
23+
# This dialect exposes Databricks SQL's TIMESTAMP, TINYINT, and VARIANT types
2324
# as these are not covered by the generic, camelcase types shown below
24-
from databricks.sqlalchemy import TIMESTAMP, TINYINT
25+
from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksVariant
2526

2627
# Beside the CamelCase types shown below, line comments reflect
2728
# the underlying Databricks SQL / Delta table type
@@ -82,6 +83,12 @@ class SampleObject(Base):
8283
datetime_col_ntz = Column(DateTime)
8384
time_col = Column(Time)
8485
uuid_col = Column(Uuid)
86+
variant_col = Column(DatabricksVariant)
87+
88+
Base.metadata.drop_all(engine)
89+
90+
# Output SQL is:
91+
# DROP TABLE pysql_sqlalchemy_example_table
8592

8693
# This generates a CREATE TABLE statement against the catalog and schema
8794
# specified in the connection string
@@ -100,6 +107,7 @@ class SampleObject(Base):
100107
# datetime_col_ntz TIMESTAMP_NTZ,
101108
# time_col STRING,
102109
# uuid_col STRING,
110+
# variant_col VARIANT,
103111
# PRIMARY KEY (bigint_col)
104112
# ) USING DELTA
105113

@@ -120,6 +128,23 @@ class SampleObject(Base):
120128
"datetime_col_ntz": datetime(1990, 12, 4, 6, 33, 41),
121129
"time_col": time(23, 59, 59),
122130
"uuid_col": UUID(int=255),
131+
"variant_col": {
132+
"name": "John Doe",
133+
"age": 30,
134+
"address": {
135+
"street": "123 Main St",
136+
"city": "San Francisco",
137+
"state": "CA",
138+
"zip": "94105"
139+
},
140+
"hobbies": ["reading", "hiking", "cooking"],
141+
"is_active": True,
142+
"metadata": {
143+
"created_at": "2024-01-15T10:30:00Z",
144+
"version": 1.2,
145+
"tags": ["premium", "verified"]
146+
}
147+
},
123148
}
124149
sa_obj = SampleObject(**sample_object)
125150

@@ -140,7 +165,8 @@ class SampleObject(Base):
140165
# datetime_col,
141166
# datetime_col_ntz,
142167
# time_col,
143-
# uuid_col
168+
# uuid_col,
169+
# variant_col
144170
# )
145171
# VALUES
146172
# (
@@ -154,7 +180,8 @@ class SampleObject(Base):
154180
# :datetime_col,
155181
# :datetime_col_ntz,
156182
# :time_col,
157-
# :uuid_col
183+
# :uuid_col,
184+
# PARSE_JSON(:variant_col)
158185
# )
159186

160187
# Here we build a SELECT query using ORM
@@ -165,6 +192,7 @@ class SampleObject(Base):
165192

166193
# Finally, we read out the input data and compare it to the output
167194
compare = {key: getattr(result, key) for key in sample_object.keys()}
195+
compare['variant_col'] = json.loads(compare['variant_col'])
168196
assert compare == sample_object
169197

170198
# Then we drop the demonstration table

src/databricks/sqlalchemy/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55
TIMESTAMP_NTZ,
66
DatabricksArray,
77
DatabricksMap,
8+
DatabricksVariant,
89
)
910

10-
__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ", "DatabricksArray", "DatabricksMap"]
11+
__all__ = [
12+
"TINYINT",
13+
"TIMESTAMP",
14+
"TIMESTAMP_NTZ",
15+
"DatabricksArray",
16+
"DatabricksMap",
17+
"DatabricksVariant",
18+
]

src/databricks/sqlalchemy/_parse.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def get_comment_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[st
318318
"map": sqlalchemy.types.String,
319319
"struct": sqlalchemy.types.String,
320320
"uniontype": sqlalchemy.types.String,
321+
"variant": type_overrides.DatabricksVariant,
321322
"decimal": sqlalchemy.types.Numeric,
322323
"timestamp": type_overrides.TIMESTAMP,
323324
"timestamp_ntz": type_overrides.TIMESTAMP_NTZ,

src/databricks/sqlalchemy/_types.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
from databricks.sql.utils import ParamEscaper
1111

12+
from sqlalchemy.sql import expression
13+
import json
14+
1215

1316
def process_literal_param_hack(value: Any):
1417
"""This method is supposed to accept a Python type and return a string representation of that type.
@@ -397,3 +400,60 @@ def compile_databricks_map(type_, compiler, **kw):
397400
key_type = compiler.process(type_.key_type, **kw)
398401
value_type = compiler.process(type_.value_type, **kw)
399402
return f"MAP<{key_type},{value_type}>"
403+
404+
405+
class DatabricksVariant(UserDefinedType):
406+
"""
407+
A custom variant type for storing semi-structured data including STRUCT, ARRAY, MAP, and scalar types.
408+
Note: VARIANT MAP types can only have STRING keys.
409+
410+
Examples:
411+
DatabricksVariant() -> VARIANT
412+
413+
Usage:
414+
Column('data', DatabricksVariant())
415+
"""
416+
417+
cache_ok = True
418+
419+
def __init__(self):
420+
self.pe = ParamEscaper()
421+
422+
def bind_processor(self, dialect):
423+
"""Process values before sending to database."""
424+
425+
def process(value):
426+
if value is None:
427+
return None
428+
try:
429+
return json.dumps(value, ensure_ascii=False, separators=(",", ":"))
430+
except (TypeError, ValueError) as e:
431+
raise ValueError(f"Cannot serialize value {value} to JSON: {e}")
432+
433+
return process
434+
435+
def bind_expression(self, bindvalue):
436+
"""Wrap with PARSE_JSON() in SQL"""
437+
return expression.func.PARSE_JSON(bindvalue)
438+
439+
def literal_processor(self, dialect):
440+
"""Process literal values for SQL generation.
441+
For VARIANT columns, use PARSE_JSON() to properly insert data.
442+
"""
443+
444+
def process(value):
445+
if value is None:
446+
return "NULL"
447+
try:
448+
return self.pe.escape_string(
449+
json.dumps(value, ensure_ascii=False, separators=(",", ":"))
450+
)
451+
except (TypeError, ValueError) as e:
452+
raise ValueError(f"Cannot serialize value {value} to JSON: {e}")
453+
454+
return process
455+
456+
457+
@compiles(DatabricksVariant, "databricks")
458+
def compile_variant(type_, compiler, **kw):
459+
return "VARIANT"

tests/test_local/e2e/test_complex_types.py

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
DateTime,
1212
)
1313
from collections.abc import Sequence
14-
from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap
14+
from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap, DatabricksVariant
1515
from sqlalchemy.orm import DeclarativeBase, Session
1616
from sqlalchemy import select
1717
from datetime import date, datetime, time, timedelta, timezone
1818
import pandas as pd
1919
import numpy as np
2020
import decimal
21-
21+
import json
2222

2323
class TestComplexTypes(TestSetup):
2424
def _parse_to_common_type(self, value):
@@ -46,7 +46,7 @@ def _parse_to_common_type(self, value):
4646
):
4747
return tuple(value)
4848
elif isinstance(value, dict):
49-
return tuple(value.items())
49+
return tuple(sorted(value.items()))
5050
elif isinstance(value, np.generic):
5151
return value.item()
5252
elif isinstance(value, decimal.Decimal):
@@ -152,6 +152,35 @@ class MapTable(Base):
152152

153153
return MapTable, sample_data
154154

155+
def sample_variant_table(self) -> tuple[DeclarativeBase, dict]:
156+
class Base(DeclarativeBase):
157+
pass
158+
159+
class VariantTable(Base):
160+
__tablename__ = "sqlalchemy_variant_table"
161+
162+
int_col = Column(Integer, primary_key=True)
163+
variant_simple_col = Column(DatabricksVariant())
164+
variant_nested_col = Column(DatabricksVariant())
165+
variant_array_col = Column(DatabricksVariant())
166+
variant_mixed_col = Column(DatabricksVariant())
167+
168+
sample_data = {
169+
"int_col": 1,
170+
"variant_simple_col": {"key": "value", "number": 42},
171+
"variant_nested_col": {"user": {"name": "John", "age": 30}, "active": True},
172+
"variant_array_col": [1, 2, 3, "hello", {"nested": "data"}],
173+
"variant_mixed_col": {
174+
"string": "test",
175+
"number": 123,
176+
"boolean": True,
177+
"array": [1, 2, 3],
178+
"object": {"nested": "value"}
179+
}
180+
}
181+
182+
return VariantTable, sample_data
183+
155184
def test_insert_array_table_sqlalchemy(self):
156185
table, sample_data = self.sample_array_table()
157186

@@ -209,3 +238,81 @@ def test_map_table_creation_pandas(self):
209238
stmt = select(table)
210239
df_result = pd.read_sql(stmt, engine)
211240
assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data)
241+
242+
def test_insert_variant_table_sqlalchemy(self):
243+
table, sample_data = self.sample_variant_table()
244+
245+
with self.table_context(table) as engine:
246+
247+
sa_obj = table(**sample_data)
248+
session = Session(engine)
249+
session.add(sa_obj)
250+
session.commit()
251+
252+
stmt = select(table).where(table.int_col == 1)
253+
result = session.scalar(stmt)
254+
compare = {key: getattr(result, key) for key in sample_data.keys()}
255+
# Parse JSON values back to original format for comparison
256+
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
257+
if compare[key] is not None:
258+
compare[key] = json.loads(compare[key])
259+
260+
assert compare == sample_data
261+
262+
def test_variant_table_creation_pandas(self):
263+
table, sample_data = self.sample_variant_table()
264+
265+
with self.table_context(table) as engine:
266+
267+
df = pd.DataFrame([sample_data])
268+
dtype_mapping = {
269+
"variant_simple_col": DatabricksVariant,
270+
"variant_nested_col": DatabricksVariant,
271+
"variant_array_col": DatabricksVariant,
272+
"variant_mixed_col": DatabricksVariant
273+
}
274+
df.to_sql(table.__tablename__, engine, if_exists="append", index=False, dtype=dtype_mapping)
275+
276+
stmt = select(table)
277+
df_result = pd.read_sql(stmt, engine)
278+
result_dict = df_result.iloc[0].to_dict()
279+
# Parse JSON values back to original format for comparison
280+
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
281+
if result_dict[key] is not None:
282+
result_dict[key] = json.loads(result_dict[key])
283+
284+
assert result_dict == sample_data
285+
286+
def test_variant_literal_processor(self):
287+
table, sample_data = self.sample_variant_table()
288+
289+
with self.table_context(table) as engine:
290+
stmt = table.__table__.insert().values(**sample_data)
291+
292+
try:
293+
compiled = stmt.compile(
294+
dialect=engine.dialect,
295+
compile_kwargs={"literal_binds": True}
296+
)
297+
sql_str = str(compiled)
298+
299+
# Assert that JSON actually got inlined
300+
assert '{"key":"value","number":42}' in sql_str
301+
except NotImplementedError:
302+
raise
303+
304+
with engine.begin() as conn:
305+
conn.execute(stmt)
306+
307+
session = Session(engine)
308+
stmt_select = select(table).where(table.int_col == sample_data["int_col"])
309+
result = session.scalar(stmt_select)
310+
311+
compare = {key: getattr(result, key) for key in sample_data.keys()}
312+
313+
# Parse JSON values back to original Python objects
314+
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
315+
if compare[key] is not None:
316+
compare[key] = json.loads(compare[key])
317+
318+
assert compare == sample_data

tests/test_local/test_ddl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
SetColumnComment,
88
SetTableComment,
99
)
10-
from databricks.sqlalchemy import DatabricksArray, DatabricksMap
10+
from databricks.sqlalchemy import DatabricksArray, DatabricksMap, DatabricksVariant
1111

1212

1313
class DDLTestBase:
@@ -103,7 +103,8 @@ def metadata(self) -> MetaData:
103103
metadata = MetaData()
104104
col1 = Column("array_array_string", DatabricksArray(DatabricksArray(String)))
105105
col2 = Column("map_string_string", DatabricksMap(String, String))
106-
table = Table("complex_type", metadata, col1, col2)
106+
col3 = Column("variant_col", DatabricksVariant())
107+
table = Table("complex_type", metadata, col1, col2, col3)
107108
return metadata
108109

109110
def test_create_table_with_complex_type(self, metadata):
@@ -112,3 +113,4 @@ def test_create_table_with_complex_type(self, metadata):
112113

113114
assert "array_array_string ARRAY<ARRAY<STRING>>" in output
114115
assert "map_string_string MAP<STRING,STRING>" in output
116+
assert "variant_col VARIANT" in output

tests/test_local/test_types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sqlalchemy
55

66
from databricks.sqlalchemy.base import DatabricksDialect
7-
from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ
7+
from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ, DatabricksVariant
88

99

1010
class DatabricksDataType(enum.Enum):
@@ -28,6 +28,7 @@ class DatabricksDataType(enum.Enum):
2828
ARRAY = enum.auto()
2929
MAP = enum.auto()
3030
STRUCT = enum.auto()
31+
VARIANT = enum.auto()
3132

3233

3334
# Defines the way that SQLAlchemy CamelCase types are compiled into Databricks SQL types.
@@ -131,6 +132,7 @@ def test_numeric_renders_as_decimal_with_precision_and_scale(self):
131132
TINYINT: DatabricksDataType.TINYINT,
132133
TIMESTAMP: DatabricksDataType.TIMESTAMP,
133134
TIMESTAMP_NTZ: DatabricksDataType.TIMESTAMP_NTZ,
135+
DatabricksVariant: DatabricksDataType.VARIANT,
134136
}
135137

136138

0 commit comments

Comments
 (0)