Skip to content

Commit 03c7aac

Browse files
committed
Use ContextAwareIterator to stop consuming after the task ends.
1 parent ec1560a commit 03c7aac

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ trait EvalPythonExec extends UnaryExecNode {
8989

9090
inputRDD.mapPartitions { iter =>
9191
val context = TaskContext.get()
92+
val contextAwareIterator = new ContextAwareIterator(iter, context)
9293

9394
// The queue used to buffer input rows so we can drain it to
9495
// combine input with output from Python.
@@ -120,7 +121,7 @@ trait EvalPythonExec extends UnaryExecNode {
120121
}.toSeq)
121122

122123
// Add rows to queue to join later with the result.
123-
val projectedRowIter = iter.map { inputRow =>
124+
val projectedRowIter = contextAwareIterator.map { inputRow =>
124125
queue.add(inputRow.asInstanceOf[UnsafeRow])
125126
projection(inputRow)
126127
}
@@ -137,3 +138,18 @@ trait EvalPythonExec extends UnaryExecNode {
137138
}
138139
}
139140
}
141+
142+
/**
143+
* A TaskContext aware iterator.
144+
*
145+
* As the Python evaluation consumes the parent iterator in a separate thread,
146+
* it could consume more data from the parent even after the task ends and the parent is closed.
147+
* Thus, we should use ContextAwareIterator to stop consuming after the task ends.
148+
*/
149+
class ContextAwareIterator[IN](iter: Iterator[IN], context: TaskContext) extends Iterator[IN] {
150+
151+
override def hasNext: Boolean =
152+
!context.isCompleted() && !context.isInterrupted() && iter.hasNext
153+
154+
override def next(): IN = iter.next()
155+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,17 @@ case class MapInPandasExec(
6161
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
6262
val outputTypes = child.schema
6363

64+
val context = TaskContext.get()
65+
val contextAwareIterator = new ContextAwareIterator(inputIter, context)
66+
6467
// Here we wrap it via another row so that Python sides understand it
6568
// as a DataFrame.
66-
val wrappedIter = inputIter.map(InternalRow(_))
69+
val wrappedIter = contextAwareIterator.map(InternalRow(_))
6770

6871
// DO NOT use iter.grouped(). See BatchIterator.
6972
val batchIter =
7073
if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else Iterator(wrappedIter)
7174

72-
val context = TaskContext.get()
73-
7475
val columnarBatchIter = new ArrowPythonRunner(
7576
chainedFunc,
7677
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,

0 commit comments

Comments
 (0)