Skip to content

Commit 05ce50a

Browse files
committed
Require elem type
1 parent dd69aa6 commit 05ce50a

File tree

11 files changed

+41
-33
lines changed

11 files changed

+41
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
429429
case class InSet(
430430
child: Expression,
431431
hset: Set[Any],
432-
hsetElemType: Option[DataType] = None) extends UnaryExpression with Predicate {
432+
hsetElemType: DataType) extends UnaryExpression with Predicate {
433433

434434
require(hset != null, "hset could not be null")
435435

@@ -449,12 +449,12 @@ case class InSet(
449449
}
450450
}
451451

452-
@transient lazy val set: Set[Any] = child.dataType match {
452+
@transient lazy val set: Set[Any] = hsetElemType match {
453453
case t: AtomicType if !t.isInstanceOf[BinaryType] => hset
454454
case _: NullType => hset
455455
case _ =>
456456
// for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows
457-
TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ (hset - null)
457+
TreeSet.empty(TypeUtils.getInterpretedOrdering(hsetElemType)) ++ (hset - null)
458458
}
459459

460460
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -465,7 +465,7 @@ case class InSet(
465465
}
466466
}
467467

468-
private def canBeComputedUsingSwitch: Boolean = child.dataType match {
468+
private def canBeComputedUsingSwitch: Boolean = hsetElemType match {
469469
case ByteType | ShortType | IntegerType | DateType => true
470470
case _ => false
471471
}
@@ -523,9 +523,8 @@ case class InSet(
523523

524524
override def sql: String = {
525525
val valueSQL = child.sql
526-
val elemType = hsetElemType.getOrElse(child.dataType)
527526
val listSQL = hset.toSeq
528-
.map(elem => Literal(convertToScala(elem, elemType)).sql)
527+
.map(elem => Literal(convertToScala(elem, hsetElemType)).sql)
529528
.mkString(", ")
530529
s"($valueSQL IN ($listSQL))"
531530
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ object OptimizeIn extends Rule[LogicalPlan] {
251251
EqualTo(v, newList.head)
252252
} else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) {
253253
val hSet = newList.map(e => e.eval(EmptyRow))
254-
InSet(v, HashSet() ++ hSet)
254+
InSet(v, HashSet() ++ hSet, v.dataType)
255255
} else if (newList.length < list.length) {
256256
expr.copy(list = newList)
257257
} else { // newList.length == list.length && newList.length > 1

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
130130
private def checkInAndInSet(in: In, expected: Any): Unit = {
131131
// expecting all in.list are Literal or NonFoldableLiteral.
132132
checkEvaluation(in, expected)
133-
checkEvaluation(InSet(in.value, HashSet() ++ in.list.map(_.eval())), expected)
133+
checkEvaluation(
134+
InSet(in.value, HashSet() ++ in.list.map(_.eval()), in.value.dataType),
135+
expected)
134136
}
135137

136138
test("basic IN/INSET predicate test") {
@@ -154,7 +156,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
154156
Literal(2)))),
155157
true)
156158
checkEvaluation(
157-
And(InSet(Literal(1), HashSet(1, 2)), InSet(Literal(2), Set(1, 2))),
159+
And(InSet(Literal(1), HashSet(1, 2), IntegerType), InSet(Literal(2), Set(1, 2), IntegerType)),
158160
true)
159161

160162
val ns = NonFoldableLiteral.create(null, StringType)
@@ -256,12 +258,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
256258

257259
val nullLiteral = Literal(null, presentValue.dataType)
258260

259-
checkEvaluation(InSet(nullLiteral, values), expected = null)
260-
checkEvaluation(InSet(nullLiteral, values + null), expected = null)
261-
checkEvaluation(InSet(presentValue, values), expected = true)
262-
checkEvaluation(InSet(presentValue, values + null), expected = true)
263-
checkEvaluation(InSet(absentValue, values), expected = false)
264-
checkEvaluation(InSet(absentValue, values + null), expected = null)
261+
checkEvaluation(InSet(nullLiteral, values, nullLiteral.dataType), expected = null)
262+
checkEvaluation(InSet(nullLiteral, values + null, nullLiteral.dataType), expected = null)
263+
checkEvaluation(InSet(presentValue, values, presentValue.dataType), expected = true)
264+
checkEvaluation(InSet(presentValue, values + null, presentValue.dataType), expected = true)
265+
checkEvaluation(InSet(absentValue, values, absentValue.dataType), expected = false)
266+
checkEvaluation(InSet(absentValue, values + null, absentValue.dataType), expected = null)
265267
}
266268

267269
def checkAllTypes(): Unit = {
@@ -498,7 +500,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
498500

499501
test("SPARK-22693: InSet should not use global variables") {
500502
val ctx = new CodegenContext
501-
InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx)
503+
InSet(Literal(1), Set(1, 2, 3, 4), IntegerType).genCode(ctx)
502504
assert(ctx.inlinedMutableStates.isEmpty)
503505
}
504506

@@ -535,7 +537,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
535537

536538
test("SPARK-29100: InSet with empty input set") {
537539
val row = create_row(1)
538-
val inSet = InSet(BoundReference(0, IntegerType, true), Set.empty)
540+
val inSet = InSet(BoundReference(0, IntegerType, true), Set.empty, IntegerType)
539541
checkEvaluation(inSet, false, row)
540542
}
541543
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class OptimizeInSuite extends PlanTest {
8585
val optimized = Optimize.execute(originalQuery.analyze)
8686
val correctAnswer =
8787
testRelation
88-
.where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet))
88+
.where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet, IntegerType))
8989
.analyze
9090

9191
comparePlans(optimized, correctAnswer)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{ColumnStatsM
2727
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
2828
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2929
import org.apache.spark.sql.types._
30+
import org.apache.spark.unsafe.types.UTF8String
3031

3132
/**
3233
* In this test suite, we test predicates containing the following operators:
@@ -352,15 +353,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
352353

353354
test("cint IN (3, 4, 5)") {
354355
validateEstimatedStats(
355-
Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)),
356+
Filter(InSet(attrInt, Set(3, 4, 5), IntegerType), childStatsTestPlan(Seq(attrInt), 10L)),
356357
Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(3), max = Some(5),
357358
nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
358359
expectedRowCount = 3)
359360
}
360361

361362
test("evaluateInSet with all zeros") {
362363
validateEstimatedStats(
363-
Filter(InSet(attrString, Set(3, 4, 5)),
364+
Filter(InSet(attrString, Set(3, 4, 5), IntegerType),
364365
StatsTestPlan(Seq(attrString), 0,
365366
AttributeMap(Seq(attrString ->
366367
ColumnStat(distinctCount = Some(0), min = None, max = None,
@@ -371,7 +372,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
371372

372373
test("evaluateInSet with string") {
373374
validateEstimatedStats(
374-
Filter(InSet(attrString, Set("A0")),
375+
Filter(InSet(attrString, Set(UTF8String.fromString("A0")), StringType),
375376
StatsTestPlan(Seq(attrString), 10,
376377
AttributeMap(Seq(attrString ->
377378
ColumnStat(distinctCount = Some(10), min = None, max = None,
@@ -383,14 +384,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
383384

384385
test("cint NOT IN (3, 4, 5)") {
385386
validateEstimatedStats(
386-
Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)),
387+
Filter(Not(InSet(attrInt, Set(3, 4, 5), IntegerType)), childStatsTestPlan(Seq(attrInt), 10L)),
387388
Seq(attrInt -> colStatInt.copy(distinctCount = Some(7))),
388389
expectedRowCount = 7)
389390
}
390391

391392
test("cbool IN (true)") {
392393
validateEstimatedStats(
393-
Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)),
394+
Filter(InSet(attrBool, Set(true), BooleanType), childStatsTestPlan(Seq(attrBool), 10L)),
394395
Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true),
395396
nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))),
396397
expectedRowCount = 5)
@@ -510,7 +511,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
510511
attributeStats = AttributeMap(Seq(attrInt -> cornerChildColStatInt))
511512
)
512513
validateEstimatedStats(
513-
Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan),
514+
Filter(InSet(attrInt, Set(1, 2, 3, 4, 5), IntegerType), cornerChildStatsTestplan),
514515
Seq(attrInt -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(5),
515516
nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
516517
expectedRowCount = 2)

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,8 @@ class Column(val expr: Expression) extends Logging {
830830
def isInCollection(values: scala.collection.Iterable[_]): Column = withExpr {
831831
val exprValues = values.toSeq.map(lit(_).expr)
832832
if (exprValues.size > SQLConf.get.optimizerInSetConversionThreshold) {
833-
InSet(expr, exprValues.map(_.eval()).toSet, exprValues.headOption.map(_.dataType))
833+
val elemType = exprValues.headOption.map(_.dataType).getOrElse(NullType)
834+
InSet(expr, exprValues.map(_.eval()).toSet, elemType)
834835
} else {
835836
In(expr, exprValues)
836837
}

sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ case class InSubqueryExec(
159159

160160
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
161161
prepareResult()
162-
InSet(child, result.toSet).doGenCode(ctx, ev)
162+
InSet(child, result.toSet, child.dataType).doGenCode(ctx, ev)
163163
}
164164

165165
override lazy val canonicalized: InSubqueryExec = {

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
872872
}
873873

874874
test("SPARK-31563: sql of InSet for UTF8String collection") {
875-
val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString))
875+
val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString), StringType)
876876
assert(inSet.sql === "('a' IN ('a', 'b'))")
877877
}
878878

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession {
110110
testTranslateFilter(LessThanOrEqual(1, attrInt),
111111
Some(sources.GreaterThanOrEqual(intColName, 1)))
112112

113-
testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In(intColName, Array(1, 2, 3))))
113+
testTranslateFilter(
114+
InSet(attrInt, Set(1, 2, 3), IntegerType),
115+
Some(sources.In(intColName, Array(1, 2, 3))))
114116

115117
testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In(intColName, Array(1, 2, 3))))
116118

sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import org.apache.spark.sql.functions._
3737
import org.apache.spark.sql.internal.SQLConf
3838
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
3939
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
40+
import org.apache.spark.sql.types.IntegerType
4041
import org.apache.spark.util.Utils
4142
import org.apache.spark.util.collection.BitSet
4243

@@ -188,8 +189,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
188189
df)
189190

190191
// Case 4: InSet
191-
val inSetExpr = expressions.InSet($"j".expr,
192-
Set(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3).map(lit(_).expr))
192+
val inSetExpr = expressions.InSet(
193+
$"j".expr,
194+
Set(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3).map(lit(_).expr),
195+
IntegerType)
193196
checkPrunedAnswers(
194197
bucketSpec,
195198
bucketValues = Seq(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3),

0 commit comments

Comments
 (0)