diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 31cc26962ad93..37b576ec4df22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -211,28 +211,24 @@ case class BroadcastNestedLoopJoin( Iterator((matchedRows, includedBroadcastTuples)) } - val includedBroadcastTuples = streamedPlusMatches.map(_._2) - val allIncludedBroadcastTuples = - if (includedBroadcastTuples.count == 0) { - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - } else { - streamedPlusMatches.map(_._2).reduce(_ ++ _) - } - - val rightOuterMatches: Seq[Row] = - if (joinType == RightOuter || joinType == FullOuter) { - broadcastedRelation.value.zipWithIndex.filter { - case (row, i) => !allIncludedBroadcastTuples.contains(i) - }.map { - // TODO: Use projection. - case (row, _) => buildRow(Vector.fill(left.output.size)(null) ++ row) + if (joinType == RightOuter || joinType == FullOuter) { + val includedBroadcastTuples = streamedPlusMatches.map(_._2) + val allIncludedBroadcastTuples = + includedBroadcastTuples.fold( + new scala.collection.mutable.BitSet(broadcastedRelation.value.size))(_ ++ _) + + val rightOuterMatches: Seq[Row] = + broadcastedRelation.value.zipWithIndex.collect { + case (row, i) if !allIncludedBroadcastTuples.contains(i) => + // TODO: Use projection. + buildRow(Vector.fill(left.output.size)(null) ++ row) } - } else { - Vector() - } - // TODO: Breaks lineage. - sc.union( - streamedPlusMatches.flatMap(_._1), sc.makeRDD(rightOuterMatches)) + // TODO: Breaks lineage. + sc.union( + streamedPlusMatches.flatMap(_._1), sc.makeRDD(rightOuterMatches)) + } else { + streamedPlusMatches.flatMap(_._1) + } } }