Skip to content

Commit ed7cd94

Browse files
committed
Added variant to sqlalchemy_example and added a test for literal_processor for variant
1 parent d0cce3f commit ed7cd94

File tree

3 files changed

+70
-7
lines changed

3 files changed

+70
-7
lines changed

sqlalchemy_example.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
from datetime import date, datetime, time, timedelta, timezone
1818
from decimal import Decimal
1919
from uuid import UUID
20+
import json
2021

2122
# By convention, backend-specific SQLA types are defined in uppercase
22-
# This dialect exposes Databricks SQL's TIMESTAMP and TINYINT types
23+
# This dialect exposes Databricks SQL's TIMESTAMP, TINYINT, and VARIANT types
2324
# as these are not covered by the generic, camelcase types shown below
24-
from databricks.sqlalchemy import TIMESTAMP, TINYINT
25+
from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksVariant
2526

2627
# Beside the CamelCase types shown below, line comments reflect
2728
# the underlying Databricks SQL / Delta table type
@@ -82,6 +83,12 @@ class SampleObject(Base):
8283
datetime_col_ntz = Column(DateTime)
8384
time_col = Column(Time)
8485
uuid_col = Column(Uuid)
86+
variant_col = Column(DatabricksVariant)
87+
88+
Base.metadata.drop_all(engine)
89+
90+
# Output SQL is:
91+
# DROP TABLE pysql_sqlalchemy_example_table
8592

8693
# This generates a CREATE TABLE statement against the catalog and schema
8794
# specified in the connection string
@@ -100,6 +107,7 @@ class SampleObject(Base):
100107
# datetime_col_ntz TIMESTAMP_NTZ,
101108
# time_col STRING,
102109
# uuid_col STRING,
110+
# variant_col VARIANT,
103111
# PRIMARY KEY (bigint_col)
104112
# ) USING DELTA
105113

@@ -120,6 +128,23 @@ class SampleObject(Base):
120128
"datetime_col_ntz": datetime(1990, 12, 4, 6, 33, 41),
121129
"time_col": time(23, 59, 59),
122130
"uuid_col": UUID(int=255),
131+
"variant_col": {
132+
"name": "John Doe",
133+
"age": 30,
134+
"address": {
135+
"street": "123 Main St",
136+
"city": "San Francisco",
137+
"state": "CA",
138+
"zip": "94105"
139+
},
140+
"hobbies": ["reading", "hiking", "cooking"],
141+
"is_active": True,
142+
"metadata": {
143+
"created_at": "2024-01-15T10:30:00Z",
144+
"version": 1.2,
145+
"tags": ["premium", "verified"]
146+
}
147+
},
123148
}
124149
sa_obj = SampleObject(**sample_object)
125150

@@ -140,7 +165,8 @@ class SampleObject(Base):
140165
# datetime_col,
141166
# datetime_col_ntz,
142167
# time_col,
143-
# uuid_col
168+
# uuid_col,
169+
# variant_col
144170
# )
145171
# VALUES
146172
# (
@@ -154,7 +180,8 @@ class SampleObject(Base):
154180
# :datetime_col,
155181
# :datetime_col_ntz,
156182
# :time_col,
157-
# :uuid_col
183+
# :uuid_col,
184+
# PARSE_JSON(:variant_col)
158185
# )
159186

160187
# Here we build a SELECT query using ORM
@@ -165,6 +192,7 @@ class SampleObject(Base):
165192

166193
# Finally, we read out the input data and compare it to the output
167194
compare = {key: getattr(result, key) for key in sample_object.keys()}
195+
compare['variant_col'] = json.loads(compare['variant_col'])
168196
assert compare == sample_object
169197

170198
# Then we drop the demonstration table

src/databricks/sqlalchemy/_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def process(value):
446446
except (TypeError, ValueError) as e:
447447
raise ValueError(f"Cannot serialize value {value} to JSON: {e}")
448448

449-
return f"PARSE_JSON('{process}')"
449+
return process
450450

451451
@compiles(DatabricksVariant, "databricks")
452452
def compile_variant(type_, compiler, **kw):

tests/test_local/e2e/test_complex_types.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def test_insert_variant_table_sqlalchemy(self):
257257
if compare[key] is not None:
258258
compare[key] = json.loads(compare[key])
259259

260-
assert self._recursive_compare(compare, sample_data)
260+
assert compare == sample_data
261261

262262
def test_variant_table_creation_pandas(self):
263263
table, sample_data = self.sample_variant_table()
@@ -280,4 +280,39 @@ def test_variant_table_creation_pandas(self):
280280
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
281281
if result_dict[key] is not None:
282282
result_dict[key] = json.loads(result_dict[key])
283-
assert self._recursive_compare(result_dict, sample_data)
283+
284+
assert result_dict == sample_data
285+
286+
def test_variant_literal_processor(self):
287+
table, sample_data = self.sample_variant_table()
288+
289+
with self.table_context(table) as engine:
290+
stmt = table.__table__.insert().values(**sample_data)
291+
292+
try:
293+
compiled = stmt.compile(
294+
dialect=engine.dialect,
295+
compile_kwargs={"literal_binds": True}
296+
)
297+
sql_str = str(compiled)
298+
299+
# Assert that JSON actually got inlined
300+
assert '{"key":"value","number":42}' in sql_str
301+
except NotImplementedError:
302+
raise
303+
304+
with engine.begin() as conn:
305+
conn.execute(stmt)
306+
307+
session = Session(engine)
308+
stmt_select = select(table).where(table.int_col == sample_data["int_col"])
309+
result = session.scalar(stmt_select)
310+
311+
compare = {key: getattr(result, key) for key in sample_data.keys()}
312+
313+
# Parse JSON values back to original Python objects
314+
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
315+
if compare[key] is not None:
316+
compare[key] = json.loads(compare[key])
317+
318+
assert compare == sample_data

0 commit comments

Comments
 (0)