@@ -197,6 +197,23 @@ def tearDownClass(cls):
197197 ReusedPySparkTestCase .tearDownClass ()
198198 cls .spark .stop ()
199199
200+ @contextmanager
201+ def sql_conf (self , key , value ):
202+ """
203+ A convenient context manager to test some configuration specific logic. This sets the
204+ configurations then restores it back.
205+ """
206+
207+ orig_value = self .spark .conf .get (key , None )
208+ self .spark .conf .set (key , value )
209+ try :
210+ yield
211+ finally :
212+ if orig_value is None :
213+ self .spark .conf .unset (key )
214+ else :
215+ self .spark .conf .set (key , orig_value )
216+
200217 def assertPandasEqual (self , expected , result ):
201218 msg = ("DataFrames are not equal: " +
202219 "\n \n Expected:\n %s\n %s" % (expected , expected .dtypes ) +
@@ -3460,6 +3477,8 @@ def setUpClass(cls):
34603477
34613478 cls .spark .conf .set ("spark.sql.session.timeZone" , tz )
34623479 cls .spark .conf .set ("spark.sql.execution.arrow.enabled" , "true" )
3480+ # Disable fallback by default to easily detect the failures.
3481+ cls .spark .conf .set ("spark.sql.execution.arrow.fallback.enabled" , "false" )
34633482 cls .schema = StructType ([
34643483 StructField ("1_str_t" , StringType (), True ),
34653484 StructField ("2_int_t" , IntegerType (), True ),
@@ -3495,22 +3514,10 @@ def create_pandas_data_frame(self):
34953514 data_dict ["4_float_t" ] = np .float32 (data_dict ["4_float_t" ])
34963515 return pd .DataFrame (data = data_dict )
34973516
3498- @contextmanager
3499- def arrow_fallback (self , enabled ):
3500- orig_value = self .spark .conf .get ("spark.sql.execution.arrow.fallback.enabled" , None )
3501- self .spark .conf .set ("spark.sql.execution.arrow.fallback.enabled" , enabled )
3502- try :
3503- yield
3504- finally :
3505- if orig_value is None :
3506- self .spark .conf .unset ("spark.sql.execution.arrow.fallback.enabled" )
3507- else :
3508- self .spark .conf .set ("spark.sql.execution.arrow.fallback.enabled" , orig_value )
3509-
35103517 def test_toPandas_fallback_enabled (self ):
35113518 import pandas as pd
35123519
3513- with self .arrow_fallback ( True ):
3520+ with self .sql_conf ( "spark.sql.execution.arrow.fallback.enabled" , True ):
35143521 schema = StructType ([StructField ("map" , MapType (StringType (), IntegerType ()), True )])
35153522 df = self .spark .createDataFrame ([({u'a' : 1 },)], schema = schema )
35163523 with QuietTest (self .sc ):
@@ -3525,7 +3532,7 @@ def test_toPandas_fallback_enabled(self):
35253532 self .assertPandasEqual (pdf , pd .DataFrame ({u'map' : [{u'a' : 1 }]}))
35263533
35273534 def test_toPandas_fallback_disabled (self ):
3528- with self .arrow_fallback ( False ):
3535+ with self .sql_conf ( "spark.sql.execution.arrow.fallback.enabled" , False ):
35293536 schema = StructType ([StructField ("map" , MapType (StringType (), IntegerType ()), True )])
35303537 df = self .spark .createDataFrame ([(None ,)], schema = schema )
35313538 with QuietTest (self .sc ):
@@ -3650,7 +3657,7 @@ def test_createDataFrame_with_incorrect_schema(self):
36503657 pdf = self .create_pandas_data_frame ()
36513658 wrong_schema = StructType (list (reversed (self .schema )))
36523659 with QuietTest (self .sc ):
3653- with self .assertRaisesRegexp (TypeError , ".*field.*can.not.accept.*type " ):
3660+ with self .assertRaisesRegexp (RuntimeError , ".*No cast.*string.*timestamp.* " ):
36543661 self .spark .createDataFrame (pdf , schema = wrong_schema )
36553662
36563663 def test_createDataFrame_with_names (self ):
@@ -3675,7 +3682,7 @@ def test_createDataFrame_column_name_encoding(self):
36753682 def test_createDataFrame_with_single_data_type (self ):
36763683 import pandas as pd
36773684 with QuietTest (self .sc ):
3678- with self .assertRaisesRegexp (TypeError , ".*IntegerType.*tuple " ):
3685+ with self .assertRaisesRegexp (RuntimeError , ".*IntegerType.*not supported.* " ):
36793686 self .spark .createDataFrame (pd .DataFrame ({"a" : [1 ]}), schema = "int" )
36803687
36813688 def test_createDataFrame_does_not_modify_input (self ):
@@ -3734,7 +3741,7 @@ def test_createDataFrame_fallback_enabled(self):
37343741 import pandas as pd
37353742
37363743 with QuietTest (self .sc ):
3737- with self .arrow_fallback ( True ):
3744+ with self .sql_conf ( "spark.sql.execution.arrow.fallback.enabled" , True ):
37383745 with warnings .catch_warnings (record = True ) as warns :
37393746 df = self .spark .createDataFrame (
37403747 pd .DataFrame ([[{u'a' : 1 }]]), "a: map<string, int>" )
@@ -3750,7 +3757,7 @@ def test_createDataFrame_fallback_disabled(self):
37503757 import pandas as pd
37513758
37523759 with QuietTest (self .sc ):
3753- with self .arrow_fallback ( False ):
3760+ with self .sql_conf ( "spark.sql.execution.arrow.fallback.enabled" , False ):
37543761 with self .assertRaisesRegexp (Exception , 'Unsupported type' ):
37553762 self .spark .createDataFrame (
37563763 pd .DataFrame ([[{u'a' : 1 }]]), "a: map<string, int>" )
0 commit comments