Skip to content

Commit f5f1e2b

Browse files
committed
Fix batching.
1 parent 3450313 commit f5f1e2b

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,14 @@ private[sql] case class PhysicalRDD(
139139
// Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen
140140
// never requires UnsafeRow as input.
141141
override protected def doProduce(ctx: CodegenContext): String = {
142+
val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch"
142143
val input = ctx.freshName("input")
144+
val idx = ctx.freshName("batchIdx")
145+
val batch = ctx.freshName("batch")
143146
// PhysicalRDD always just has one input
144147
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
148+
ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;")
149+
ctx.addMutableState("int", idx, s"$idx = 0;")
145150

146151
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
147152
val row = ctx.freshName("row")
@@ -156,27 +161,28 @@ private[sql] case class PhysicalRDD(
156161
// TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know
157162
// here which path to use. Fix this.
158163

159-
val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch"
160164

161165
val scanBatches = ctx.freshName("processBatches")
162166
ctx.addNewFunction(scanBatches,
163167
s"""
164-
| private void $scanBatches($columnarBatchClz batch) throws java.io.IOException {
165-
| int batchIdx = 0;
168+
| private void $scanBatches() throws java.io.IOException {
166169
| while (true) {
167-
| int numRows = batch.numRows();
168-
| if (batchIdx == 0) $numOutputRows.add(numRows);
170+
| int numRows = $batch.numRows();
171+
| if ($idx == 0) $numOutputRows.add(numRows);
169172
|
170-
| while (batchIdx < numRows) {
171-
| InternalRow $row = batch.getRow(batchIdx++);
173+
| while ($idx < numRows) {
174+
| InternalRow $row = $batch.getRow($idx++);
172175
| ${columns.map(_.code).mkString("\n").trim}
173176
| ${consume(ctx, columns).trim}
174177
| if (shouldStop()) return;
175178
| }
176179
|
177-
| if (!$input.hasNext()) break;
178-
| batch = ($columnarBatchClz)$input.next();
179-
| batchIdx = 0;
180+
| if (!$input.hasNext()) {
181+
| $batch = null;
182+
| break;
183+
| }
184+
| $batch = ($columnarBatchClz)$input.next();
185+
| $idx = 0;
180186
| }
181187
| }""".stripMargin)
182188

@@ -195,12 +201,17 @@ private[sql] case class PhysicalRDD(
195201
| }""".stripMargin)
196202

197203
s"""
198-
| if ($input.hasNext()) {
199-
| Object firstValue = $input.next();
200-
| if (firstValue instanceof $columnarBatchClz) {
201-
| $scanBatches(($columnarBatchClz)firstValue);
204+
| if ($batch != null || $input.hasNext()) {
205+
| if ($batch == null) {
206+
| Object value = $input.next();
207+
| if (value instanceof $columnarBatchClz) {
208+
| $batch = ($columnarBatchClz)value;
209+
| $scanBatches();
210+
| } else {
211+
| $scanRows((InternalRow)value);
212+
| }
202213
| } else {
203-
| $scanRows((InternalRow)firstValue);
214+
| $scanBatches();
204215
| }
205216
| }
206217
""".stripMargin

0 commit comments

Comments
 (0)