From 3c45c5c132f91f32878dd52245a1beb55eca05e7 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 21 Sep 2017 14:33:01 +0900 Subject: [PATCH 01/15] Extract PythonRunner from PythonRDD.scala file. --- .../apache/spark/api/python/PythonRDD.scala | 321 +--------------- .../spark/api/python/PythonRunner.scala | 347 ++++++++++++++++++ 2 files changed, 349 insertions(+), 319 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 86d0405c678a..a98bc9c9df04 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -83,318 +83,9 @@ private[spark] case class PythonFunction( */ private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) -/** - * Enumerate the type of command that will be sent to the Python worker - */ -private[spark] object PythonEvalType { - val NON_UDF = 0 - val SQL_BATCHED_UDF = 1 - val SQL_PANDAS_UDF = 2 -} - -private[spark] object PythonRunner { - def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { - new PythonRunner( - Seq(ChainedPythonFunctions(Seq(func))), - bufferSize, - reuse_worker, - PythonEvalType.NON_UDF, - Array(Array(0))) - } -} - -/** - * A helper class to run Python mapPartition/UDFs in Spark. - * - * funcs is a list of independent Python functions, each one of them is a list of chained Python - * functions (from bottom to top). - */ -private[spark] class PythonRunner( - funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuse_worker: Boolean, - evalType: Int, - argOffsets: Array[Array[Int]]) - extends Logging { - - require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") - - // All the Python functions should have the same exec, version and envvars. - private val envVars = funcs.head.funcs.head.envVars - private val pythonExec = funcs.head.funcs.head.pythonExec - private val pythonVer = funcs.head.funcs.head.pythonVer - - // TODO: support accumulator in multiple UDF - private val accumulator = funcs.head.funcs.head.accumulator - - def compute( - inputIterator: Iterator[_], - partitionIndex: Int, - context: TaskContext): Iterator[Array[Byte]] = { - val startTime = System.currentTimeMillis - val env = SparkEnv.get - val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") - envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread - if (reuse_worker) { - envVars.put("SPARK_REUSE_WORKER", "1") - } - val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) - // Whether is the worker released into idle pool - @volatile var released = false - - // Start a thread to feed the process input from our parent's iterator - val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) - - context.addTaskCompletionListener { context => - writerThread.shutdownOnTaskCompletion() - if (!reuse_worker || !released) { - try { - worker.close() - } catch { - case e: Exception => - logWarning("Failed to close worker socket", e) - } - } - } - - writerThread.start() - new MonitorThread(env, worker, context).start() - - // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - val stdoutIterator = new Iterator[Array[Byte]] { - override def next(): Array[Byte] = { - val obj = _nextObj - if (hasNext) { - _nextObj = read() - } - obj - } - - private def read(): Array[Byte] = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - stream.readInt() match { - case length if length > 0 => - val obj = new Array[Byte](length) - stream.readFully(obj) - obj - case 0 => Array.empty[Byte] - case SpecialLengths.TIMING_DATA => - // Timing data from worker - val bootTime = stream.readLong() - val initTime = stream.readLong() - val finishTime = stream.readLong() - val boot = bootTime - startTime - val init = initTime - bootTime - val finish = finishTime - initTime - val total = finishTime - startTime - logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, - init, finish)) - val memoryBytesSpilled = stream.readLong() - val diskBytesSpilled = stream.readLong() - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - // Signals that an exception has been thrown in python - val exLength = stream.readInt() - val obj = new Array[Byte](exLength) - stream.readFully(obj) - throw new PythonException(new String(obj, StandardCharsets.UTF_8), - writerThread.exception.getOrElse(null)) - case SpecialLengths.END_OF_DATA_SECTION => - // We've finished the data section of the output, but we can still - // read some accumulator updates: - val numAccumulatorUpdates = stream.readInt() - (1 to numAccumulatorUpdates).foreach { _ => - val updateLen = stream.readInt() - val update = new Array[Byte](updateLen) - stream.readFully(update) - accumulator.add(update) - } - // Check whether the worker is ready to be re-used. - if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - if (reuse_worker) { - env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) - released = true - } - } - null - } - } catch { - - case e: Exception if context.isInterrupted => - logDebug("Exception thrown after task interruption", e) - throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) - - case e: Exception if env.isStopped => - logDebug("Exception thrown after context is stopped", e) - null // exit silently - - case e: Exception if writerThread.exception.isDefined => - logError("Python worker exited unexpectedly (crashed)", e) - logError("This may have been caused by a prior exception:", writerThread.exception.get) - throw writerThread.exception.get - - case eof: EOFException => - throw new SparkException("Python worker exited unexpectedly (crashed)", eof) - } - } - - var _nextObj = read() - - override def hasNext: Boolean = _nextObj != null - } - new InterruptibleIterator(context, stdoutIterator) - } - - /** - * The thread responsible for writing the data from the PythonRDD's parent iterator to the - * Python process. - */ - class WriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[_], - partitionIndex: Int, - context: TaskContext) - extends Thread(s"stdout writer for $pythonExec") { - - @volatile private var _exception: Exception = null - - private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet - private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) - - setDaemon(true) - - /** Contains the exception thrown while writing the parent iterator to the Python process. */ - def exception: Option[Exception] = Option(_exception) - - /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ - def shutdownOnTaskCompletion() { - assert(context.isCompleted) - this.interrupt() - } - - override def run(): Unit = Utils.logUncaughtExceptions { - try { - TaskContext.setTaskContext(context) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) - // Partition index - dataOut.writeInt(partitionIndex) - // Python version of driver - PythonRDD.writeUTF(pythonVer, dataOut) - // Write out the TaskContextInfo - dataOut.writeInt(context.stageId()) - dataOut.writeInt(context.partitionId()) - dataOut.writeInt(context.attemptNumber()) - dataOut.writeLong(context.taskAttemptId()) - // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.size) - for (include <- pythonIncludes) { - PythonRDD.writeUTF(include, dataOut) - } - // Broadcast variables - val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.map(_.id).toSet - // number of different broadcasts - val toRemove = oldBids.diff(newBids) - val cnt = toRemove.size + newBids.diff(oldBids).size - dataOut.writeInt(cnt) - for (bid <- toRemove) { - // remove the broadcast from worker - dataOut.writeLong(- bid - 1) // bid >= 0 - oldBids.remove(bid) - } - for (broadcast <- broadcastVars) { - if (!oldBids.contains(broadcast.id)) { - // send new broadcast - dataOut.writeLong(broadcast.id) - PythonRDD.writeUTF(broadcast.value.path, dataOut) - oldBids.add(broadcast.id) - } - } - dataOut.flush() - // Serialized command: - dataOut.writeInt(evalType) - if (evalType != PythonEvalType.NON_UDF) { - dataOut.writeInt(funcs.length) - funcs.zip(argOffsets).foreach { case (chained, offsets) => - dataOut.writeInt(offsets.length) - offsets.foreach { offset => - dataOut.writeInt(offset) - } - dataOut.writeInt(chained.funcs.length) - chained.funcs.foreach { f => - dataOut.writeInt(f.command.length) - dataOut.write(f.command) - } - } - } else { - val command = funcs.head.funcs.head.command - dataOut.writeInt(command.length) - dataOut.write(command) - } - // Data values - PythonRDD.writeIteratorToStream(inputIterator, dataOut) - dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) - dataOut.writeInt(SpecialLengths.END_OF_STREAM) - dataOut.flush() - } catch { - case e: Exception if context.isCompleted || context.isInterrupted => - logDebug("Exception thrown after task completion (likely due to cleanup)", e) - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - - case e: Exception => - // We must avoid throwing exceptions here, because the thread uncaught exception handler - // will kill the whole executor (see org.apache.spark.executor.Executor). - _exception = e - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - } - } - } - - /** - * It is necessary to have a monitor thread for python workers if the user cancels with - * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the - * threads can block indefinitely. - */ - class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) - extends Thread(s"Worker Monitor for $pythonExec") { - - setDaemon(true) - - override def run() { - // Kill the worker if it is interrupted, checking until task completion. - // TODO: This has a race condition if interruption occurs, as completed may still become true. - while (!context.isInterrupted && !context.isCompleted) { - Thread.sleep(2000) - } - if (!context.isCompleted) { - try { - logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) - } - } - } - } -} - /** Thrown for exceptions in user Python code. */ -private class PythonException(msg: String, cause: Exception) extends RuntimeException(msg, cause) +private[spark] class PythonException(msg: String, cause: Exception) + extends RuntimeException(msg, cause) /** * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. @@ -411,14 +102,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte] val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } -private object SpecialLengths { - val END_OF_DATA_SECTION = -1 - val PYTHON_EXCEPTION_THROWN = -2 - val TIMING_DATA = -3 - val END_OF_STREAM = -4 - val NULL = -5 -} - private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala new file mode 100644 index 000000000000..0e792d0db8d1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io._ +import java.net._ +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.util._ + + +/** + * Enumerate the type of command that will be sent to the Python worker + */ +private[spark] object PythonEvalType { + val NON_UDF = 0 + val SQL_BATCHED_UDF = 1 + val SQL_PANDAS_UDF = 2 +} + +private[spark] object PythonRunner { + def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { + new PythonRunner( + Seq(ChainedPythonFunctions(Seq(func))), + bufferSize, + reuse_worker, + PythonEvalType.NON_UDF, + Array(Array(0))) + } +} + +/** + * A helper class to run Python mapPartition/UDFs in Spark. + * + * funcs is a list of independent Python functions, each one of them is a list of chained Python + * functions (from bottom to top). + */ +private[spark] class PythonRunner( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuse_worker: Boolean, + evalType: Int, + argOffsets: Array[Array[Int]]) + extends Logging { + + require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + + // All the Python functions should have the same exec, version and envvars. + private val envVars = funcs.head.funcs.head.envVars + private val pythonExec = funcs.head.funcs.head.pythonExec + private val pythonVer = funcs.head.funcs.head.pythonVer + + // TODO: support accumulator in multiple UDF + private val accumulator = funcs.head.funcs.head.accumulator + + def compute( + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext): Iterator[Array[Byte]] = { + val startTime = System.currentTimeMillis + val env = SparkEnv.get + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread + if (reuse_worker) { + envVars.put("SPARK_REUSE_WORKER", "1") + } + val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) + // Whether is the worker released into idle pool + @volatile var released = false + + // Start a thread to feed the process input from our parent's iterator + val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) + + context.addTaskCompletionListener { context => + writerThread.shutdownOnTaskCompletion() + if (!reuse_worker || !released) { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + + writerThread.start() + new MonitorThread(env, worker, context).start() + + // Return an iterator that read lines from the process's stdout + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + val stdoutIterator = new Iterator[Array[Byte]] { + override def next(): Array[Byte] = { + val obj = _nextObj + if (hasNext) { + _nextObj = read() + } + obj + } + + private def read(): Array[Byte] = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + stream.readInt() match { + case length if length > 0 => + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + case 0 => Array.empty[Byte] + case SpecialLengths.TIMING_DATA => + // Timing data from worker + val bootTime = stream.readLong() + val initTime = stream.readLong() + val finishTime = stream.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, + init, finish)) + val memoryBytesSpilled = stream.readLong() + val diskBytesSpilled = stream.readLong() + context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + // Signals that an exception has been thrown in python + val exLength = stream.readInt() + val obj = new Array[Byte](exLength) + stream.readFully(obj) + throw new PythonException(new String(obj, StandardCharsets.UTF_8), + writerThread.exception.getOrElse(null)) + case SpecialLengths.END_OF_DATA_SECTION => + // We've finished the data section of the output, but we can still + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) + stream.readFully(update) + accumulator.add(update) + } + // Check whether the worker is ready to be re-used. + if (stream.readInt() == SpecialLengths.END_OF_STREAM) { + if (reuse_worker) { + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) + released = true + } + } + null + } + } catch { + + case e: Exception if context.isInterrupted => + logDebug("Exception thrown after task interruption", e) + throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) + + case e: Exception if env.isStopped => + logDebug("Exception thrown after context is stopped", e) + null // exit silently + + case e: Exception if writerThread.exception.isDefined => + logError("Python worker exited unexpectedly (crashed)", e) + logError("This may have been caused by a prior exception:", writerThread.exception.get) + throw writerThread.exception.get + + case eof: EOFException => + throw new SparkException("Python worker exited unexpectedly (crashed)", eof) + } + } + + var _nextObj = read() + + override def hasNext: Boolean = _nextObj != null + } + new InterruptibleIterator(context, stdoutIterator) + } + + /** + * The thread responsible for writing the data from the PythonRDD's parent iterator to the + * Python process. + */ + class WriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext) + extends Thread(s"stdout writer for $pythonExec") { + + @volatile private var _exception: Exception = null + + private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet + private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) + + setDaemon(true) + + /** Contains the exception thrown while writing the parent iterator to the Python process. */ + def exception: Option[Exception] = Option(_exception) + + /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ + def shutdownOnTaskCompletion() { + assert(context.isCompleted) + this.interrupt() + } + + override def run(): Unit = Utils.logUncaughtExceptions { + try { + TaskContext.setTaskContext(context) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + // Partition index + dataOut.writeInt(partitionIndex) + // Python version of driver + PythonRDD.writeUTF(pythonVer, dataOut) + // Write out the TaskContextInfo + dataOut.writeInt(context.stageId()) + dataOut.writeInt(context.partitionId()) + dataOut.writeInt(context.attemptNumber()) + dataOut.writeLong(context.taskAttemptId()) + // sparkFilesDir + PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.size) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } + // Broadcast variables + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.map(_.id).toSet + // number of different broadcasts + val toRemove = oldBids.diff(newBids) + val cnt = toRemove.size + newBids.diff(oldBids).size + dataOut.writeInt(cnt) + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(- bid - 1) // bid >= 0 + oldBids.remove(bid) + } + for (broadcast <- broadcastVars) { + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + PythonRDD.writeUTF(broadcast.value.path, dataOut) + oldBids.add(broadcast.id) + } + } + dataOut.flush() + // Serialized command: + dataOut.writeInt(evalType) + if (evalType != PythonEvalType.NON_UDF) { + dataOut.writeInt(funcs.length) + funcs.zip(argOffsets).foreach { case (chained, offsets) => + dataOut.writeInt(offsets.length) + offsets.foreach { offset => + dataOut.writeInt(offset) + } + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } + } + } else { + val command = funcs.head.funcs.head.command + dataOut.writeInt(command.length) + dataOut.write(command) + } + // Data values + PythonRDD.writeIteratorToStream(inputIterator, dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + dataOut.writeInt(SpecialLengths.END_OF_STREAM) + dataOut.flush() + } catch { + case e: Exception if context.isCompleted || context.isInterrupted => + logDebug("Exception thrown after task completion (likely due to cleanup)", e) + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + + case e: Exception => + // We must avoid throwing exceptions here, because the thread uncaught exception handler + // will kill the whole executor (see org.apache.spark.executor.Executor). + _exception = e + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + } + } + } + + /** + * It is necessary to have a monitor thread for python workers if the user cancels with + * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the + * threads can block indefinitely. + */ + class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) + extends Thread(s"Worker Monitor for $pythonExec") { + + setDaemon(true) + + override def run() { + // Kill the worker if it is interrupted, checking until task completion. + // TODO: This has a race condition if interruption occurs, as completed may still become true. + while (!context.isInterrupted && !context.isCompleted) { + Thread.sleep(2000) + } + if (!context.isCompleted) { + try { + logWarning("Incomplete task interrupted: Attempting to kill Python Worker") + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) + } catch { + case e: Exception => + logError("Exception when trying to kill worker", e) + } + } + } + } +} + +private[spark] object SpecialLengths { + val END_OF_DATA_SECTION = -1 + val PYTHON_EXCEPTION_THROWN = -2 + val TIMING_DATA = -3 + val END_OF_STREAM = -4 + val NULL = -5 +} From 1cd832cb796bdb7e330e56a953274ed577dc8876 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 21 Sep 2017 15:33:43 +0900 Subject: [PATCH 02/15] Extract writer thread. --- .../apache/spark/api/python/PythonRDD.scala | 4 +- .../spark/api/python/PythonRunner.scala | 107 +++++++++++------- .../python/ArrowEvalPythonExec.scala | 2 +- .../python/BatchEvalPythonExec.scala | 2 +- .../execution/python/PythonUDFRunner.scala | 66 +++++++++++ 5 files changed, 133 insertions(+), 48 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a98bc9c9df04..319bd6ebee26 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -48,7 +48,7 @@ private[spark] class PythonRDD( extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) - val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) + val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) override def getPartitions: Array[Partition] = firstParent.partitions @@ -59,7 +59,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = PythonRunner(func, bufferSize, reuse_worker) + val runner = PythonCommandRunner(func, bufferSize, reuseWorker) runner.compute(firstParent.iterator(split, context), split.index, context) } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 0e792d0db8d1..5aee89d1c659 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -37,27 +37,16 @@ private[spark] object PythonEvalType { val SQL_PANDAS_UDF = 2 } -private[spark] object PythonRunner { - def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { - new PythonRunner( - Seq(ChainedPythonFunctions(Seq(func))), - bufferSize, - reuse_worker, - PythonEvalType.NON_UDF, - Array(Array(0))) - } -} - /** * A helper class to run Python mapPartition/UDFs in Spark. * * funcs is a list of independent Python functions, each one of them is a list of chained Python * functions (from bottom to top). */ -private[spark] class PythonRunner( +private[spark] abstract class PythonRunner( funcs: Seq[ChainedPythonFunctions], bufferSize: Int, - reuse_worker: Boolean, + reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]]) extends Logging { @@ -65,12 +54,12 @@ private[spark] class PythonRunner( require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") // All the Python functions should have the same exec, version and envvars. - private val envVars = funcs.head.funcs.head.envVars - private val pythonExec = funcs.head.funcs.head.pythonExec - private val pythonVer = funcs.head.funcs.head.pythonVer + protected val envVars = funcs.head.funcs.head.envVars + protected val pythonExec = funcs.head.funcs.head.pythonExec + protected val pythonVer = funcs.head.funcs.head.pythonVer // TODO: support accumulator in multiple UDF - private val accumulator = funcs.head.funcs.head.accumulator + protected val accumulator = funcs.head.funcs.head.accumulator def compute( inputIterator: Iterator[_], @@ -80,7 +69,7 @@ private[spark] class PythonRunner( val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread - if (reuse_worker) { + if (reuseWorker) { envVars.put("SPARK_REUSE_WORKER", "1") } val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) @@ -88,11 +77,11 @@ private[spark] class PythonRunner( @volatile var released = false // Start a thread to feed the process input from our parent's iterator - val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) + val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() - if (!reuse_worker || !released) { + if (!reuseWorker || !released) { try { worker.close() } catch { @@ -162,7 +151,7 @@ private[spark] class PythonRunner( } // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - if (reuse_worker) { + if (reuseWorker) { env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) released = true } @@ -196,11 +185,18 @@ private[spark] class PythonRunner( new InterruptibleIterator(context, stdoutIterator) } + def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext): WriterThread + /** * The thread responsible for writing the data from the PythonRDD's parent iterator to the * Python process. */ - class WriterThread( + abstract class WriterThread( env: SparkEnv, worker: Socket, inputIterator: Iterator[_], @@ -224,6 +220,9 @@ private[spark] class PythonRunner( this.interrupt() } + def writeCommand(dataOut: DataOutputStream): Unit + def writeIteratorToStream(dataOut: DataOutputStream): Unit + override def run(): Unit = Utils.logUncaughtExceptions { try { TaskContext.setTaskContext(context) @@ -266,29 +265,11 @@ private[spark] class PythonRunner( } } dataOut.flush() - // Serialized command: + dataOut.writeInt(evalType) - if (evalType != PythonEvalType.NON_UDF) { - dataOut.writeInt(funcs.length) - funcs.zip(argOffsets).foreach { case (chained, offsets) => - dataOut.writeInt(offsets.length) - offsets.foreach { offset => - dataOut.writeInt(offset) - } - dataOut.writeInt(chained.funcs.length) - chained.funcs.foreach { f => - dataOut.writeInt(f.command.length) - dataOut.write(f.command) - } - } - } else { - val command = funcs.head.funcs.head.command - dataOut.writeInt(command.length) - dataOut.write(command) - } - // Data values - PythonRDD.writeIteratorToStream(inputIterator, dataOut) - dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + writeCommand(dataOut) + writeIteratorToStream(dataOut) + dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() } catch { @@ -338,6 +319,44 @@ private[spark] class PythonRunner( } } +private[spark] object PythonCommandRunner { + + def apply(func: PythonFunction, bufferSize: Int, reuseWorker: Boolean): PythonCommandRunner = { + new PythonCommandRunner(Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuseWorker) + } +} + +/** + * A helper class to run Python mapPartition in Spark. + */ +private[spark] class PythonCommandRunner( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean) + extends PythonRunner(funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0))) { + + override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + override def writeCommand(dataOut: DataOutputStream): Unit = { + val command = funcs.head.funcs.head.command + dataOut.writeInt(command.length) + dataOut.write(command) + } + + override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + PythonRDD.writeIteratorToStream(inputIterator, dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + } + } + } +} + private[spark] object SpecialLengths { val END_OF_DATA_SECTION = -1 val PYTHON_EXCEPTION_THROWN = -2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 5e72cd255873..d3d04b1ddc95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -43,7 +43,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi iter, schema, conf.arrowMaxRecordsPerBatch, context).map(_.asPythonSerializable) // Output iterator for results from Python. - val outputIterator = new PythonRunner( + val outputIterator = new PythonUDFRunner( funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets) .compute(inputIterator, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 2978eac50554..5bdf816000c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -68,7 +68,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi }.grouped(100).map(x => pickle.dumps(x.toArray)) // Output iterator for results from Python. - val outputIterator = new PythonRunner( + val outputIterator = new PythonUDFRunner( funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets) .compute(inputIterator, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala new file mode 100644 index 000000000000..b22b0d2622d6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ + +import org.apache.spark._ +import org.apache.spark.api.python._ + +/** + * A helper class to run Python UDFs in Spark. + */ +class PythonUDFRunner( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean, + evalType: Int, + argOffsets: Array[Array[Int]]) + extends PythonRunner(funcs, bufferSize, reuseWorker, evalType, argOffsets) { + + override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + override def writeCommand(dataOut: DataOutputStream): Unit = { + dataOut.writeInt(funcs.length) + funcs.zip(argOffsets).foreach { case (chained, offsets) => + dataOut.writeInt(offsets.length) + offsets.foreach { offset => + dataOut.writeInt(offset) + } + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } + } + } + + override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + PythonRDD.writeIteratorToStream(inputIterator, dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + } + } + } +} From 919811d9ffacb8218acf7148b2f0918b255c4f3a Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 21 Sep 2017 16:23:55 +0900 Subject: [PATCH 03/15] Extract reader iterator. --- .../spark/api/python/PythonRunner.scala | 239 +++++++++++------- .../execution/python/PythonUDFRunner.scala | 38 ++- 2 files changed, 187 insertions(+), 90 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 5aee89d1c659..c70393fa712b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -20,6 +20,7 @@ package org.apache.spark.api.python import java.io._ import java.net._ import java.nio.charset.StandardCharsets +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ @@ -74,14 +75,14 @@ private[spark] abstract class PythonRunner( } val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool - @volatile var released = false + val released = new AtomicBoolean(false) // Start a thread to feed the process input from our parent's iterator val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() - if (!reuseWorker || !released) { + if (!reuseWorker || !released.get) { try { worker.close() } catch { @@ -96,102 +97,28 @@ private[spark] abstract class PythonRunner( // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - val stdoutIterator = new Iterator[Array[Byte]] { - override def next(): Array[Byte] = { - val obj = _nextObj - if (hasNext) { - _nextObj = read() - } - obj - } - - private def read(): Array[Byte] = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - stream.readInt() match { - case length if length > 0 => - val obj = new Array[Byte](length) - stream.readFully(obj) - obj - case 0 => Array.empty[Byte] - case SpecialLengths.TIMING_DATA => - // Timing data from worker - val bootTime = stream.readLong() - val initTime = stream.readLong() - val finishTime = stream.readLong() - val boot = bootTime - startTime - val init = initTime - bootTime - val finish = finishTime - initTime - val total = finishTime - startTime - logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, - init, finish)) - val memoryBytesSpilled = stream.readLong() - val diskBytesSpilled = stream.readLong() - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - // Signals that an exception has been thrown in python - val exLength = stream.readInt() - val obj = new Array[Byte](exLength) - stream.readFully(obj) - throw new PythonException(new String(obj, StandardCharsets.UTF_8), - writerThread.exception.getOrElse(null)) - case SpecialLengths.END_OF_DATA_SECTION => - // We've finished the data section of the output, but we can still - // read some accumulator updates: - val numAccumulatorUpdates = stream.readInt() - (1 to numAccumulatorUpdates).foreach { _ => - val updateLen = stream.readInt() - val update = new Array[Byte](updateLen) - stream.readFully(update) - accumulator.add(update) - } - // Check whether the worker is ready to be re-used. - if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - if (reuseWorker) { - env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) - released = true - } - } - null - } - } catch { - - case e: Exception if context.isInterrupted => - logDebug("Exception thrown after task interruption", e) - throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) - - case e: Exception if env.isStopped => - logDebug("Exception thrown after context is stopped", e) - null // exit silently - case e: Exception if writerThread.exception.isDefined => - logError("Python worker exited unexpectedly (crashed)", e) - logError("This may have been caused by a prior exception:", writerThread.exception.get) - throw writerThread.exception.get - - case eof: EOFException => - throw new SparkException("Python worker exited unexpectedly (crashed)", eof) - } - } - - var _nextObj = read() - - override def hasNext: Boolean = _nextObj != null - } + val stdoutIterator = newReaderIterator( + stream, writerThread, startTime, env, worker, released, context) new InterruptibleIterator(context, stdoutIterator) } - def newWriterThread( + protected def newWriterThread( env: SparkEnv, worker: Socket, inputIterator: Iterator[_], partitionIndex: Int, context: TaskContext): WriterThread + protected def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[Array[Byte]] + /** * The thread responsible for writing the data from the PythonRDD's parent iterator to the * Python process. @@ -290,6 +217,105 @@ private[spark] abstract class PythonRunner( } } + abstract class ReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext) + extends Iterator[Array[Byte]] { + + private var nextObj: Array[Byte] = _ + private var eos = false + + override def hasNext: Boolean = nextObj != null || { + if (!eos) { + nextObj = read() + hasNext + } else { + false + } + } + + override def next(): Array[Byte] = { + if (hasNext) { + val obj = nextObj + nextObj = null + obj + } else { + Iterator.empty.next() + } + } + + protected def read(): Array[Byte] + + protected def handleTimingData(): Unit = { + // Timing data from worker + val bootTime = stream.readLong() + val initTime = stream.readLong() + val finishTime = stream.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, + init, finish)) + val memoryBytesSpilled = stream.readLong() + val diskBytesSpilled = stream.readLong() + context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + } + + protected def handlePythonException(): PythonException = { + // Signals that an exception has been thrown in python + val exLength = stream.readInt() + val obj = new Array[Byte](exLength) + stream.readFully(obj) + new PythonException(new String(obj, StandardCharsets.UTF_8), + writerThread.exception.getOrElse(null)) + } + + protected def handleEndOfDataSection(): Unit = { + // We've finished the data section of the output, but we can still + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) + stream.readFully(update) + accumulator.add(update) + } + // Check whether the worker is ready to be re-used. + if (stream.readInt() == SpecialLengths.END_OF_STREAM) { + if (reuseWorker) { + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) + released.set(true) + } + } + eos = true + } + + protected val handleException: PartialFunction[Throwable, Array[Byte]] = { + case e: Exception if context.isInterrupted => + logDebug("Exception thrown after task interruption", e) + throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) + + case e: Exception if env.isStopped => + logDebug("Exception thrown after context is stopped", e) + null // exit silently + + case e: Exception if writerThread.exception.isDefined => + logError("Python worker exited unexpectedly (crashed)", e) + logError("This may have been caused by a prior exception:", writerThread.exception.get) + throw writerThread.exception.get + + case eof: EOFException => + throw new SparkException("Python worker exited unexpectedly (crashed)", eof) + } + } + /** * It is necessary to have a monitor thread for python workers if the user cancels with * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the @@ -335,7 +361,7 @@ private[spark] class PythonCommandRunner( reuseWorker: Boolean) extends PythonRunner(funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0))) { - override def newWriterThread( + protected override def newWriterThread( env: SparkEnv, worker: Socket, inputIterator: Iterator[_], @@ -355,6 +381,41 @@ private[spark] class PythonCommandRunner( } } } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[Array[Byte]] = { + new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + + protected override def read(): Array[Byte] = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + stream.readInt() match { + case length if length > 0 => + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + case 0 => Array.empty[Byte] + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } catch handleException + } + } + } } private[spark] object SpecialLengths { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index b22b0d2622d6..5319d3cfcc6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python import java.io._ import java.net._ +import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark._ import org.apache.spark.api.python._ @@ -34,7 +35,7 @@ class PythonUDFRunner( argOffsets: Array[Array[Int]]) extends PythonRunner(funcs, bufferSize, reuseWorker, evalType, argOffsets) { - override def newWriterThread( + protected override def newWriterThread( env: SparkEnv, worker: Socket, inputIterator: Iterator[_], @@ -63,4 +64,39 @@ class PythonUDFRunner( } } } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[Array[Byte]] = { + new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + + protected override def read(): Array[Byte] = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + stream.readInt() match { + case length if length > 0 => + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + case 0 => Array.empty[Byte] + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } catch handleException + } + } + } } From b2fed104ee00f5bf8235e21b01f89c98ec9400fc Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 21 Sep 2017 18:00:42 +0900 Subject: [PATCH 04/15] Introduce ArrowStreamPythonUDFRunner. --- .../spark/api/python/PythonRunner.scala | 31 +-- .../python/ArrowStreamPythonUDFRunner.scala | 197 ++++++++++++++++++ .../execution/python/PythonUDFRunner.scala | 5 +- 3 files changed, 216 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index c70393fa712b..4fd58bc99caf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -44,7 +44,7 @@ private[spark] object PythonEvalType { * funcs is a list of independent Python functions, each one of them is a list of chained Python * functions (from bottom to top). */ -private[spark] abstract class PythonRunner( +private[spark] abstract class PythonRunner[IN, OUT]( funcs: Seq[ChainedPythonFunctions], bufferSize: Int, reuseWorker: Boolean, @@ -63,9 +63,9 @@ private[spark] abstract class PythonRunner( protected val accumulator = funcs.head.funcs.head.accumulator def compute( - inputIterator: Iterator[_], + inputIterator: Iterator[IN], partitionIndex: Int, - context: TaskContext): Iterator[Array[Byte]] = { + context: TaskContext): Iterator[OUT] = { val startTime = System.currentTimeMillis val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") @@ -106,7 +106,7 @@ private[spark] abstract class PythonRunner( protected def newWriterThread( env: SparkEnv, worker: Socket, - inputIterator: Iterator[_], + inputIterator: Iterator[IN], partitionIndex: Int, context: TaskContext): WriterThread @@ -117,7 +117,7 @@ private[spark] abstract class PythonRunner( env: SparkEnv, worker: Socket, released: AtomicBoolean, - context: TaskContext): Iterator[Array[Byte]] + context: TaskContext): Iterator[OUT] /** * The thread responsible for writing the data from the PythonRDD's parent iterator to the @@ -126,7 +126,7 @@ private[spark] abstract class PythonRunner( abstract class WriterThread( env: SparkEnv, worker: Socket, - inputIterator: Iterator[_], + inputIterator: Iterator[IN], partitionIndex: Int, context: TaskContext) extends Thread(s"stdout writer for $pythonExec") { @@ -225,9 +225,9 @@ private[spark] abstract class PythonRunner( worker: Socket, released: AtomicBoolean, context: TaskContext) - extends Iterator[Array[Byte]] { + extends Iterator[OUT] { - private var nextObj: Array[Byte] = _ + private var nextObj: OUT = _ private var eos = false override def hasNext: Boolean = nextObj != null || { @@ -239,17 +239,17 @@ private[spark] abstract class PythonRunner( } } - override def next(): Array[Byte] = { + override def next(): OUT = { if (hasNext) { val obj = nextObj - nextObj = null + nextObj = null.asInstanceOf[OUT] obj } else { Iterator.empty.next() } } - protected def read(): Array[Byte] + protected def read(): OUT protected def handleTimingData(): Unit = { // Timing data from worker @@ -297,14 +297,14 @@ private[spark] abstract class PythonRunner( eos = true } - protected val handleException: PartialFunction[Throwable, Array[Byte]] = { + protected val handleException: PartialFunction[Throwable, OUT] = { case e: Exception if context.isInterrupted => logDebug("Exception thrown after task interruption", e) throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) case e: Exception if env.isStopped => logDebug("Exception thrown after context is stopped", e) - null // exit silently + null.asInstanceOf[OUT] // exit silently case e: Exception if writerThread.exception.isDefined => logError("Python worker exited unexpectedly (crashed)", e) @@ -359,12 +359,13 @@ private[spark] class PythonCommandRunner( funcs: Seq[ChainedPythonFunctions], bufferSize: Int, reuseWorker: Boolean) - extends PythonRunner(funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0))) { + extends PythonRunner[Array[Byte], Array[Byte]]( + funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0))) { protected override def newWriterThread( env: SparkEnv, worker: Socket, - inputIterator: Iterator[_], + inputIterator: Iterator[Array[Byte]], partitionIndex: Int, context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala new file mode 100644 index 000000000000..d19e14f091cc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.stream.{ArrowStreamReader, ArrowStreamWriter} + +import org.apache.spark._ +import org.apache.spark.api.python._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} +import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +/** + * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. + */ +class ArrowStreamPythonUDFRunner( + funcs: Seq[ChainedPythonFunctions], + batchSize: Int, + bufferSize: Int, + reuseWorker: Boolean, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType) + extends PythonRunner[InternalRow, ColumnarBatch]( + funcs, bufferSize, reuseWorker, evalType, argOffsets) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[InternalRow], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + override def writeCommand(dataOut: DataOutputStream): Unit = { + dataOut.writeInt(funcs.length) + funcs.zip(argOffsets).foreach { case (chained, offsets) => + dataOut.writeInt(offsets.length) + offsets.foreach { offset => + dataOut.writeInt(offset) + } + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } + } + } + + override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val arrowSchema = ArrowUtils.toArrowSchema(schema) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for $pythonExec", 0, Long.MaxValue) + + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter = ArrowWriter.create(root) + + var closed = false + + context.addTaskCompletionListener { _ => + if (!closed) { + root.close() + allocator.close() + } + } + + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() + + Utils.tryWithSafeFinally { + while (inputIterator.hasNext) { + var rowCount = 0 + while (inputIterator.hasNext && (batchSize <= 0 || rowCount < batchSize)) { + val row = inputIterator.next() + arrowWriter.write(row) + rowCount += 1 + } + arrowWriter.finish() + writer.writeBatch() + arrowWriter.reset() + } + } { + writer.end() + root.close() + allocator.close() + closed = true + } + } + } + } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[ColumnarBatch] = { + new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", 0, Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var schema: StructType = _ + private var vectors: Array[ColumnVector] = _ + + private var closed = false + + context.addTaskCompletionListener { _ => + // todo: we need something like `read.end()`, which release all the resources, but leave + // the input stream open. `reader.close` will close the socket and we can't reuse worker. + // So here we simply not close the reader, which is problematic. + if (!closed) { + if (root != null) { + root.close() + } + allocator.close() + } + } + + override def hasNext: Boolean = super.hasNext || { + if (root != null) { + root.close() + } + allocator.close() + closed = true + false + } + + private var batchLoaded = true + + protected override def read(): ColumnarBatch = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (reader != null && batchLoaded) { + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(schema, vectors, root.getRowCount) + batch.setNumRows(root.getRowCount) + batch + } else { + read() + } + } else { + stream.readInt() match { + case 0 => + reader = new ArrowStreamReader(stream, allocator) + root = reader.getVectorSchemaRoot() + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + vectors = root.getFieldVectors().asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 5319d3cfcc6c..9e46a9b1e5d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -33,12 +33,13 @@ class PythonUDFRunner( reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]]) - extends PythonRunner(funcs, bufferSize, reuseWorker, evalType, argOffsets) { + extends PythonRunner[Array[Byte], Array[Byte]]( + funcs, bufferSize, reuseWorker, evalType, argOffsets) { protected override def newWriterThread( env: SparkEnv, worker: Socket, - inputIterator: Iterator[_], + inputIterator: Iterator[Array[Byte]], partitionIndex: Int, context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { From 937292d0a2a2145be3dbc6314cf0da1b41e71b6e Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 22 Sep 2017 20:03:07 +0900 Subject: [PATCH 05/15] Add ArrowStreamPandasSerializer. --- python/pyspark/serializers.py | 68 ++++++++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 17 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 7c1fbadcb82b..1755d30c60e8 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -211,6 +211,26 @@ def __repr__(self): return "ArrowSerializer" +def _create_batch(series): + import pyarrow as pa + # Make input conform to [(series1, type1), (series2, type2), ...] + if not isinstance(series, (list, tuple)) or \ + (len(series) == 2 and isinstance(series[1], pa.DataType)): + series = [series] + series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + + # If a nullable integer series has been promoted to floating point with NaNs, need to cast + # NOTE: this is not necessary with Arrow >= 0.7 + def cast_series(s, t): + if t is None or s.dtype == t.to_pandas_dtype(): + return s + else: + return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) + + arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) + + class ArrowPandasSerializer(ArrowSerializer): """ Serializes Pandas.Series as Arrow data. @@ -221,23 +241,7 @@ def dumps(self, series): Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or a list of series accompanied by an optional pyarrow type to coerce the data to. """ - import pyarrow as pa - # Make input conform to [(series1, type1), (series2, type2), ...] - if not isinstance(series, (list, tuple)) or \ - (len(series) == 2 and isinstance(series[1], pa.DataType)): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - - # If a nullable integer series has been promoted to floating point with NaNs, need to cast - # NOTE: this is not necessary with Arrow >= 0.7 - def cast_series(s, t): - if t is None or s.dtype == t.to_pandas_dtype(): - return s - else: - return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) - - arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] - batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) + batch = _create_batch(series) return super(ArrowPandasSerializer, self).dumps(batch) def loads(self, obj): @@ -251,6 +255,36 @@ def __repr__(self): return "ArrowPandasSerializer" +class ArrowStreamPandasSerializer(Serializer): + """ + (De)serializes a vectorized(Apache Arrow) stream. + """ + + def load_stream(self, stream): + import pyarrow as pa + reader = pa.open_stream(stream) + for batch in reader: + table = pa.Table.from_batches([batch]) + yield [c.to_pandas() for c in table.itercolumns()] + + def dump_stream(self, iterator, stream): + import pyarrow as pa + writer = None + try: + for series in iterator: + batch = _create_batch(series) + if writer is None: + write_int(0, stream) + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + finally: + if writer is not None: + writer.close() + + def __repr__(self): + return "ArrowStreamPandasSerializer" + + class BatchedSerializer(Serializer): """ From 80167219abf98b8c019df3582a8c2b3ec6697753 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 22 Sep 2017 20:14:08 +0900 Subject: [PATCH 06/15] Introduce ArrowStreamEvalPythonExec. --- .../execution/vectorized/ColumnarBatch.java | 2 + .../python/ArrowStreamEvalPythonExec.scala | 76 +++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index e782756a3e78..a469880cb0c1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -462,6 +462,8 @@ public int numValidRows() { return numRows - numRowsFiltered; } + public StructType schema() { return schema; } + /** * Returns the max capacity (in number of rows) for this batch. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala new file mode 100644 index 000000000000..418ea48ca279 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.StructType + +/** + * A physical plan that evaluates a [[PythonUDF]], but exchange data with Python worker via Arrow + * stream. + */ +case class ArrowStreamEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) + extends EvalPythonExec(udfs, output, child) { + + protected override def evaluate( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean, + argOffsets: Array[Array[Int]], + iter: Iterator[InternalRow], + schema: StructType, + context: TaskContext): Iterator[InternalRow] = { + + val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex + .map { case (attr, i) => attr.withName(s"_$i") }) + + val columnarBatchIter = new ArrowStreamPythonUDFRunner( + funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + .compute(iter, context.partitionId(), context) + + new Iterator[InternalRow] { + + var currentIter = if (columnarBatchIter.hasNext) { + val batch = columnarBatchIter.next() + assert(schemaOut.equals(batch.schema), + s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") + batch.rowIterator.asScala + } else { + Iterator.empty + } + + override def hasNext: Boolean = currentIter.hasNext || { + if (columnarBatchIter.hasNext) { + currentIter = columnarBatchIter.next().rowIterator.asScala + hasNext + } else { + false + } + } + + override def next(): InternalRow = currentIter.next() + } + } +} From e62d619e13f63af5af2f386c0d7ab554ad3c6336 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 22 Sep 2017 20:36:11 +0900 Subject: [PATCH 07/15] Enable vectorized UDF via Arrow stream protocol. --- .../org/apache/spark/api/python/PythonRunner.scala | 1 + python/pyspark/serializers.py | 1 + python/pyspark/sql/tests.py | 10 ++++++++++ python/pyspark/worker.py | 10 ++++++---- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 9 +++++++++ .../spark/sql/execution/vectorized/ColumnarBatch.java | 3 +++ .../apache/spark/sql/execution/QueryExecution.scala | 2 +- .../execution/python/ArrowStreamEvalPythonExec.scala | 2 +- .../spark/sql/execution/python/ExtractPythonUDFs.scala | 9 +++++++-- 9 files changed, 39 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 4fd58bc99caf..a9c5f2669a9e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -36,6 +36,7 @@ private[spark] object PythonEvalType { val NON_UDF = 0 val SQL_BATCHED_UDF = 1 val SQL_PANDAS_UDF = 2 + val SQL_PANDAS_UDF_STREAM = 3 } /** diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 1755d30c60e8..d3651f4b6d01 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -85,6 +85,7 @@ class PythonEvalType(object): NON_UDF = 0 SQL_BATCHED_UDF = 1 SQL_PANDAS_UDF = 2 + SQL_PANDAS_UDF_STREAM = 3 class Serializer(object): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1b3af42c47ad..03636ac31005 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3376,6 +3376,16 @@ def test_vectorized_udf_empty_partition(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class ArrowStreamVectorizedUDFTests(VectorizedUDFTests): + + @classmethod + def setUpClass(cls): + VectorizedUDFTests.setUpClass() + cls.spark.conf.set("spark.sql.execution.arrow.stream.enable", "true") + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index fd917c400c87..08cf47e1acea 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,7 +31,7 @@ from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ - BatchedSerializer, ArrowPandasSerializer + BatchedSerializer, ArrowPandasSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import toArrowType from pyspark import shuffle @@ -98,10 +98,10 @@ def read_single_udf(pickleSer, infile, eval_type): else: row_func = chain(row_func, f) # the last returnType will be the return type of UDF - if eval_type == PythonEvalType.SQL_PANDAS_UDF: - return arg_offsets, wrap_pandas_udf(row_func, return_type) - else: + if eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(row_func, return_type) + else: + return arg_offsets, wrap_pandas_udf(row_func, return_type) def read_udfs(pickleSer, infile, eval_type): @@ -124,6 +124,8 @@ def read_udfs(pickleSer, infile, eval_type): if eval_type == PythonEvalType.SQL_PANDAS_UDF: ser = ArrowPandasSerializer() + elif eval_type == PythonEvalType.SQL_PANDAS_UDF_STREAM: + ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d00c67248753..dc8707b0b21e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -925,6 +925,13 @@ object SQLConf { .intConf .createWithDefault(10000) + val ARROW_EXECUTION_STREAM_ENABLE = + buildConf("spark.sql.execution.arrow.stream.enable") + .internal() + .doc("When using Apache Arrow, use Arrow stream protocol if possible.") + .booleanConf + .createWithDefault(false) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1203,6 +1210,8 @@ class SQLConf extends Serializable with Logging { def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + def arrowStreamEnable: Boolean = getConf(ARROW_EXECUTION_STREAM_ENABLE) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index a469880cb0c1..bc546c7c425b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -462,6 +462,9 @@ public int numValidRows() { return numRows - numRowsFiltered; } + /** + * Returns the schema that makes up this batch. + */ public StructType schema() { return schema; } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 4accf54a1823..e3dc63f07b3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -102,7 +102,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( - python.ExtractPythonUDFs, + python.ExtractPythonUDFs(sparkSession.sessionState.conf), PlanSubqueries(sparkSession), new ReorderJoinPredicates, EnsureRequirements(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala index 418ea48ca279..07678144e99b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala @@ -47,7 +47,7 @@ case class ArrowStreamEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute val columnarBatchIter = new ArrowStreamPythonUDFRunner( funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + PythonEvalType.SQL_PANDAS_UDF_STREAM, argOffsets, schema) .compute(iter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index fec456d86dbe..3780fd37e6b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Proj import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution import org.apache.spark.sql.execution.{FilterExec, SparkPlan} +import org.apache.spark.sql.internal.SQLConf /** @@ -90,7 +91,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { +case class ExtractPythonUDFs(conf: SQLConf) extends Rule[SparkPlan] with PredicateHelper { private def hasPythonUDF(e: Expression): Boolean = { e.find(_.isInstanceOf[PythonUDF]).isDefined @@ -141,7 +142,11 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val evaluation = validUdfs.partition(_.vectorized) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => - ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) + if (conf.arrowStreamEnable) { + ArrowStreamEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) + } else { + ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) + } case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) case _ => From 14aa3b641fd0c7f3a6feb6869508703b113b6ce6 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 26 Sep 2017 15:04:42 +0900 Subject: [PATCH 08/15] Revert rename. --- .../org/apache/spark/api/python/PythonRDD.scala | 2 +- .../org/apache/spark/api/python/PythonRunner.scala | 12 ++++++------ .../sql/execution/python/ArrowEvalPythonExec.scala | 2 +- .../python/ArrowStreamPythonUDFRunner.scala | 2 +- .../sql/execution/python/BatchEvalPythonExec.scala | 2 +- .../spark/sql/execution/python/PythonUDFRunner.scala | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 319bd6ebee26..f6293c0dc509 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -59,7 +59,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = PythonCommandRunner(func, bufferSize, reuseWorker) + val runner = PythonRunner(func, bufferSize, reuseWorker) runner.compute(firstParent.iterator(split, context), split.index, context) } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index a9c5f2669a9e..ddf37ecf1415 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -45,7 +45,7 @@ private[spark] object PythonEvalType { * funcs is a list of independent Python functions, each one of them is a list of chained Python * functions (from bottom to top). */ -private[spark] abstract class PythonRunner[IN, OUT]( +private[spark] abstract class BasePythonRunner[IN, OUT]( funcs: Seq[ChainedPythonFunctions], bufferSize: Int, reuseWorker: Boolean, @@ -346,21 +346,21 @@ private[spark] abstract class PythonRunner[IN, OUT]( } } -private[spark] object PythonCommandRunner { +private[spark] object PythonRunner { - def apply(func: PythonFunction, bufferSize: Int, reuseWorker: Boolean): PythonCommandRunner = { - new PythonCommandRunner(Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuseWorker) + def apply(func: PythonFunction, bufferSize: Int, reuseWorker: Boolean): PythonRunner = { + new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuseWorker) } } /** * A helper class to run Python mapPartition in Spark. */ -private[spark] class PythonCommandRunner( +private[spark] class PythonRunner( funcs: Seq[ChainedPythonFunctions], bufferSize: Int, reuseWorker: Boolean) - extends PythonRunner[Array[Byte], Array[Byte]]( + extends BasePythonRunner[Array[Byte], Array[Byte]]( funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0))) { protected override def newWriterThread( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index d3d04b1ddc95..6c9ebd26299d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.TaskContext -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala index d19e14f091cc..bf124fdea801 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala @@ -45,7 +45,7 @@ class ArrowStreamPythonUDFRunner( evalType: Int, argOffsets: Array[Array[Int]], schema: StructType) - extends PythonRunner[InternalRow, ColumnarBatch]( + extends BasePythonRunner[InternalRow, ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { protected override def newWriterThread( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 5bdf816000c2..26ee25f633ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.TaskContext -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 9e46a9b1e5d5..2bb18deb07e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -33,7 +33,7 @@ class PythonUDFRunner( reuseWorker: Boolean, evalType: Int, argOffsets: Array[Array[Int]]) - extends PythonRunner[Array[Byte], Array[Byte]]( + extends BasePythonRunner[Array[Byte], Array[Byte]]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { protected override def newWriterThread( From 4a23c521e1ee312bca1cd2d824a633f6c41da1e2 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 27 Sep 2017 10:26:48 +0900 Subject: [PATCH 09/15] Add a new entry in `SpecialLenths`. --- .../main/scala/org/apache/spark/api/python/PythonRunner.scala | 1 + python/pyspark/serializers.py | 3 ++- .../sql/execution/python/ArrowStreamPythonUDFRunner.scala | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index ddf37ecf1415..e45d38f9cd61 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -426,4 +426,5 @@ private[spark] object SpecialLengths { val TIMING_DATA = -3 val END_OF_STREAM = -4 val NULL = -5 + val START_ARROW_STREAM = -6 } diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d3651f4b6d01..d3981120da32 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -79,6 +79,7 @@ class SpecialLengths(object): TIMING_DATA = -3 END_OF_STREAM = -4 NULL = -5 + START_ARROW_STREAM = -6 class PythonEvalType(object): @@ -275,7 +276,7 @@ def dump_stream(self, iterator, stream): for series in iterator: batch = _create_batch(series) if writer is None: - write_int(0, stream) + write_int(SpecialLengths.START_ARROW_STREAM, stream) writer = pa.RecordBatchStreamWriter(stream, batch.schema) writer.write_batch(batch) finally: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala index bf124fdea801..138f09cc76b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala @@ -172,7 +172,7 @@ class ArrowStreamPythonUDFRunner( } } else { stream.readInt() match { - case 0 => + case SpecialLengths.START_ARROW_STREAM => reader = new ArrowStreamReader(stream, allocator) root = reader.getVectorSchemaRoot() schema = ArrowUtils.fromArrowSchema(root.getSchema()) From 7f6e43f7b291b0e5f60914bcbf6a1a7dee92820f Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 27 Sep 2017 10:58:12 +0900 Subject: [PATCH 10/15] Replace Arrow file format with Arrow stream format instead of having a new conf. --- .../spark/api/python/PythonRunner.scala | 1 - python/pyspark/serializers.py | 5 +- python/pyspark/sql/tests.py | 10 --- python/pyspark/worker.py | 10 +-- .../apache/spark/sql/internal/SQLConf.scala | 9 --- .../spark/sql/execution/QueryExecution.scala | 2 +- .../python/ArrowEvalPythonExec.scala | 52 ++++++++----- ...DFRunner.scala => ArrowPythonRunner.scala} | 2 +- .../python/ArrowStreamEvalPythonExec.scala | 76 ------------------- .../execution/python/ExtractPythonUDFs.scala | 9 +-- 10 files changed, 42 insertions(+), 134 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/python/{ArrowStreamPythonUDFRunner.scala => ArrowPythonRunner.scala} (99%) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index e45d38f9cd61..d44407133a97 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -36,7 +36,6 @@ private[spark] object PythonEvalType { val NON_UDF = 0 val SQL_BATCHED_UDF = 1 val SQL_PANDAS_UDF = 2 - val SQL_PANDAS_UDF_STREAM = 3 } /** diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d3981120da32..f09899ce5293 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -86,7 +86,6 @@ class PythonEvalType(object): NON_UDF = 0 SQL_BATCHED_UDF = 1 SQL_PANDAS_UDF = 2 - SQL_PANDAS_UDF_STREAM = 3 class Serializer(object): @@ -235,7 +234,7 @@ def cast_series(s, t): class ArrowPandasSerializer(ArrowSerializer): """ - Serializes Pandas.Series as Arrow data. + Serializes Pandas.Series as Arrow data with Arrow file format. """ def dumps(self, series): @@ -259,7 +258,7 @@ def __repr__(self): class ArrowStreamPandasSerializer(Serializer): """ - (De)serializes a vectorized(Apache Arrow) stream. + Serializes Pandas.Series as Arrow data with Arrow streaming format. """ def load_stream(self, stream): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 03636ac31005..1b3af42c47ad 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3376,16 +3376,6 @@ def test_vectorized_udf_empty_partition(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) - -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class ArrowStreamVectorizedUDFTests(VectorizedUDFTests): - - @classmethod - def setUpClass(cls): - VectorizedUDFTests.setUpClass() - cls.spark.conf.set("spark.sql.execution.arrow.stream.enable", "true") - - if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 08cf47e1acea..4e24789cf010 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,7 +31,7 @@ from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ - BatchedSerializer, ArrowPandasSerializer, ArrowStreamPandasSerializer + BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import toArrowType from pyspark import shuffle @@ -98,10 +98,10 @@ def read_single_udf(pickleSer, infile, eval_type): else: row_func = chain(row_func, f) # the last returnType will be the return type of UDF - if eval_type == PythonEvalType.SQL_BATCHED_UDF: - return arg_offsets, wrap_udf(row_func, return_type) - else: + if eval_type == PythonEvalType.SQL_PANDAS_UDF: return arg_offsets, wrap_pandas_udf(row_func, return_type) + else: + return arg_offsets, wrap_udf(row_func, return_type) def read_udfs(pickleSer, infile, eval_type): @@ -123,8 +123,6 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) if eval_type == PythonEvalType.SQL_PANDAS_UDF: - ser = ArrowPandasSerializer() - elif eval_type == PythonEvalType.SQL_PANDAS_UDF_STREAM: ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index dc8707b0b21e..d00c67248753 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -925,13 +925,6 @@ object SQLConf { .intConf .createWithDefault(10000) - val ARROW_EXECUTION_STREAM_ENABLE = - buildConf("spark.sql.execution.arrow.stream.enable") - .internal() - .doc("When using Apache Arrow, use Arrow stream protocol if possible.") - .booleanConf - .createWithDefault(false) - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1210,8 +1203,6 @@ class SQLConf extends Serializable with Logging { def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) - def arrowStreamEnable: Boolean = getConf(ARROW_EXECUTION_STREAM_ENABLE) - /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index e3dc63f07b3d..4accf54a1823 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -102,7 +102,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( - python.ExtractPythonUDFs(sparkSession.sessionState.conf), + python.ExtractPythonUDFs, PlanSubqueries(sparkSession), new ReorderJoinPredicates, EnsureRequirements(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 6c9ebd26299d..f7e8cbe41612 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.python +import scala.collection.JavaConverters._ + import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} import org.apache.spark.sql.types.StructType /** @@ -39,25 +40,36 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi iter: Iterator[InternalRow], schema: StructType, context: TaskContext): Iterator[InternalRow] = { - val inputIterator = ArrowConverters.toPayloadIterator( - iter, schema, conf.arrowMaxRecordsPerBatch, context).map(_.asPythonSerializable) - - // Output iterator for results from Python. - val outputIterator = new PythonUDFRunner( - funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets) - .compute(inputIterator, context.partitionId(), context) - - val outputRowIterator = ArrowConverters.fromPayloadIterator( - outputIterator.map(new ArrowPayload(_)), context) - - // Verify that the output schema is correct - if (outputRowIterator.hasNext) { - val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex - .map { case (attr, i) => attr.withName(s"_$i") }) - assert(schemaOut.equals(outputRowIterator.schema), - s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}") - } - outputRowIterator + val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex + .map { case (attr, i) => attr.withName(s"_$i") }) + + val columnarBatchIter = new ArrowPythonRunner( + funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + .compute(iter, context.partitionId(), context) + + new Iterator[InternalRow] { + + var currentIter = if (columnarBatchIter.hasNext) { + val batch = columnarBatchIter.next() + assert(schemaOut.equals(batch.schema), + s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") + batch.rowIterator.asScala + } else { + Iterator.empty + } + + override def hasNext: Boolean = currentIter.hasNext || { + if (columnarBatchIter.hasNext) { + currentIter = columnarBatchIter.next().rowIterator.asScala + hasNext + } else { + false + } + } + + override def next(): InternalRow = currentIter.next() + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 138f09cc76b8..efe986301044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -37,7 +37,7 @@ import org.apache.spark.util.Utils /** * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. */ -class ArrowStreamPythonUDFRunner( +class ArrowPythonRunner( funcs: Seq[ChainedPythonFunctions], batchSize: Int, bufferSize: Int, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala deleted file mode 100644 index 07678144e99b..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.python - -import scala.collection.JavaConverters._ - -import org.apache.spark.TaskContext -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.types.StructType - -/** - * A physical plan that evaluates a [[PythonUDF]], but exchange data with Python worker via Arrow - * stream. - */ -case class ArrowStreamEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) - extends EvalPythonExec(udfs, output, child) { - - protected override def evaluate( - funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuseWorker: Boolean, - argOffsets: Array[Array[Int]], - iter: Iterator[InternalRow], - schema: StructType, - context: TaskContext): Iterator[InternalRow] = { - - val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex - .map { case (attr, i) => attr.withName(s"_$i") }) - - val columnarBatchIter = new ArrowStreamPythonUDFRunner( - funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF_STREAM, argOffsets, schema) - .compute(iter, context.partitionId(), context) - - new Iterator[InternalRow] { - - var currentIter = if (columnarBatchIter.hasNext) { - val batch = columnarBatchIter.next() - assert(schemaOut.equals(batch.schema), - s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") - batch.rowIterator.asScala - } else { - Iterator.empty - } - - override def hasNext: Boolean = currentIter.hasNext || { - if (columnarBatchIter.hasNext) { - currentIter = columnarBatchIter.next().rowIterator.asScala - hasNext - } else { - false - } - } - - override def next(): InternalRow = currentIter.next() - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 3780fd37e6b6..fec456d86dbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Proj import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution import org.apache.spark.sql.execution.{FilterExec, SparkPlan} -import org.apache.spark.sql.internal.SQLConf /** @@ -91,7 +90,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -case class ExtractPythonUDFs(conf: SQLConf) extends Rule[SparkPlan] with PredicateHelper { +object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { private def hasPythonUDF(e: Expression): Boolean = { e.find(_.isInstanceOf[PythonUDF]).isDefined @@ -142,11 +141,7 @@ case class ExtractPythonUDFs(conf: SQLConf) extends Rule[SparkPlan] with Predica val evaluation = validUdfs.partition(_.vectorized) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => - if (conf.arrowStreamEnable) { - ArrowStreamEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) - } else { - ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) - } + ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) case _ => From dd6eaa39723059ffec5fbfc333dda84308342587 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 27 Sep 2017 14:14:04 +0900 Subject: [PATCH 11/15] Remove ArrowPandasSerializer. --- python/pyspark/serializers.py | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index f09899ce5293..711625dbffeb 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -232,36 +232,15 @@ def cast_series(s, t): return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) -class ArrowPandasSerializer(ArrowSerializer): - """ - Serializes Pandas.Series as Arrow data with Arrow file format. - """ - - def dumps(self, series): - """ - Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or - a list of series accompanied by an optional pyarrow type to coerce the data to. - """ - batch = _create_batch(series) - return super(ArrowPandasSerializer, self).dumps(batch) - - def loads(self, obj): - """ - Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series. - """ - table = super(ArrowPandasSerializer, self).loads(obj) - return [c.to_pandas() for c in table.itercolumns()] - - def __repr__(self): - return "ArrowPandasSerializer" - - class ArrowStreamPandasSerializer(Serializer): """ Serializes Pandas.Series as Arrow data with Arrow streaming format. """ def load_stream(self, stream): + """ + Deserialize ArrowRecordBatchs to an Arrow table and return as a list of pandas.Series. + """ import pyarrow as pa reader = pa.open_stream(stream) for batch in reader: @@ -269,6 +248,10 @@ def load_stream(self, stream): yield [c.to_pandas() for c in table.itercolumns()] def dump_stream(self, iterator, stream): + """ + Make ArrowRecordBatches from Pandas Serieses and serialize. Input is a single series or + a list of series accompanied by an optional pyarrow type to coerce the data to. + """ import pyarrow as pa writer = None try: From 83bb3c209e33c752f50556ba15ae0ea6d29b98ec Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 27 Sep 2017 14:41:07 +0900 Subject: [PATCH 12/15] Extract duplicate code to utility method. --- .../execution/python/ArrowPythonRunner.scala | 13 +------ .../execution/python/PythonUDFRunner.scala | 34 ++++++++++++------- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index efe986301044..98738bf5c06c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -57,18 +57,7 @@ class ArrowPythonRunner( new WriterThread(env, worker, inputIterator, partitionIndex, context) { override def writeCommand(dataOut: DataOutputStream): Unit = { - dataOut.writeInt(funcs.length) - funcs.zip(argOffsets).foreach { case (chained, offsets) => - dataOut.writeInt(offsets.length) - offsets.foreach { offset => - dataOut.writeInt(offset) - } - dataOut.writeInt(chained.funcs.length) - chained.funcs.foreach { f => - dataOut.writeInt(f.command.length) - dataOut.write(f.command) - } - } + PythonUDFRunner.writeUDF(dataOut, funcs, argOffsets) } override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 2bb18deb07e1..e636edc497ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -45,18 +45,7 @@ class PythonUDFRunner( new WriterThread(env, worker, inputIterator, partitionIndex, context) { override def writeCommand(dataOut: DataOutputStream): Unit = { - dataOut.writeInt(funcs.length) - funcs.zip(argOffsets).foreach { case (chained, offsets) => - dataOut.writeInt(offsets.length) - offsets.foreach { offset => - dataOut.writeInt(offset) - } - dataOut.writeInt(chained.funcs.length) - chained.funcs.foreach { f => - dataOut.writeInt(f.command.length) - dataOut.write(f.command) - } - } + PythonUDFRunner.writeUDF(dataOut, funcs, argOffsets) } override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { @@ -101,3 +90,24 @@ class PythonUDFRunner( } } } + +object PythonUDFRunner { + + def writeUDF( + dataOut: DataOutputStream, + funcs: Seq[ChainedPythonFunctions], + argOffsets: Array[Array[Int]]): Unit = { + dataOut.writeInt(funcs.length) + funcs.zip(argOffsets).foreach { case (chained, offsets) => + dataOut.writeInt(offsets.length) + offsets.foreach { offset => + dataOut.writeInt(offset) + } + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } + } + } +} From aa3fa70981731202a14e0a76be01c04deb5de0f4 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 27 Sep 2017 15:26:50 +0900 Subject: [PATCH 13/15] Add comments. --- .../org/apache/spark/api/python/PythonRunner.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index d44407133a97..8771b255e471 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -147,7 +147,14 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( this.interrupt() } + /** + * Writes a command section to the stream connected to the Python worker. + */ def writeCommand(dataOut: DataOutputStream): Unit + + /** + * Writes input data to the stream connected to the Python worker. + */ def writeIteratorToStream(dataOut: DataOutputStream): Unit override def run(): Unit = Utils.logUncaughtExceptions { @@ -249,6 +256,11 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } } + /** + * Reads next object from the stream. + * When the stream reaches end of data, needs to process the following sections, + * and then returns null. + */ protected def read(): OUT protected def handleTimingData(): Unit = { From 416bd10d5e62c01d3d34dad92bcabe9586cecde1 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 27 Sep 2017 15:57:11 +0900 Subject: [PATCH 14/15] Add a comment to describe calling `read()` again. --- .../apache/spark/sql/execution/python/ArrowPythonRunner.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 98738bf5c06c..c1c746fb3bdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -157,6 +157,7 @@ class ArrowPythonRunner( batch.setNumRows(root.getRowCount) batch } else { + // Reach end of stream. Call `read()` again to read control data. read() } } else { From 7cd78b2aa44e830e0b8b466d1ca80e54359a3c3c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 27 Sep 2017 18:15:41 +0900 Subject: [PATCH 15/15] Minor cleanups. --- .../spark/api/python/PythonRunner.scala | 10 ++++----- python/pyspark/serializers.py | 20 ++++++++--------- .../execution/python/ArrowPythonRunner.scala | 22 +++++++------------ .../execution/python/PythonUDFRunner.scala | 8 +++---- 4 files changed, 27 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 8771b255e471..3688a149443c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -80,7 +80,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // Start a thread to feed the process input from our parent's iterator val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) - context.addTaskCompletionListener { context => + context.addTaskCompletionListener { _ => writerThread.shutdownOnTaskCompletion() if (!reuseWorker || !released.get) { try { @@ -150,12 +150,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( /** * Writes a command section to the stream connected to the Python worker. */ - def writeCommand(dataOut: DataOutputStream): Unit + protected def writeCommand(dataOut: DataOutputStream): Unit /** * Writes input data to the stream connected to the Python worker. */ - def writeIteratorToStream(dataOut: DataOutputStream): Unit + protected def writeIteratorToStream(dataOut: DataOutputStream): Unit override def run(): Unit = Utils.logUncaughtExceptions { try { @@ -382,13 +382,13 @@ private[spark] class PythonRunner( context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { - override def writeCommand(dataOut: DataOutputStream): Unit = { + protected override def writeCommand(dataOut: DataOutputStream): Unit = { val command = funcs.head.funcs.head.command dataOut.writeInt(command.length) dataOut.write(command) } - override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) } diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 711625dbffeb..db77b7e150b2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -237,16 +237,6 @@ class ArrowStreamPandasSerializer(Serializer): Serializes Pandas.Series as Arrow data with Arrow streaming format. """ - def load_stream(self, stream): - """ - Deserialize ArrowRecordBatchs to an Arrow table and return as a list of pandas.Series. - """ - import pyarrow as pa - reader = pa.open_stream(stream) - for batch in reader: - table = pa.Table.from_batches([batch]) - yield [c.to_pandas() for c in table.itercolumns()] - def dump_stream(self, iterator, stream): """ Make ArrowRecordBatches from Pandas Serieses and serialize. Input is a single series or @@ -265,6 +255,16 @@ def dump_stream(self, iterator, stream): if writer is not None: writer.close() + def load_stream(self, stream): + """ + Deserialize ArrowRecordBatchs to an Arrow table and return as a list of pandas.Series. + """ + import pyarrow as pa + reader = pa.open_stream(stream) + for batch in reader: + table = pa.Table.from_batches([batch]) + yield [c.to_pandas() for c in table.itercolumns()] + def __repr__(self): return "ArrowStreamPandasSerializer" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index c1c746fb3bdc..bbad9d6b631f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -56,11 +56,11 @@ class ArrowPythonRunner( context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { - override def writeCommand(dataOut: DataOutputStream): Unit = { - PythonUDFRunner.writeUDF(dataOut, funcs, argOffsets) + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) } - override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { val arrowSchema = ArrowUtils.toArrowSchema(schema) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) @@ -123,8 +123,8 @@ class ArrowPythonRunner( private var closed = false context.addTaskCompletionListener { _ => - // todo: we need something like `read.end()`, which release all the resources, but leave - // the input stream open. `reader.close` will close the socket and we can't reuse worker. + // todo: we need something like `reader.end()`, which release all the resources, but leave + // the input stream open. `reader.close()` will close the socket and we can't reuse worker. // So here we simply not close the reader, which is problematic. if (!closed) { if (root != null) { @@ -134,15 +134,6 @@ class ArrowPythonRunner( } } - override def hasNext: Boolean = super.hasNext || { - if (root != null) { - root.close() - } - allocator.close() - closed = true - false - } - private var batchLoaded = true protected override def read(): ColumnarBatch = { @@ -157,6 +148,9 @@ class ArrowPythonRunner( batch.setNumRows(root.getRowCount) batch } else { + root.close() + allocator.close() + closed = true // Reach end of stream. Call `read()` again to read control data. read() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index e636edc497ce..e28def1c4b42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -44,11 +44,11 @@ class PythonUDFRunner( context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { - override def writeCommand(dataOut: DataOutputStream): Unit = { - PythonUDFRunner.writeUDF(dataOut, funcs, argOffsets) + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) } - override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) } @@ -93,7 +93,7 @@ class PythonUDFRunner( object PythonUDFRunner { - def writeUDF( + def writeUDFs( dataOut: DataOutputStream, funcs: Seq[ChainedPythonFunctions], argOffsets: Array[Array[Int]]): Unit = {