Skip to content

Commit 09cbf3d

Browse files
ueshinHyukjinKwon
authored andcommitted
[SPARK-22125][PYSPARK][SQL] Enable Arrow Stream format for vectorized UDF.
## What changes were proposed in this pull request? Currently we use Arrow File format to communicate with Python worker when invoking vectorized UDF but we can use Arrow Stream format. This pr replaces the Arrow File format with the Arrow Stream format. ## How was this patch tested? Existing tests. Author: Takuya UESHIN <[email protected]> Closes #19349 from ueshin/issues/SPARK-22125.
1 parent 12e740b commit 09cbf3d

File tree

9 files changed

+825
-372
lines changed

9 files changed

+825
-372
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 4 additions & 321 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ private[spark] class PythonRDD(
4848
extends RDD[Array[Byte]](parent) {
4949

5050
val bufferSize = conf.getInt("spark.buffer.size", 65536)
51-
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
51+
val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true)
5252

5353
override def getPartitions: Array[Partition] = firstParent.partitions
5454

@@ -59,7 +59,7 @@ private[spark] class PythonRDD(
5959
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
6060

6161
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
62-
val runner = PythonRunner(func, bufferSize, reuse_worker)
62+
val runner = PythonRunner(func, bufferSize, reuseWorker)
6363
runner.compute(firstParent.iterator(split, context), split.index, context)
6464
}
6565
}
@@ -83,318 +83,9 @@ private[spark] case class PythonFunction(
8383
*/
8484
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
8585

86-
/**
87-
* Enumerate the type of command that will be sent to the Python worker
88-
*/
89-
private[spark] object PythonEvalType {
90-
val NON_UDF = 0
91-
val SQL_BATCHED_UDF = 1
92-
val SQL_PANDAS_UDF = 2
93-
}
94-
95-
private[spark] object PythonRunner {
96-
def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
97-
new PythonRunner(
98-
Seq(ChainedPythonFunctions(Seq(func))),
99-
bufferSize,
100-
reuse_worker,
101-
PythonEvalType.NON_UDF,
102-
Array(Array(0)))
103-
}
104-
}
105-
106-
/**
107-
* A helper class to run Python mapPartition/UDFs in Spark.
108-
*
109-
* funcs is a list of independent Python functions, each one of them is a list of chained Python
110-
* functions (from bottom to top).
111-
*/
112-
private[spark] class PythonRunner(
113-
funcs: Seq[ChainedPythonFunctions],
114-
bufferSize: Int,
115-
reuse_worker: Boolean,
116-
evalType: Int,
117-
argOffsets: Array[Array[Int]])
118-
extends Logging {
119-
120-
require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
121-
122-
// All the Python functions should have the same exec, version and envvars.
123-
private val envVars = funcs.head.funcs.head.envVars
124-
private val pythonExec = funcs.head.funcs.head.pythonExec
125-
private val pythonVer = funcs.head.funcs.head.pythonVer
126-
127-
// TODO: support accumulator in multiple UDF
128-
private val accumulator = funcs.head.funcs.head.accumulator
129-
130-
def compute(
131-
inputIterator: Iterator[_],
132-
partitionIndex: Int,
133-
context: TaskContext): Iterator[Array[Byte]] = {
134-
val startTime = System.currentTimeMillis
135-
val env = SparkEnv.get
136-
val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
137-
envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread
138-
if (reuse_worker) {
139-
envVars.put("SPARK_REUSE_WORKER", "1")
140-
}
141-
val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
142-
// Whether is the worker released into idle pool
143-
@volatile var released = false
144-
145-
// Start a thread to feed the process input from our parent's iterator
146-
val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context)
147-
148-
context.addTaskCompletionListener { context =>
149-
writerThread.shutdownOnTaskCompletion()
150-
if (!reuse_worker || !released) {
151-
try {
152-
worker.close()
153-
} catch {
154-
case e: Exception =>
155-
logWarning("Failed to close worker socket", e)
156-
}
157-
}
158-
}
159-
160-
writerThread.start()
161-
new MonitorThread(env, worker, context).start()
162-
163-
// Return an iterator that read lines from the process's stdout
164-
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
165-
val stdoutIterator = new Iterator[Array[Byte]] {
166-
override def next(): Array[Byte] = {
167-
val obj = _nextObj
168-
if (hasNext) {
169-
_nextObj = read()
170-
}
171-
obj
172-
}
173-
174-
private def read(): Array[Byte] = {
175-
if (writerThread.exception.isDefined) {
176-
throw writerThread.exception.get
177-
}
178-
try {
179-
stream.readInt() match {
180-
case length if length > 0 =>
181-
val obj = new Array[Byte](length)
182-
stream.readFully(obj)
183-
obj
184-
case 0 => Array.empty[Byte]
185-
case SpecialLengths.TIMING_DATA =>
186-
// Timing data from worker
187-
val bootTime = stream.readLong()
188-
val initTime = stream.readLong()
189-
val finishTime = stream.readLong()
190-
val boot = bootTime - startTime
191-
val init = initTime - bootTime
192-
val finish = finishTime - initTime
193-
val total = finishTime - startTime
194-
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
195-
init, finish))
196-
val memoryBytesSpilled = stream.readLong()
197-
val diskBytesSpilled = stream.readLong()
198-
context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
199-
context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
200-
read()
201-
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
202-
// Signals that an exception has been thrown in python
203-
val exLength = stream.readInt()
204-
val obj = new Array[Byte](exLength)
205-
stream.readFully(obj)
206-
throw new PythonException(new String(obj, StandardCharsets.UTF_8),
207-
writerThread.exception.getOrElse(null))
208-
case SpecialLengths.END_OF_DATA_SECTION =>
209-
// We've finished the data section of the output, but we can still
210-
// read some accumulator updates:
211-
val numAccumulatorUpdates = stream.readInt()
212-
(1 to numAccumulatorUpdates).foreach { _ =>
213-
val updateLen = stream.readInt()
214-
val update = new Array[Byte](updateLen)
215-
stream.readFully(update)
216-
accumulator.add(update)
217-
}
218-
// Check whether the worker is ready to be re-used.
219-
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
220-
if (reuse_worker) {
221-
env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
222-
released = true
223-
}
224-
}
225-
null
226-
}
227-
} catch {
228-
229-
case e: Exception if context.isInterrupted =>
230-
logDebug("Exception thrown after task interruption", e)
231-
throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))
232-
233-
case e: Exception if env.isStopped =>
234-
logDebug("Exception thrown after context is stopped", e)
235-
null // exit silently
236-
237-
case e: Exception if writerThread.exception.isDefined =>
238-
logError("Python worker exited unexpectedly (crashed)", e)
239-
logError("This may have been caused by a prior exception:", writerThread.exception.get)
240-
throw writerThread.exception.get
241-
242-
case eof: EOFException =>
243-
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
244-
}
245-
}
246-
247-
var _nextObj = read()
248-
249-
override def hasNext: Boolean = _nextObj != null
250-
}
251-
new InterruptibleIterator(context, stdoutIterator)
252-
}
253-
254-
/**
255-
* The thread responsible for writing the data from the PythonRDD's parent iterator to the
256-
* Python process.
257-
*/
258-
class WriterThread(
259-
env: SparkEnv,
260-
worker: Socket,
261-
inputIterator: Iterator[_],
262-
partitionIndex: Int,
263-
context: TaskContext)
264-
extends Thread(s"stdout writer for $pythonExec") {
265-
266-
@volatile private var _exception: Exception = null
267-
268-
private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
269-
private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
270-
271-
setDaemon(true)
272-
273-
/** Contains the exception thrown while writing the parent iterator to the Python process. */
274-
def exception: Option[Exception] = Option(_exception)
275-
276-
/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
277-
def shutdownOnTaskCompletion() {
278-
assert(context.isCompleted)
279-
this.interrupt()
280-
}
281-
282-
override def run(): Unit = Utils.logUncaughtExceptions {
283-
try {
284-
TaskContext.setTaskContext(context)
285-
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
286-
val dataOut = new DataOutputStream(stream)
287-
// Partition index
288-
dataOut.writeInt(partitionIndex)
289-
// Python version of driver
290-
PythonRDD.writeUTF(pythonVer, dataOut)
291-
// Write out the TaskContextInfo
292-
dataOut.writeInt(context.stageId())
293-
dataOut.writeInt(context.partitionId())
294-
dataOut.writeInt(context.attemptNumber())
295-
dataOut.writeLong(context.taskAttemptId())
296-
// sparkFilesDir
297-
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
298-
// Python includes (*.zip and *.egg files)
299-
dataOut.writeInt(pythonIncludes.size)
300-
for (include <- pythonIncludes) {
301-
PythonRDD.writeUTF(include, dataOut)
302-
}
303-
// Broadcast variables
304-
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
305-
val newBids = broadcastVars.map(_.id).toSet
306-
// number of different broadcasts
307-
val toRemove = oldBids.diff(newBids)
308-
val cnt = toRemove.size + newBids.diff(oldBids).size
309-
dataOut.writeInt(cnt)
310-
for (bid <- toRemove) {
311-
// remove the broadcast from worker
312-
dataOut.writeLong(- bid - 1) // bid >= 0
313-
oldBids.remove(bid)
314-
}
315-
for (broadcast <- broadcastVars) {
316-
if (!oldBids.contains(broadcast.id)) {
317-
// send new broadcast
318-
dataOut.writeLong(broadcast.id)
319-
PythonRDD.writeUTF(broadcast.value.path, dataOut)
320-
oldBids.add(broadcast.id)
321-
}
322-
}
323-
dataOut.flush()
324-
// Serialized command:
325-
dataOut.writeInt(evalType)
326-
if (evalType != PythonEvalType.NON_UDF) {
327-
dataOut.writeInt(funcs.length)
328-
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
329-
dataOut.writeInt(offsets.length)
330-
offsets.foreach { offset =>
331-
dataOut.writeInt(offset)
332-
}
333-
dataOut.writeInt(chained.funcs.length)
334-
chained.funcs.foreach { f =>
335-
dataOut.writeInt(f.command.length)
336-
dataOut.write(f.command)
337-
}
338-
}
339-
} else {
340-
val command = funcs.head.funcs.head.command
341-
dataOut.writeInt(command.length)
342-
dataOut.write(command)
343-
}
344-
// Data values
345-
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
346-
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
347-
dataOut.writeInt(SpecialLengths.END_OF_STREAM)
348-
dataOut.flush()
349-
} catch {
350-
case e: Exception if context.isCompleted || context.isInterrupted =>
351-
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
352-
if (!worker.isClosed) {
353-
Utils.tryLog(worker.shutdownOutput())
354-
}
355-
356-
case e: Exception =>
357-
// We must avoid throwing exceptions here, because the thread uncaught exception handler
358-
// will kill the whole executor (see org.apache.spark.executor.Executor).
359-
_exception = e
360-
if (!worker.isClosed) {
361-
Utils.tryLog(worker.shutdownOutput())
362-
}
363-
}
364-
}
365-
}
366-
367-
/**
368-
* It is necessary to have a monitor thread for python workers if the user cancels with
369-
* interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
370-
* threads can block indefinitely.
371-
*/
372-
class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
373-
extends Thread(s"Worker Monitor for $pythonExec") {
374-
375-
setDaemon(true)
376-
377-
override def run() {
378-
// Kill the worker if it is interrupted, checking until task completion.
379-
// TODO: This has a race condition if interruption occurs, as completed may still become true.
380-
while (!context.isInterrupted && !context.isCompleted) {
381-
Thread.sleep(2000)
382-
}
383-
if (!context.isCompleted) {
384-
try {
385-
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
386-
env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
387-
} catch {
388-
case e: Exception =>
389-
logError("Exception when trying to kill worker", e)
390-
}
391-
}
392-
}
393-
}
394-
}
395-
39686
/** Thrown for exceptions in user Python code. */
397-
private class PythonException(msg: String, cause: Exception) extends RuntimeException(msg, cause)
87+
private[spark] class PythonException(msg: String, cause: Exception)
88+
extends RuntimeException(msg, cause)
39889

39990
/**
40091
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
@@ -411,14 +102,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte]
411102
val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
412103
}
413104

414-
private object SpecialLengths {
415-
val END_OF_DATA_SECTION = -1
416-
val PYTHON_EXCEPTION_THROWN = -2
417-
val TIMING_DATA = -3
418-
val END_OF_STREAM = -4
419-
val NULL = -5
420-
}
421-
422105
private[spark] object PythonRDD extends Logging {
423106

424107
// remember the broadcasts sent to each worker

0 commit comments

Comments
 (0)