Skip to content

Commit bb152cd

Browse files
cloud-fanhvanhovell
authored andcommitted
[SPARK-18519][SQL] map type can not be used in EqualTo
## What changes were proposed in this pull request? Technically map type is not orderable, but can be used in equality comparison. However, due to the limitation of the current implementation, map type can't be used in equality comparison so that it can't be join key or grouping key. This PR makes this limitation explicit, to avoid wrong result. ## How was this patch tested? updated tests. Author: Wenchen Fan <[email protected]> Closes #15956 from cloud-fan/map-type.
1 parent 933a654 commit bb152cd

File tree

4 files changed

+48
-43
lines changed

4 files changed

+48
-43
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -183,21 +183,6 @@ trait CheckAnalysis extends PredicateHelper {
183183
s"join condition '${condition.sql}' " +
184184
s"of type ${condition.dataType.simpleString} is not a boolean.")
185185

186-
case j @ Join(_, _, _, Some(condition)) =>
187-
def checkValidJoinConditionExprs(expr: Expression): Unit = expr match {
188-
case p: Predicate =>
189-
p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs)
190-
case e if e.dataType.isInstanceOf[BinaryType] =>
191-
failAnalysis(s"binary type expression ${e.sql} cannot be used " +
192-
"in join conditions")
193-
case e if e.dataType.isInstanceOf[MapType] =>
194-
failAnalysis(s"map type expression ${e.sql} cannot be used " +
195-
"in join conditions")
196-
case _ => // OK
197-
}
198-
199-
checkValidJoinConditionExprs(condition)
200-
201186
case Aggregate(groupingExprs, aggregateExprs, child) =>
202187
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
203188
case aggExpr: AggregateExpression =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,21 @@ case class EqualTo(left: Expression, right: Expression)
412412

413413
override def inputType: AbstractDataType = AnyDataType
414414

415+
override def checkInputDataTypes(): TypeCheckResult = {
416+
super.checkInputDataTypes() match {
417+
case TypeCheckResult.TypeCheckSuccess =>
418+
// TODO: although map type is not orderable, technically map type should be able to be used
419+
// in equality comparison, remove this type check once we support it.
420+
if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
421+
TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, but the actual " +
422+
s"input type is ${left.dataType.catalogString}.")
423+
} else {
424+
TypeCheckResult.TypeCheckSuccess
425+
}
426+
case failure => failure
427+
}
428+
}
429+
415430
override def symbol: String = "="
416431

417432
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
@@ -440,6 +455,21 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
440455

441456
override def inputType: AbstractDataType = AnyDataType
442457

458+
override def checkInputDataTypes(): TypeCheckResult = {
459+
super.checkInputDataTypes() match {
460+
case TypeCheckResult.TypeCheckSuccess =>
461+
// TODO: although map type is not orderable, technically map type should be able to be used
462+
// in equality comparison, remove this type check once we support it.
463+
if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
464+
TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualNullSafe, but the actual " +
465+
s"input type is ${left.dataType.catalogString}.")
466+
} else {
467+
TypeCheckResult.TypeCheckSuccess
468+
}
469+
case failure => failure
470+
}
471+
}
472+
443473
override def symbol: String = "<=>"
444474

445475
override def nullable: Boolean = false

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -465,34 +465,22 @@ class AnalysisErrorSuite extends AnalysisTest {
465465
"another aggregate function." :: Nil)
466466
}
467467

468-
test("Join can't work on binary and map types") {
469-
val plan =
470-
Join(
471-
LocalRelation(
472-
AttributeReference("a", BinaryType)(exprId = ExprId(2)),
473-
AttributeReference("b", IntegerType)(exprId = ExprId(1))),
474-
LocalRelation(
475-
AttributeReference("c", BinaryType)(exprId = ExprId(4)),
476-
AttributeReference("d", IntegerType)(exprId = ExprId(3))),
477-
Cross,
478-
Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)),
479-
AttributeReference("c", BinaryType)(exprId = ExprId(4)))))
480-
481-
assertAnalysisError(plan, "binary type expression `a` cannot be used in join conditions" :: Nil)
482-
483-
val plan2 =
484-
Join(
485-
LocalRelation(
486-
AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
487-
AttributeReference("b", IntegerType)(exprId = ExprId(1))),
488-
LocalRelation(
489-
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)),
490-
AttributeReference("d", IntegerType)(exprId = ExprId(3))),
491-
Cross,
492-
Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
493-
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)))))
494-
495-
assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil)
468+
test("Join can work on binary types but can't work on map types") {
469+
val left = LocalRelation('a.binary, 'b.map(StringType, StringType))
470+
val right = LocalRelation('c.binary, 'd.map(StringType, StringType))
471+
472+
val plan1 = left.join(
473+
right,
474+
joinType = Cross,
475+
condition = Some('a === 'c))
476+
477+
assertAnalysisSuccess(plan1)
478+
479+
val plan2 = left.join(
480+
right,
481+
joinType = Cross,
482+
condition = Some('b === 'd))
483+
assertAnalysisError(plan2, "Cannot use map type in EqualTo" :: Nil)
496484
}
497485

498486
test("PredicateSubQuery is used outside of a filter") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
111111
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
112112
assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
113113

114+
assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo")
115+
assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe")
114116
assertError(LessThan('mapField, 'mapField),
115117
s"requires ${TypeCollection.Ordered.simpleString} type")
116118
assertError(LessThanOrEqual('mapField, 'mapField),

0 commit comments

Comments
 (0)