Skip to content

Commit 7641fd0

Browse files
committed
Address comments
1 parent 7f87d25 commit 7641fd0

File tree

4 files changed

+33
-31
lines changed

4 files changed

+33
-31
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1995,13 +1995,12 @@ def toPandas(self):
19951995
to_arrow_schema(self.schema)
19961996
except Exception as e:
19971997

1998-
if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "false") \
1998+
if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \
19991999
.lower() == "true":
20002000
msg = (
20012001
"toPandas attempted Arrow optimization because "
20022002
"'spark.sql.execution.arrow.enabled' is set to true; however, "
2003-
"failed by the reason below:\n"
2004-
" %s\n"
2003+
"failed by the reason below:\n %s\n"
20052004
"Attempts non-optimization as "
20062005
"'spark.sql.execution.arrow.fallback.enabled' is set to "
20072006
"true." % _exception_message(e))
@@ -2011,8 +2010,7 @@ def toPandas(self):
20112010
msg = (
20122011
"toPandas attempted Arrow optimization because "
20132012
"'spark.sql.execution.arrow.enabled' is set to true; however, "
2014-
"failed by the reason below:\n"
2015-
" %s\n"
2013+
"failed by the reason below:\n %s\n"
20162014
"For fallback to non-optimization automatically, please set true to "
20172015
"'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e))
20182016
raise RuntimeError(msg)

python/pyspark/sql/session.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -668,13 +668,12 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
668668
except Exception as e:
669669
from pyspark.util import _exception_message
670670

671-
if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "false") \
671+
if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \
672672
.lower() == "true":
673673
msg = (
674674
"createDataFrame attempted Arrow optimization because "
675675
"'spark.sql.execution.arrow.enabled' is set to true; however, "
676-
"failed by the reason below:\n"
677-
" %s\n"
676+
"failed by the reason below:\n %s\n"
678677
"Attempts non-optimization as "
679678
"'spark.sql.execution.arrow.fallback.enabled' is set to "
680679
"true." % _exception_message(e))
@@ -683,8 +682,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
683682
msg = (
684683
"createDataFrame attempted Arrow optimization because "
685684
"'spark.sql.execution.arrow.enabled' is set to true; however, "
686-
"failed by the reason below:\n"
687-
" %s\n"
685+
"failed by the reason below:\n %s\n"
688686
"For fallback to non-optimization automatically, please set true to "
689687
"'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e))
690688
raise RuntimeError(msg)

python/pyspark/sql/tests.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,23 @@ def tearDownClass(cls):
197197
ReusedPySparkTestCase.tearDownClass()
198198
cls.spark.stop()
199199

200+
@contextmanager
201+
def sql_conf(self, key, value):
202+
"""
203+
A convenient context manager to test some configuration specific logic. This sets the
204+
configurations then restores it back.
205+
"""
206+
207+
orig_value = self.spark.conf.get(key, None)
208+
self.spark.conf.set(key, value)
209+
try:
210+
yield
211+
finally:
212+
if orig_value is None:
213+
self.spark.conf.unset(key)
214+
else:
215+
self.spark.conf.set(key, orig_value)
216+
200217
def assertPandasEqual(self, expected, result):
201218
msg = ("DataFrames are not equal: " +
202219
"\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
@@ -3460,6 +3477,8 @@ def setUpClass(cls):
34603477

34613478
cls.spark.conf.set("spark.sql.session.timeZone", tz)
34623479
cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
3480+
# Disable fallback by default to easily detect the failures.
3481+
cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false")
34633482
cls.schema = StructType([
34643483
StructField("1_str_t", StringType(), True),
34653484
StructField("2_int_t", IntegerType(), True),
@@ -3495,22 +3514,10 @@ def create_pandas_data_frame(self):
34953514
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
34963515
return pd.DataFrame(data=data_dict)
34973516

3498-
@contextmanager
3499-
def arrow_fallback(self, enabled):
3500-
orig_value = self.spark.conf.get("spark.sql.execution.arrow.fallback.enabled", None)
3501-
self.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", enabled)
3502-
try:
3503-
yield
3504-
finally:
3505-
if orig_value is None:
3506-
self.spark.conf.unset("spark.sql.execution.arrow.fallback.enabled")
3507-
else:
3508-
self.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", orig_value)
3509-
35103517
def test_toPandas_fallback_enabled(self):
35113518
import pandas as pd
35123519

3513-
with self.arrow_fallback(True):
3520+
with self.sql_conf("spark.sql.execution.arrow.fallback.enabled", True):
35143521
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
35153522
df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
35163523
with QuietTest(self.sc):
@@ -3525,7 +3532,7 @@ def test_toPandas_fallback_enabled(self):
35253532
self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
35263533

35273534
def test_toPandas_fallback_disabled(self):
3528-
with self.arrow_fallback(False):
3535+
with self.sql_conf("spark.sql.execution.arrow.fallback.enabled", False):
35293536
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
35303537
df = self.spark.createDataFrame([(None,)], schema=schema)
35313538
with QuietTest(self.sc):
@@ -3650,7 +3657,7 @@ def test_createDataFrame_with_incorrect_schema(self):
36503657
pdf = self.create_pandas_data_frame()
36513658
wrong_schema = StructType(list(reversed(self.schema)))
36523659
with QuietTest(self.sc):
3653-
with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"):
3660+
with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"):
36543661
self.spark.createDataFrame(pdf, schema=wrong_schema)
36553662

36563663
def test_createDataFrame_with_names(self):
@@ -3675,7 +3682,7 @@ def test_createDataFrame_column_name_encoding(self):
36753682
def test_createDataFrame_with_single_data_type(self):
36763683
import pandas as pd
36773684
with QuietTest(self.sc):
3678-
with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"):
3685+
with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"):
36793686
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
36803687

36813688
def test_createDataFrame_does_not_modify_input(self):
@@ -3734,7 +3741,7 @@ def test_createDataFrame_fallback_enabled(self):
37343741
import pandas as pd
37353742

37363743
with QuietTest(self.sc):
3737-
with self.arrow_fallback(True):
3744+
with self.sql_conf("spark.sql.execution.arrow.fallback.enabled", True):
37383745
with warnings.catch_warnings(record=True) as warns:
37393746
df = self.spark.createDataFrame(
37403747
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
@@ -3750,7 +3757,7 @@ def test_createDataFrame_fallback_disabled(self):
37503757
import pandas as pd
37513758

37523759
with QuietTest(self.sc):
3753-
with self.arrow_fallback(False):
3760+
with self.sql_conf("spark.sql.execution.arrow.fallback.enabled", False):
37543761
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
37553762
self.spark.createDataFrame(
37563763
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,9 +1070,8 @@ object SQLConf {
10701070

10711071
val ARROW_FALLBACK_ENABLE =
10721072
buildConf("spark.sql.execution.arrow.fallback.enabled")
1073-
.doc("When true, the optimization by 'spark.sql.execution.arrow.enabled' " +
1074-
"could be disabled when it is unable to be used, and fallback to " +
1075-
"non-optimization.")
1073+
.doc("When true, optimizations enabled by 'spark.sql.execution.arrow.enabled' will " +
1074+
"fallback automatically to non-optimized implementations if an error occurs.")
10761075
.booleanConf
10771076
.createWithDefault(true)
10781077

0 commit comments

Comments
 (0)