Skip to content

Commit a9eb7de

Browse files
committed
[SPARK-31809][SQL] Infer IsNotNull from join condition
1 parent c7d9bd2 commit a9eb7de

File tree

2 files changed

+73
-13
lines changed

2 files changed

+73
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,14 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
12151215
}
12161216
}
12171217

1218+
// Whether the result of this expression may be null. For example: CAST(strCol AS double)
1219+
// We will infer an IsNotNull expression for this expression to avoid skew join.
1220+
private def resultMayBeNull(e: Expression): Boolean = e match {
1221+
case Cast(child, dataType, _, _) => !Cast.canUpCast(child.dataType, dataType)
1222+
case _: Coalesce => true
1223+
case _ => false
1224+
}
1225+
12181226
private def inferFilters(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
12191227
_.containsAnyPattern(FILTER, JOIN)) {
12201228
case filter @ Filter(condition, child) =>
@@ -1227,25 +1235,43 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
12271235
}
12281236

12291237
case join @ Join(left, right, joinType, conditionOpt, _) =>
1238+
val leftKeys = new mutable.HashSet[Expression]
1239+
val rightKeys = new mutable.HashSet[Expression]
1240+
conditionOpt.foreach { condition =>
1241+
splitConjunctivePredicates(condition).foreach {
1242+
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty =>
1243+
case EqualTo(l, r) =>
1244+
if (resultMayBeNull(l)) {
1245+
if (canEvaluate(l, left)) leftKeys.add(l)
1246+
if (canEvaluate(l, right)) rightKeys.add(l)
1247+
}
1248+
if (resultMayBeNull(r)) {
1249+
if (canEvaluate(r, left)) leftKeys.add(r)
1250+
if (canEvaluate(r, right)) rightKeys.add(r)
1251+
}
1252+
case _ =>
1253+
}
1254+
}
1255+
12301256
joinType match {
12311257
// For inner join, we can infer additional filters for both sides. LeftSemi is kind of an
12321258
// inner join, it just drops the right side in the final output.
12331259
case _: InnerLike | LeftSemi =>
12341260
val allConstraints = getAllConstraints(left, right, conditionOpt)
1235-
val newLeft = inferNewFilter(left, allConstraints)
1236-
val newRight = inferNewFilter(right, allConstraints)
1261+
val newLeft = inferNewFilter(left, allConstraints, leftKeys)
1262+
val newRight = inferNewFilter(right, allConstraints, rightKeys)
12371263
join.copy(left = newLeft, right = newRight)
12381264

12391265
// For right outer join, we can only infer additional filters for left side.
12401266
case RightOuter =>
12411267
val allConstraints = getAllConstraints(left, right, conditionOpt)
1242-
val newLeft = inferNewFilter(left, allConstraints)
1268+
val newLeft = inferNewFilter(left, allConstraints, leftKeys)
12431269
join.copy(left = newLeft)
12441270

12451271
// For left join, we can only infer additional filters for right side.
12461272
case LeftOuter | LeftAnti =>
12471273
val allConstraints = getAllConstraints(left, right, conditionOpt)
1248-
val newRight = inferNewFilter(right, allConstraints)
1274+
val newRight = inferNewFilter(right, allConstraints, rightKeys)
12491275
join.copy(right = newRight)
12501276

12511277
case _ => join
@@ -1261,9 +1287,13 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
12611287
baseConstraints.union(inferAdditionalConstraints(baseConstraints))
12621288
}
12631289

1264-
private def inferNewFilter(plan: LogicalPlan, constraints: ExpressionSet): LogicalPlan = {
1290+
private def inferNewFilter(
1291+
plan: LogicalPlan,
1292+
constraints: ExpressionSet,
1293+
joinKeys: mutable.HashSet[Expression]): LogicalPlan = {
12651294
val newPredicates = constraints
12661295
.union(constructIsNotNullConstraints(constraints, plan.output))
1296+
.union(ExpressionSet(joinKeys.map(IsNotNull)))
12671297
.filter { c =>
12681298
c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic
12691299
} -- plan.constraints

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,15 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
271271
val originalLeft = testRelation1.subquery('left)
272272
val originalRight = testRelation2.where('b === 1L).subquery('right)
273273

274-
val left = testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left)
275-
val right = testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right)
276-
277274
Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
278275
Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
279-
testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
276+
testConstraintsAfterJoin(
277+
originalLeft,
278+
originalRight,
279+
testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left),
280+
testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right),
281+
Inner,
282+
condition)
280283
}
281284

282285
Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
@@ -285,7 +288,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
285288
originalLeft,
286289
originalRight,
287290
testRelation1.where(IsNotNull('a)).subquery('left),
288-
right,
291+
testRelation2.where(IsNotNull('b) && IsNotNull('b.cast(IntegerType)) &&
292+
'b === 1L).subquery('right),
289293
Inner,
290294
condition)
291295
}
@@ -302,16 +306,23 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
302306

303307
Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
304308
Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
305-
testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
309+
testConstraintsAfterJoin(
310+
originalLeft,
311+
originalRight,
312+
testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left),
313+
testRelation2.where(IsNotNull('b)).subquery('right),
314+
Inner,
315+
condition)
306316
}
307317

308318
Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
309319
Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition =>
310320
testConstraintsAfterJoin(
311321
originalLeft,
312322
originalRight,
313-
left,
314-
testRelation2.where(IsNotNull('b) && 'b.attr.cast(IntegerType) === 1).subquery('right),
323+
testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left),
324+
testRelation2.where(IsNotNull('b) && IsNotNull('b.cast(IntegerType)) &&
325+
'b.attr.cast(IntegerType) === 1).subquery('right),
315326
Inner,
316327
condition)
317328
}
@@ -361,4 +372,23 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
361372
val optimized = Optimize.execute(originalQuery)
362373
comparePlans(optimized, correctAnswer)
363374
}
375+
376+
test("SPARK-31809: Infer IsNotNull for join condition") {
377+
testConstraintsAfterJoin(
378+
testRelation.subquery('left),
379+
testRelation.subquery('right),
380+
testRelation.where(IsNotNull('a.cast(StringType).cast(DoubleType)) && IsNotNull('a))
381+
.subquery('left),
382+
testRelation.where(IsNotNull('c)).subquery('right),
383+
Inner,
384+
Some("left.a".attr.cast(StringType).cast(DoubleType) === "right.c".attr.cast(DoubleType)))
385+
386+
testConstraintsAfterJoin(
387+
testRelation.subquery('left),
388+
testRelation.subquery('right),
389+
testRelation.where(IsNotNull(Coalesce(Seq('a, 'b)))).subquery('left),
390+
testRelation.where(IsNotNull('c)).subquery('right),
391+
Inner,
392+
Some(Coalesce(Seq("left.a".attr, "left.b".attr)) === "right.c".attr))
393+
}
364394
}

0 commit comments

Comments
 (0)