Skip to content

Commit fb47432

Browse files
committed
address comments
1 parent 7c6d2d5 commit fb47432

File tree

3 files changed

+50
-49
lines changed

3 files changed

+50
-49
lines changed

python/pyspark/sql/functions.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2713,25 +2713,6 @@ def from_csv(col, schema, options={}):
27132713
return Column(jc)
27142714

27152715

2716-
@since(3.0)
2717-
def _getActiveSession():
2718-
"""
2719-
Returns the active SparkSession for the current thread
2720-
This method is not intended for user to call directly.
2721-
It is only used for getActiveSession method in session.py
2722-
"""
2723-
from pyspark.sql import SparkSession
2724-
sc = SparkContext._active_spark_context
2725-
if sc is None:
2726-
return None
2727-
else:
2728-
if sc._jvm.SparkSession.getActiveSession().isDefined():
2729-
SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get())
2730-
return SparkSession._activeSession
2731-
else:
2732-
return None
2733-
2734-
27352716
# ---------------------------- User Defined Function ----------------------------------
27362717

27372718
class PandasUDFType(object):

python/pyspark/sql/session.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,16 @@ def getActiveSession(cls):
270270
>>> df.select("age").collect()
271271
[Row(age=1)]
272272
"""
273-
from pyspark.sql import functions
274-
return functions._getActiveSession()
273+
from pyspark import SparkContext
274+
sc = SparkContext._active_spark_context
275+
if sc is None:
276+
return None
277+
else:
278+
if sc._jvm.SparkSession.getActiveSession().isDefined():
279+
SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get())
280+
return SparkSession._activeSession
281+
else:
282+
return None
275283

276284
@property
277285
@since(2.0)

python/pyspark/sql/tests.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3926,7 +3926,7 @@ def test_default_and_active_session(self):
39263926
finally:
39273927
spark.stop()
39283928

3929-
def test_config_option_propagated_to_existing_SparkSession(self):
3929+
def test_config_option_propagated_to_existing_session(self):
39303930
session1 = SparkSession.builder \
39313931
.master("local") \
39323932
.config("spark-config1", "a") \
@@ -3968,38 +3968,50 @@ def test_create_new_session_if_old_session_stopped(self):
39683968
def test_active_session_with_None_and_not_None_context(self):
39693969
from pyspark.context import SparkContext
39703970
from pyspark.conf import SparkConf
3971-
sc = SparkContext._active_spark_context
3972-
self.assertEqual(sc, None)
3973-
activeSession = SparkSession.getActiveSession()
3974-
self.assertEqual(activeSession, None)
3975-
sparkConf = SparkConf()
3976-
sc = SparkContext.getOrCreate(sparkConf)
3977-
activeSession = sc._jvm.SparkSession.getActiveSession()
3978-
self.assertFalse(activeSession.isDefined())
3979-
session = SparkSession(sc)
3980-
activeSession = sc._jvm.SparkSession.getActiveSession()
3981-
self.assertTrue(activeSession.isDefined())
3982-
activeSession2 = SparkSession.getActiveSession()
3983-
self.assertNotEqual(activeSession2, None)
3971+
sc = None
3972+
session = None
3973+
try:
3974+
sc = SparkContext._active_spark_context
3975+
self.assertEqual(sc, None)
3976+
activeSession = SparkSession.getActiveSession()
3977+
self.assertEqual(activeSession, None)
3978+
sparkConf = SparkConf()
3979+
sc = SparkContext.getOrCreate(sparkConf)
3980+
activeSession = sc._jvm.SparkSession.getActiveSession()
3981+
self.assertFalse(activeSession.isDefined())
3982+
session = SparkSession(sc)
3983+
activeSession = sc._jvm.SparkSession.getActiveSession()
3984+
self.assertTrue(activeSession.isDefined())
3985+
activeSession2 = SparkSession.getActiveSession()
3986+
self.assertNotEqual(activeSession2, None)
3987+
finally:
3988+
if session is not None:
3989+
session.stop()
3990+
if sc is not None:
3991+
sc.stop()
39843992

39853993

39863994
class SparkSessionTests3(ReusedSQLTestCase):
39873995

39883996
def test_get_active_session_after_create_dataframe(self):
3989-
activeSession1 = SparkSession.getActiveSession()
3990-
session1 = self.spark
3991-
self.assertEqual(session1, activeSession1)
3992-
session2 = self.spark.newSession()
3993-
activeSession2 = SparkSession.getActiveSession()
3994-
self.assertEqual(session1, activeSession2)
3995-
self.assertNotEqual(session2, activeSession2)
3996-
session2.createDataFrame([(1, 'Alice')], ['age', 'name'])
3997-
activeSession3 = SparkSession.getActiveSession()
3998-
self.assertEqual(session2, activeSession3)
3999-
session1.createDataFrame([(1, 'Alice')], ['age', 'name'])
4000-
activeSession4 = SparkSession.getActiveSession()
4001-
self.assertEqual(session1, activeSession4)
4002-
session2.stop()
3997+
session2 = None
3998+
try:
3999+
activeSession1 = SparkSession.getActiveSession()
4000+
session1 = self.spark
4001+
self.assertEqual(session1, activeSession1)
4002+
session2 = self.spark.newSession()
4003+
activeSession2 = SparkSession.getActiveSession()
4004+
self.assertEqual(session1, activeSession2)
4005+
self.assertNotEqual(session2, activeSession2)
4006+
session2.createDataFrame([(1, 'Alice')], ['age', 'name'])
4007+
activeSession3 = SparkSession.getActiveSession()
4008+
self.assertEqual(session2, activeSession3)
4009+
session1.createDataFrame([(1, 'Alice')], ['age', 'name'])
4010+
activeSession4 = SparkSession.getActiveSession()
4011+
self.assertEqual(session1, activeSession4)
4012+
finally:
4013+
if session2 is not None:
4014+
session2.stop()
40034015

40044016

40054017
class UDFInitializationTests(unittest.TestCase):

0 commit comments

Comments
 (0)