@@ -1215,6 +1215,14 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
12151215 }
12161216 }
12171217
1218+ // Whether the result of this expression may be null. For example: CAST(strCol AS double)
1219+ // We will infer an IsNotNull expression for this expression to avoid skew join.
1220+ private def resultMayBeNull (e : Expression ): Boolean = e match {
1221+ case Cast (child, dataType, _, _) => ! Cast .canUpCast(child.dataType, dataType)
1222+ case _ : Coalesce => true
1223+ case _ => false
1224+ }
1225+
12181226 private def inferFilters (plan : LogicalPlan ): LogicalPlan = plan.transformWithPruning(
12191227 _.containsAnyPattern(FILTER , JOIN )) {
12201228 case filter @ Filter (condition, child) =>
@@ -1227,25 +1235,43 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
12271235 }
12281236
12291237 case join @ Join (left, right, joinType, conditionOpt, _) =>
1238+ val leftKeys = new mutable.HashSet [Expression ]
1239+ val rightKeys = new mutable.HashSet [Expression ]
1240+ conditionOpt.foreach { condition =>
1241+ splitConjunctivePredicates(condition).foreach {
1242+ case EqualTo (l, r) if l.references.isEmpty || r.references.isEmpty =>
1243+ case EqualTo (l, r) =>
1244+ if (resultMayBeNull(l)) {
1245+ if (canEvaluate(l, left)) leftKeys.add(l)
1246+ if (canEvaluate(l, right)) rightKeys.add(l)
1247+ }
1248+ if (resultMayBeNull(r)) {
1249+ if (canEvaluate(r, left)) leftKeys.add(r)
1250+ if (canEvaluate(r, right)) rightKeys.add(r)
1251+ }
1252+ case _ =>
1253+ }
1254+ }
1255+
12301256 joinType match {
12311257 // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an
12321258 // inner join, it just drops the right side in the final output.
12331259 case _ : InnerLike | LeftSemi =>
12341260 val allConstraints = getAllConstraints(left, right, conditionOpt)
1235- val newLeft = inferNewFilter(left, allConstraints)
1236- val newRight = inferNewFilter(right, allConstraints)
1261+ val newLeft = inferNewFilter(left, allConstraints, leftKeys )
1262+ val newRight = inferNewFilter(right, allConstraints, rightKeys )
12371263 join.copy(left = newLeft, right = newRight)
12381264
12391265 // For right outer join, we can only infer additional filters for left side.
12401266 case RightOuter =>
12411267 val allConstraints = getAllConstraints(left, right, conditionOpt)
1242- val newLeft = inferNewFilter(left, allConstraints)
1268+ val newLeft = inferNewFilter(left, allConstraints, leftKeys )
12431269 join.copy(left = newLeft)
12441270
12451271 // For left join, we can only infer additional filters for right side.
12461272 case LeftOuter | LeftAnti =>
12471273 val allConstraints = getAllConstraints(left, right, conditionOpt)
1248- val newRight = inferNewFilter(right, allConstraints)
1274+ val newRight = inferNewFilter(right, allConstraints, rightKeys )
12491275 join.copy(right = newRight)
12501276
12511277 case _ => join
@@ -1261,9 +1287,13 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
12611287 baseConstraints.union(inferAdditionalConstraints(baseConstraints))
12621288 }
12631289
1264- private def inferNewFilter (plan : LogicalPlan , constraints : ExpressionSet ): LogicalPlan = {
1290+ private def inferNewFilter (
1291+ plan : LogicalPlan ,
1292+ constraints : ExpressionSet ,
1293+ joinKeys : mutable.HashSet [Expression ]): LogicalPlan = {
12651294 val newPredicates = constraints
12661295 .union(constructIsNotNullConstraints(constraints, plan.output))
1296+ .union(ExpressionSet (joinKeys.map(IsNotNull )))
12671297 .filter { c =>
12681298 c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic
12691299 } -- plan.constraints
0 commit comments