2020import decimal
2121import json
2222
23-
2423class TestComplexTypes (TestSetup ):
2524 def _parse_to_common_type (self , value ):
2625 """
@@ -244,38 +243,28 @@ def test_insert_variant_table_sqlalchemy(self):
244243 table , sample_data = self .sample_variant_table ()
245244
246245 with self .table_context (table ) as engine :
247- # Pre-serialize variant data for SQLAlchemy
248- variant_data = sample_data .copy ()
249- for key in ['variant_simple_col' , 'variant_nested_col' , 'variant_array_col' , 'variant_mixed_col' ]:
250- variant_data [key ] = None if sample_data [key ] is None else json .dumps (sample_data [key ])
251-
252- sa_obj = table (** variant_data )
246+
247+ sa_obj = table (** sample_data )
253248 session = Session (engine )
254249 session .add (sa_obj )
255250 session .commit ()
256251
257252 stmt = select (table ).where (table .int_col == 1 )
258-
259253 result = session .scalar (stmt )
260-
261254 compare = {key : getattr (result , key ) for key in sample_data .keys ()}
262255 # Parse JSON values back to original format for comparison
263256 for key in ['variant_simple_col' , 'variant_nested_col' , 'variant_array_col' , 'variant_mixed_col' ]:
264257 if compare [key ] is not None :
265258 compare [key ] = json .loads (compare [key ])
259+
266260 assert self ._recursive_compare (compare , sample_data )
267261
268262 def test_variant_table_creation_pandas (self ):
269263 table , sample_data = self .sample_variant_table ()
270264
271265 with self .table_context (table ) as engine :
272- # Pre-serialize variant data for pandas
273- variant_data = sample_data .copy ()
274- for key in ['variant_simple_col' , 'variant_nested_col' , 'variant_array_col' , 'variant_mixed_col' ]:
275- variant_data [key ] = None if sample_data [key ] is None else json .dumps (sample_data [key ])
276266
277- # Insert the data into the table
278- df = pd .DataFrame ([variant_data ])
267+ df = pd .DataFrame ([sample_data ])
279268 dtype_mapping = {
280269 "variant_simple_col" : DatabricksVariant ,
281270 "variant_nested_col" : DatabricksVariant ,
@@ -284,7 +273,6 @@ def test_variant_table_creation_pandas(self):
284273 }
285274 df .to_sql (table .__tablename__ , engine , if_exists = "append" , index = False , dtype = dtype_mapping )
286275
287- # Read the data from the table
288276 stmt = select (table )
289277 df_result = pd .read_sql (stmt , engine )
290278 result_dict = df_result .iloc [0 ].to_dict ()
0 commit comments