Skip to content

Commit 2b4099f

Browse files
ByronHsuHyukjinKwon
authored andcommitted
[SPARK-37137][PYTHON] Inline type hints for python/pyspark/conf.py
### What changes were proposed in this pull request? Inline type hints for python/pyspark/conf.py ### Why are the changes needed? Currently, Inline type hints for python/pyspark/conf.pyi doesn't support type checking within function bodies. So we inline type hints to support that. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Exising test. Closes #34411 from ByronHsu/SPARK-37137. Authored-by: Byron <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 8f20398 commit 2b4099f

File tree

3 files changed

+44
-69
lines changed

3 files changed

+44
-69
lines changed

python/pyspark/conf.py

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
__all__ = ['SparkConf']
1919

2020
import sys
21+
from typing import Dict, List, Optional, Tuple, cast, overload
2122

23+
from py4j.java_gateway import JVMView, JavaObject # type: ignore[import]
2224

23-
class SparkConf(object):
2425

26+
class SparkConf(object):
2527
"""
2628
Configuration for a Spark application. Used to set various Spark
2729
parameters as key-value pairs.
@@ -105,15 +107,19 @@ class SparkConf(object):
105107
spark.home=/path
106108
"""
107109

108-
def __init__(self, loadDefaults=True, _jvm=None, _jconf=None):
110+
_jconf: Optional[JavaObject]
111+
_conf: Optional[Dict[str, str]]
112+
113+
def __init__(self, loadDefaults: bool = True, _jvm: Optional[JVMView] = None,
114+
_jconf: Optional[JavaObject] = None):
109115
"""
110116
Create a new Spark configuration.
111117
"""
112118
if _jconf:
113119
self._jconf = _jconf
114120
else:
115121
from pyspark.context import SparkContext
116-
_jvm = _jvm or SparkContext._jvm
122+
_jvm = _jvm or SparkContext._jvm # type: ignore[attr-defined]
117123

118124
if _jvm is not None:
119125
# JVM is created, so create self._jconf directly through JVM
@@ -124,48 +130,58 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None):
124130
self._jconf = None
125131
self._conf = {}
126132

127-
def set(self, key, value):
133+
def set(self, key: str, value: str) -> "SparkConf":
128134
"""Set a configuration property."""
129135
# Try to set self._jconf first if JVM is created, set self._conf if JVM is not created yet.
130136
if self._jconf is not None:
131137
self._jconf.set(key, str(value))
132138
else:
139+
assert self._conf is not None
133140
self._conf[key] = str(value)
134141
return self
135142

136-
def setIfMissing(self, key, value):
143+
def setIfMissing(self, key: str, value: str) -> "SparkConf":
137144
"""Set a configuration property, if not already set."""
138145
if self.get(key) is None:
139146
self.set(key, value)
140147
return self
141148

142-
def setMaster(self, value):
149+
def setMaster(self, value: str) -> "SparkConf":
143150
"""Set master URL to connect to."""
144151
self.set("spark.master", value)
145152
return self
146153

147-
def setAppName(self, value):
154+
def setAppName(self, value: str) -> "SparkConf":
148155
"""Set application name."""
149156
self.set("spark.app.name", value)
150157
return self
151158

152-
def setSparkHome(self, value):
159+
def setSparkHome(self, value: str) -> "SparkConf":
153160
"""Set path where Spark is installed on worker nodes."""
154161
self.set("spark.home", value)
155162
return self
156163

157-
def setExecutorEnv(self, key=None, value=None, pairs=None):
164+
@overload
165+
def setExecutorEnv(self, key: str, value: str) -> "SparkConf":
166+
...
167+
168+
@overload
169+
def setExecutorEnv(self, *, pairs: List[Tuple[str, str]]) -> "SparkConf":
170+
...
171+
172+
def setExecutorEnv(self, key: Optional[str] = None, value: Optional[str] = None,
173+
pairs: Optional[List[Tuple[str, str]]] = None) -> "SparkConf":
158174
"""Set an environment variable to be passed to executors."""
159175
if (key is not None and pairs is not None) or (key is None and pairs is None):
160176
raise RuntimeError("Either pass one key-value pair or a list of pairs")
161177
elif key is not None:
162-
self.set("spark.executorEnv." + key, value)
178+
self.set("spark.executorEnv.{}".format(key), cast(str, value))
163179
elif pairs is not None:
164180
for (k, v) in pairs:
165-
self.set("spark.executorEnv." + k, v)
181+
self.set("spark.executorEnv.{}".format(k), v)
166182
return self
167183

168-
def setAll(self, pairs):
184+
def setAll(self, pairs: List[Tuple[str, str]]) -> "SparkConf":
169185
"""
170186
Set multiple parameters, passed as a list of key-value pairs.
171187
@@ -178,49 +194,52 @@ def setAll(self, pairs):
178194
self.set(k, v)
179195
return self
180196

181-
def get(self, key, defaultValue=None):
197+
def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]:
182198
"""Get the configured value for some key, or return a default otherwise."""
183-
if defaultValue is None: # Py4J doesn't call the right get() if we pass None
199+
if defaultValue is None: # Py4J doesn't call the right get() if we pass None
184200
if self._jconf is not None:
185201
if not self._jconf.contains(key):
186202
return None
187203
return self._jconf.get(key)
188204
else:
189-
if key not in self._conf:
190-
return None
191-
return self._conf[key]
205+
assert self._conf is not None
206+
return self._conf.get(key, None)
192207
else:
193208
if self._jconf is not None:
194209
return self._jconf.get(key, defaultValue)
195210
else:
211+
assert self._conf is not None
196212
return self._conf.get(key, defaultValue)
197213

198-
def getAll(self):
214+
def getAll(self) -> List[Tuple[str, str]]:
199215
"""Get all values as a list of key-value pairs."""
200216
if self._jconf is not None:
201-
return [(elem._1(), elem._2()) for elem in self._jconf.getAll()]
217+
return [(elem._1(), elem._2()) for elem in cast(JavaObject, self._jconf).getAll()]
202218
else:
203-
return self._conf.items()
219+
assert self._conf is not None
220+
return list(self._conf.items())
204221

205-
def contains(self, key):
222+
def contains(self, key: str) -> bool:
206223
"""Does this configuration contain a given key?"""
207224
if self._jconf is not None:
208225
return self._jconf.contains(key)
209226
else:
227+
assert self._conf is not None
210228
return key in self._conf
211229

212-
def toDebugString(self):
230+
def toDebugString(self) -> str:
213231
"""
214232
Returns a printable version of the configuration, as a list of
215233
key=value pairs, one per line.
216234
"""
217235
if self._jconf is not None:
218236
return self._jconf.toDebugString()
219237
else:
238+
assert self._conf is not None
220239
return '\n'.join('%s=%s' % (k, v) for k, v in self._conf.items())
221240

222241

223-
def _test():
242+
def _test() -> None:
224243
import doctest
225244
(failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)
226245
if failure_count:

python/pyspark/conf.pyi

Lines changed: 0 additions & 44 deletions
This file was deleted.

python/pyspark/sql/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ def _create_shell_session() -> "SparkSession":
610610
try:
611611
# Try to access HiveConf, it will raise exception if Hive is not added
612612
conf = SparkConf()
613-
if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive':
613+
if cast(str, conf.get('spark.sql.catalogImplementation', 'hive')).lower() == 'hive':
614614
(SparkContext._jvm # type: ignore[attr-defined]
615615
.org.apache.hadoop.hive.conf.HiveConf())
616616
return SparkSession.builder\
@@ -619,7 +619,7 @@ def _create_shell_session() -> "SparkSession":
619619
else:
620620
return SparkSession.builder.getOrCreate()
621621
except (py4j.protocol.Py4JError, TypeError):
622-
if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive':
622+
if cast(str, conf.get('spark.sql.catalogImplementation', '')).lower() == 'hive':
623623
warnings.warn("Fall back to non-hive support because failing to access HiveConf, "
624624
"please make sure you build spark with hive")
625625

0 commit comments

Comments
 (0)