@@ -48,7 +48,7 @@ private[spark] class PythonRDD(
4848 extends RDD [Array [Byte ]](parent) {
4949
5050 val bufferSize = conf.getInt(" spark.buffer.size" , 65536 )
51- val reuse_worker = conf.getBoolean(" spark.python.worker.reuse" , true )
51+ val reuseWorker = conf.getBoolean(" spark.python.worker.reuse" , true )
5252
5353 override def getPartitions : Array [Partition ] = firstParent.partitions
5454
@@ -59,7 +59,7 @@ private[spark] class PythonRDD(
5959 val asJavaRDD : JavaRDD [Array [Byte ]] = JavaRDD .fromRDD(this )
6060
6161 override def compute (split : Partition , context : TaskContext ): Iterator [Array [Byte ]] = {
62- val runner = PythonRunner (func, bufferSize, reuse_worker )
62+ val runner = PythonRunner (func, bufferSize, reuseWorker )
6363 runner.compute(firstParent.iterator(split, context), split.index, context)
6464 }
6565}
@@ -83,318 +83,9 @@ private[spark] case class PythonFunction(
8383 */
8484private [spark] case class ChainedPythonFunctions (funcs : Seq [PythonFunction ])
8585
86- /**
87- * Enumerate the type of command that will be sent to the Python worker
88- */
89- private [spark] object PythonEvalType {
90- val NON_UDF = 0
91- val SQL_BATCHED_UDF = 1
92- val SQL_PANDAS_UDF = 2
93- }
94-
95- private [spark] object PythonRunner {
96- def apply (func : PythonFunction , bufferSize : Int , reuse_worker : Boolean ): PythonRunner = {
97- new PythonRunner (
98- Seq (ChainedPythonFunctions (Seq (func))),
99- bufferSize,
100- reuse_worker,
101- PythonEvalType .NON_UDF ,
102- Array (Array (0 )))
103- }
104- }
105-
106- /**
107- * A helper class to run Python mapPartition/UDFs in Spark.
108- *
109- * funcs is a list of independent Python functions, each one of them is a list of chained Python
110- * functions (from bottom to top).
111- */
112- private [spark] class PythonRunner (
113- funcs : Seq [ChainedPythonFunctions ],
114- bufferSize : Int ,
115- reuse_worker : Boolean ,
116- evalType : Int ,
117- argOffsets : Array [Array [Int ]])
118- extends Logging {
119-
120- require(funcs.length == argOffsets.length, " argOffsets should have the same length as funcs" )
121-
122- // All the Python functions should have the same exec, version and envvars.
123- private val envVars = funcs.head.funcs.head.envVars
124- private val pythonExec = funcs.head.funcs.head.pythonExec
125- private val pythonVer = funcs.head.funcs.head.pythonVer
126-
127- // TODO: support accumulator in multiple UDF
128- private val accumulator = funcs.head.funcs.head.accumulator
129-
130- def compute (
131- inputIterator : Iterator [_],
132- partitionIndex : Int ,
133- context : TaskContext ): Iterator [Array [Byte ]] = {
134- val startTime = System .currentTimeMillis
135- val env = SparkEnv .get
136- val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(" ," )
137- envVars.put(" SPARK_LOCAL_DIRS" , localdir) // it's also used in monitor thread
138- if (reuse_worker) {
139- envVars.put(" SPARK_REUSE_WORKER" , " 1" )
140- }
141- val worker : Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
142- // Whether is the worker released into idle pool
143- @ volatile var released = false
144-
145- // Start a thread to feed the process input from our parent's iterator
146- val writerThread = new WriterThread (env, worker, inputIterator, partitionIndex, context)
147-
148- context.addTaskCompletionListener { context =>
149- writerThread.shutdownOnTaskCompletion()
150- if (! reuse_worker || ! released) {
151- try {
152- worker.close()
153- } catch {
154- case e : Exception =>
155- logWarning(" Failed to close worker socket" , e)
156- }
157- }
158- }
159-
160- writerThread.start()
161- new MonitorThread (env, worker, context).start()
162-
163- // Return an iterator that read lines from the process's stdout
164- val stream = new DataInputStream (new BufferedInputStream (worker.getInputStream, bufferSize))
165- val stdoutIterator = new Iterator [Array [Byte ]] {
166- override def next (): Array [Byte ] = {
167- val obj = _nextObj
168- if (hasNext) {
169- _nextObj = read()
170- }
171- obj
172- }
173-
174- private def read (): Array [Byte ] = {
175- if (writerThread.exception.isDefined) {
176- throw writerThread.exception.get
177- }
178- try {
179- stream.readInt() match {
180- case length if length > 0 =>
181- val obj = new Array [Byte ](length)
182- stream.readFully(obj)
183- obj
184- case 0 => Array .empty[Byte ]
185- case SpecialLengths .TIMING_DATA =>
186- // Timing data from worker
187- val bootTime = stream.readLong()
188- val initTime = stream.readLong()
189- val finishTime = stream.readLong()
190- val boot = bootTime - startTime
191- val init = initTime - bootTime
192- val finish = finishTime - initTime
193- val total = finishTime - startTime
194- logInfo(" Times: total = %s, boot = %s, init = %s, finish = %s" .format(total, boot,
195- init, finish))
196- val memoryBytesSpilled = stream.readLong()
197- val diskBytesSpilled = stream.readLong()
198- context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
199- context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
200- read()
201- case SpecialLengths .PYTHON_EXCEPTION_THROWN =>
202- // Signals that an exception has been thrown in python
203- val exLength = stream.readInt()
204- val obj = new Array [Byte ](exLength)
205- stream.readFully(obj)
206- throw new PythonException (new String (obj, StandardCharsets .UTF_8 ),
207- writerThread.exception.getOrElse(null ))
208- case SpecialLengths .END_OF_DATA_SECTION =>
209- // We've finished the data section of the output, but we can still
210- // read some accumulator updates:
211- val numAccumulatorUpdates = stream.readInt()
212- (1 to numAccumulatorUpdates).foreach { _ =>
213- val updateLen = stream.readInt()
214- val update = new Array [Byte ](updateLen)
215- stream.readFully(update)
216- accumulator.add(update)
217- }
218- // Check whether the worker is ready to be re-used.
219- if (stream.readInt() == SpecialLengths .END_OF_STREAM ) {
220- if (reuse_worker) {
221- env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
222- released = true
223- }
224- }
225- null
226- }
227- } catch {
228-
229- case e : Exception if context.isInterrupted =>
230- logDebug(" Exception thrown after task interruption" , e)
231- throw new TaskKilledException (context.getKillReason().getOrElse(" unknown reason" ))
232-
233- case e : Exception if env.isStopped =>
234- logDebug(" Exception thrown after context is stopped" , e)
235- null // exit silently
236-
237- case e : Exception if writerThread.exception.isDefined =>
238- logError(" Python worker exited unexpectedly (crashed)" , e)
239- logError(" This may have been caused by a prior exception:" , writerThread.exception.get)
240- throw writerThread.exception.get
241-
242- case eof : EOFException =>
243- throw new SparkException (" Python worker exited unexpectedly (crashed)" , eof)
244- }
245- }
246-
247- var _nextObj = read()
248-
249- override def hasNext : Boolean = _nextObj != null
250- }
251- new InterruptibleIterator (context, stdoutIterator)
252- }
253-
254- /**
255- * The thread responsible for writing the data from the PythonRDD's parent iterator to the
256- * Python process.
257- */
258- class WriterThread (
259- env : SparkEnv ,
260- worker : Socket ,
261- inputIterator : Iterator [_],
262- partitionIndex : Int ,
263- context : TaskContext )
264- extends Thread (s " stdout writer for $pythonExec" ) {
265-
266- @ volatile private var _exception : Exception = null
267-
268- private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
269- private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
270-
271- setDaemon(true )
272-
273- /** Contains the exception thrown while writing the parent iterator to the Python process. */
274- def exception : Option [Exception ] = Option (_exception)
275-
276- /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
277- def shutdownOnTaskCompletion () {
278- assert(context.isCompleted)
279- this .interrupt()
280- }
281-
282- override def run (): Unit = Utils .logUncaughtExceptions {
283- try {
284- TaskContext .setTaskContext(context)
285- val stream = new BufferedOutputStream (worker.getOutputStream, bufferSize)
286- val dataOut = new DataOutputStream (stream)
287- // Partition index
288- dataOut.writeInt(partitionIndex)
289- // Python version of driver
290- PythonRDD .writeUTF(pythonVer, dataOut)
291- // Write out the TaskContextInfo
292- dataOut.writeInt(context.stageId())
293- dataOut.writeInt(context.partitionId())
294- dataOut.writeInt(context.attemptNumber())
295- dataOut.writeLong(context.taskAttemptId())
296- // sparkFilesDir
297- PythonRDD .writeUTF(SparkFiles .getRootDirectory(), dataOut)
298- // Python includes (*.zip and *.egg files)
299- dataOut.writeInt(pythonIncludes.size)
300- for (include <- pythonIncludes) {
301- PythonRDD .writeUTF(include, dataOut)
302- }
303- // Broadcast variables
304- val oldBids = PythonRDD .getWorkerBroadcasts(worker)
305- val newBids = broadcastVars.map(_.id).toSet
306- // number of different broadcasts
307- val toRemove = oldBids.diff(newBids)
308- val cnt = toRemove.size + newBids.diff(oldBids).size
309- dataOut.writeInt(cnt)
310- for (bid <- toRemove) {
311- // remove the broadcast from worker
312- dataOut.writeLong(- bid - 1 ) // bid >= 0
313- oldBids.remove(bid)
314- }
315- for (broadcast <- broadcastVars) {
316- if (! oldBids.contains(broadcast.id)) {
317- // send new broadcast
318- dataOut.writeLong(broadcast.id)
319- PythonRDD .writeUTF(broadcast.value.path, dataOut)
320- oldBids.add(broadcast.id)
321- }
322- }
323- dataOut.flush()
324- // Serialized command:
325- dataOut.writeInt(evalType)
326- if (evalType != PythonEvalType .NON_UDF ) {
327- dataOut.writeInt(funcs.length)
328- funcs.zip(argOffsets).foreach { case (chained, offsets) =>
329- dataOut.writeInt(offsets.length)
330- offsets.foreach { offset =>
331- dataOut.writeInt(offset)
332- }
333- dataOut.writeInt(chained.funcs.length)
334- chained.funcs.foreach { f =>
335- dataOut.writeInt(f.command.length)
336- dataOut.write(f.command)
337- }
338- }
339- } else {
340- val command = funcs.head.funcs.head.command
341- dataOut.writeInt(command.length)
342- dataOut.write(command)
343- }
344- // Data values
345- PythonRDD .writeIteratorToStream(inputIterator, dataOut)
346- dataOut.writeInt(SpecialLengths .END_OF_DATA_SECTION )
347- dataOut.writeInt(SpecialLengths .END_OF_STREAM )
348- dataOut.flush()
349- } catch {
350- case e : Exception if context.isCompleted || context.isInterrupted =>
351- logDebug(" Exception thrown after task completion (likely due to cleanup)" , e)
352- if (! worker.isClosed) {
353- Utils .tryLog(worker.shutdownOutput())
354- }
355-
356- case e : Exception =>
357- // We must avoid throwing exceptions here, because the thread uncaught exception handler
358- // will kill the whole executor (see org.apache.spark.executor.Executor).
359- _exception = e
360- if (! worker.isClosed) {
361- Utils .tryLog(worker.shutdownOutput())
362- }
363- }
364- }
365- }
366-
367- /**
368- * It is necessary to have a monitor thread for python workers if the user cancels with
369- * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
370- * threads can block indefinitely.
371- */
372- class MonitorThread (env : SparkEnv , worker : Socket , context : TaskContext )
373- extends Thread (s " Worker Monitor for $pythonExec" ) {
374-
375- setDaemon(true )
376-
377- override def run () {
378- // Kill the worker if it is interrupted, checking until task completion.
379- // TODO: This has a race condition if interruption occurs, as completed may still become true.
380- while (! context.isInterrupted && ! context.isCompleted) {
381- Thread .sleep(2000 )
382- }
383- if (! context.isCompleted) {
384- try {
385- logWarning(" Incomplete task interrupted: Attempting to kill Python Worker" )
386- env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
387- } catch {
388- case e : Exception =>
389- logError(" Exception when trying to kill worker" , e)
390- }
391- }
392- }
393- }
394- }
395-
39686/** Thrown for exceptions in user Python code. */
397- private class PythonException (msg : String , cause : Exception ) extends RuntimeException (msg, cause)
87+ private [spark] class PythonException (msg : String , cause : Exception )
88+ extends RuntimeException (msg, cause)
39889
39990/**
40091 * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
@@ -411,14 +102,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte]
411102 val asJavaPairRDD : JavaPairRDD [Long , Array [Byte ]] = JavaPairRDD .fromRDD(this )
412103}
413104
414- private object SpecialLengths {
415- val END_OF_DATA_SECTION = - 1
416- val PYTHON_EXCEPTION_THROWN = - 2
417- val TIMING_DATA = - 3
418- val END_OF_STREAM = - 4
419- val NULL = - 5
420- }
421-
422105private [spark] object PythonRDD extends Logging {
423106
424107 // remember the broadcasts sent to each worker
0 commit comments