diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 317a88edf8e95..71d55b007aa17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -430,21 +430,23 @@ class DataFrameSuite extends QueryTest test("repartition by MapType") { Seq("int", "long", "float", "double", "decimal(10, 2)", "string", "varchar(6)").foreach { dt => - val df = spark.range(20) - .withColumn("c1", - when(col("id") % 3 === 1, typedLit(Map(1 -> 1))) - .when(col("id") % 3 === 2, typedLit(Map(1 -> 1, 2 -> 2))) - .otherwise(typedLit(Map(2 -> 2, 1 -> 1))).cast(s"map<$dt, $dt>")) - .withColumn("c2", typedLit(Map(1 -> null)).cast(s"map<$dt, $dt>")) - .withColumn("c3", lit(null).cast(s"map<$dt, $dt>")) - - assertPartitionNumber(df.repartition(4, col("c1")), 2) - assertPartitionNumber(df.repartition(4, col("c2")), 1) - assertPartitionNumber(df.repartition(4, col("c3")), 1) - assertPartitionNumber(df.repartition(4, col("c1"), col("c2")), 2) - assertPartitionNumber(df.repartition(4, col("c1"), col("c3")), 2) - assertPartitionNumber(df.repartition(4, col("c1"), col("c2"), col("c3")), 2) - assertPartitionNumber(df.repartition(4, col("c2"), col("c3")), 2) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val df = spark.range(20) + .withColumn("c1", + when(col("id") % 3 === 1, typedLit(Map(1 -> 1))) + .when(col("id") % 3 === 2, typedLit(Map(1 -> 1, 2 -> 2))) + .otherwise(typedLit(Map(2 -> 2, 1 -> 1))).cast(s"map<$dt, $dt>")) + .withColumn("c2", typedLit(Map(1 -> null)).cast(s"map<$dt, $dt>")) + .withColumn("c3", lit(null).cast(s"map<$dt, $dt>")) + + assertPartitionNumber(df.repartition(4, col("c1")), 2) + assertPartitionNumber(df.repartition(4, col("c2")), 1) + assertPartitionNumber(df.repartition(4, col("c3")), 1) + assertPartitionNumber(df.repartition(4, col("c1"), col("c2")), 2) + assertPartitionNumber(df.repartition(4, col("c1"), col("c3")), 2) + assertPartitionNumber(df.repartition(4, col("c1"), col("c2"), col("c3")), 2) + assertPartitionNumber(df.repartition(4, col("c2"), col("c3")), 2) + } } }