Skip to content

Commit 5ec04bc

Browse files
committed
Allows user to directly pass object for variant type.
1 parent d697cfc commit 5ec04bc

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

src/databricks/sqlalchemy/_types.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from databricks.sql.utils import ParamEscaper
1111

1212
from sqlalchemy.sql import expression
13+
import json
1314

1415
def process_literal_param_hack(value: Any):
1516
"""This method is supposed to accept a Python type and return a string representation of that type.
@@ -420,7 +421,12 @@ def bind_processor(self, dialect):
420421
"""
421422

422423
def process(value):
423-
return value
424+
if value is None:
425+
return None
426+
try:
427+
return json.dumps(value, ensure_ascii=False, separators=(',', ':'))
428+
except (TypeError, ValueError) as e:
429+
raise ValueError(f"Cannot serialize value {value} to JSON: {e}")
424430

425431
return process
426432

@@ -435,7 +441,10 @@ def literal_processor(self, dialect):
435441
def process(value):
436442
if value is None:
437443
return "NULL"
438-
return self.pe.escape_string(value)
444+
try:
445+
return self.pe.escape_string(json.dumps(value, ensure_ascii=False, separators=(',', ':')))
446+
except (TypeError, ValueError) as e:
447+
raise ValueError(f"Cannot serialize value {value} to JSON: {e}")
439448

440449
return f"PARSE_JSON('{process}')"
441450

tests/test_local/e2e/test_complex_types.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import decimal
2121
import json
2222

23-
2423
class TestComplexTypes(TestSetup):
2524
def _parse_to_common_type(self, value):
2625
"""
@@ -244,38 +243,28 @@ def test_insert_variant_table_sqlalchemy(self):
244243
table, sample_data = self.sample_variant_table()
245244

246245
with self.table_context(table) as engine:
247-
# Pre-serialize variant data for SQLAlchemy
248-
variant_data = sample_data.copy()
249-
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
250-
variant_data[key] = None if sample_data[key] is None else json.dumps(sample_data[key])
251-
252-
sa_obj = table(**variant_data)
246+
247+
sa_obj = table(**sample_data)
253248
session = Session(engine)
254249
session.add(sa_obj)
255250
session.commit()
256251

257252
stmt = select(table).where(table.int_col == 1)
258-
259253
result = session.scalar(stmt)
260-
261254
compare = {key: getattr(result, key) for key in sample_data.keys()}
262255
# Parse JSON values back to original format for comparison
263256
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
264257
if compare[key] is not None:
265258
compare[key] = json.loads(compare[key])
259+
266260
assert self._recursive_compare(compare, sample_data)
267261

268262
def test_variant_table_creation_pandas(self):
269263
table, sample_data = self.sample_variant_table()
270264

271265
with self.table_context(table) as engine:
272-
# Pre-serialize variant data for pandas
273-
variant_data = sample_data.copy()
274-
for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']:
275-
variant_data[key] = None if sample_data[key] is None else json.dumps(sample_data[key])
276266

277-
# Insert the data into the table
278-
df = pd.DataFrame([variant_data])
267+
df = pd.DataFrame([sample_data])
279268
dtype_mapping = {
280269
"variant_simple_col": DatabricksVariant,
281270
"variant_nested_col": DatabricksVariant,
@@ -284,7 +273,6 @@ def test_variant_table_creation_pandas(self):
284273
}
285274
df.to_sql(table.__tablename__, engine, if_exists="append", index=False, dtype=dtype_mapping)
286275

287-
# Read the data from the table
288276
stmt = select(table)
289277
df_result = pd.read_sql(stmt, engine)
290278
result_dict = df_result.iloc[0].to_dict()

0 commit comments

Comments
 (0)