From d1373fb38d9350e508cab31e2d7aa83403273c5b Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 9 Apr 2021 18:44:12 +0800 Subject: [PATCH 1/7] Support ANSI SQL intervals by the aggregate function `sum` --- .../sql/catalyst/expressions/aggregate/Sum.scala | 11 ++++++++--- .../apache/spark/sql/DataFrameAggregateSuite.scala | 11 ++++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index f412a3ec31e0..2ea6fdfdc518 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -46,15 +46,20 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast // Return data type. override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType)) - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sum") + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case YearMonthIntervalType | DayTimeIntervalType => TypeCheckResult.TypeCheckSuccess + case _ => TypeUtils.checkForNumericExpr(child.dataType, "function sum") + } private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) case _: IntegralType => LongType + case _: YearMonthIntervalType => YearMonthIntervalType + case _: DayTimeIntervalType => DayTimeIntervalType case _ => DoubleType } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 3e137d49e64c..5aa3fe2ed775 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql +import java.time.{Duration, Period} + import scala.util.Random import org.scalatest.matchers.must.Matchers.the - import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -1110,6 +1111,14 @@ class DataFrameAggregateSuite extends QueryTest val e = intercept[AnalysisException](arrayDF.groupBy(struct($"col.a")).count()) assert(e.message.contains("requires integral type")) } + + test("SPARK-34716: Support ANSI SQL intervals by the aggregate function `sum`") { + val df = Seq((Period.ofMonths(10), Duration.ofDays(10)), + (Period.ofMonths(1), Duration.ofDays(1))) + .toDF("year-month", "day-time") + val sumDF = df.select(sum($"year-month"), sum($"day-time")) + checkAnswer(sumDF, Row(Period.ofMonths(11), Duration.ofDays(11))) + } } case class B(c: Option[Double]) From 7897af450ab3870a1af747feb776ebf43d913767 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Sun, 11 Apr 2021 10:44:25 +0900 Subject: [PATCH 2/7] Update sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala --- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 5aa3fe2ed775..4e441f9d1a48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -22,6 +22,7 @@ import java.time.{Duration, Period} import scala.util.Random import org.scalatest.matchers.must.Matchers.the + import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} From 75f383650351f92e99237c251fb189a58f828592 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Tue, 13 Apr 2021 10:55:45 +0800 Subject: [PATCH 3/7] Add tests --- .../catalyst/expressions/aggregate/Sum.scala | 7 +++--- .../ExpressionTypeCheckingSuite.scala | 2 +- .../spark/sql/DataFrameAggregateSuite.scala | 23 +++++++++++++++---- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 2ea6fdfdc518..45b17583e157 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.UnaryLike -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -50,8 +49,10 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType)) override def checkInputDataTypes(): TypeCheckResult = child.dataType match { - case YearMonthIntervalType | DayTimeIntervalType => TypeCheckResult.TypeCheckSuccess - case _ => TypeUtils.checkForNumericExpr(child.dataType, "function sum") + case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess + case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess + case other => TypeCheckResult.TypeCheckFailure( + s"function sum requires numeric or interval types, not ${other.catalogString}") } private lazy val resultType = child.dataType match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 44f333342d1c..1b9135eef69f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -158,7 +158,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Min(Symbol("mapField")), "min does not support ordering on type") assertError(Max(Symbol("mapField")), "max does not support ordering on type") - assertError(Sum(Symbol("booleanField")), "function sum requires numeric type") + assertError(Sum(Symbol("booleanField")), "function sum requires numeric or interval types") assertError(Average(Symbol("booleanField")), "function average requires numeric type") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 5aa3fe2ed775..42496b92d6f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -22,6 +22,7 @@ import java.time.{Duration, Period} import scala.util.Random import org.scalatest.matchers.must.Matchers.the + import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -1113,11 +1114,25 @@ class DataFrameAggregateSuite extends QueryTest } test("SPARK-34716: Support ANSI SQL intervals by the aggregate function `sum`") { - val df = Seq((Period.ofMonths(10), Duration.ofDays(10)), - (Period.ofMonths(1), Duration.ofDays(1))) - .toDF("year-month", "day-time") + val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)), + (2, Period.ofMonths(1), Duration.ofDays(1)), + (2, null, null), + (3, Period.ofMonths(-3), Duration.ofDays(-6)), + (3, Period.ofMonths(21), Duration.ofDays(-5))) + .toDF("class", "year-month", "day-time") + val sumDF = df.select(sum($"year-month"), sum($"day-time")) - checkAnswer(sumDF, Row(Period.ofMonths(11), Duration.ofDays(11))) + checkAnswer(sumDF, Row(Period.of(2, 5, 0), Duration.ofDays(0))) + assert(sumDF.schema == StructType(Seq(StructField("sum(year-month)", YearMonthIntervalType), + StructField("sum(day-time)", DayTimeIntervalType)))) + + val sumDF2 = df.groupBy($"class").agg(sum($"year-month"), sum($"day-time")) + checkAnswer(sumDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) :: + Row(2, Period.ofMonths(1), Duration.ofDays(1)) :: + Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) ::Nil) + assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false), + StructField("sum(year-month)", YearMonthIntervalType), + StructField("sum(day-time)", DayTimeIntervalType)))) } } From 9b8c4614544c5be1d8a0b08a8a8265d2b932ca50 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Tue, 13 Apr 2021 14:22:00 +0800 Subject: [PATCH 4/7] Add tests --- .../sql/catalyst/expressions/UnsafeRow.java | 4 +++- .../spark/sql/DataFrameAggregateSuite.scala | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 4dc5ce1de047..0c6685d76fd0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -90,7 +90,9 @@ public static int calculateBitSetWidthInBytes(int numFields) { FloatType, DoubleType, DateType, - TimestampType + TimestampType, + YearMonthIntervalType, + DayTimeIntervalType }))); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 42496b92d6f7..92d3dc6fb88e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -23,6 +23,7 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the +import org.apache.spark.SparkException import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -1121,8 +1122,13 @@ class DataFrameAggregateSuite extends QueryTest (3, Period.ofMonths(21), Duration.ofDays(-5))) .toDF("class", "year-month", "day-time") + val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), + (Period.ofMonths(10), Duration.ofDays(10))) + .toDF("year-month", "day-time") + val sumDF = df.select(sum($"year-month"), sum($"day-time")) checkAnswer(sumDF, Row(Period.of(2, 5, 0), Duration.ofDays(0))) + assert(find(sumDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(sumDF.schema == StructType(Seq(StructField("sum(year-month)", YearMonthIntervalType), StructField("sum(day-time)", DayTimeIntervalType)))) @@ -1130,9 +1136,20 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(sumDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) :: Row(2, Period.ofMonths(1), Duration.ofDays(1)) :: Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) ::Nil) + assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false), StructField("sum(year-month)", YearMonthIntervalType), StructField("sum(day-time)", DayTimeIntervalType)))) + + val error = intercept[SparkException] { + checkAnswer(df2.select(sum($"year-month")), Nil) + } + assert(error.toString contains "java.lang.ArithmeticException: integer overflow") + + val error2 = intercept[SparkException] { + checkAnswer(df2.select(sum($"day-time")), Nil) + } + assert(error2.toString contains "java.lang.ArithmeticException: long overflow") } } From a4d1214ecc3f7ab846ba043b78d77fda2353f014 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Tue, 13 Apr 2021 18:59:38 +0800 Subject: [PATCH 5/7] extend BufferSetterGetterUtils --- .../spark/sql/execution/aggregate/udaf.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index e6851a9af739..ca8aa69e843c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -187,6 +187,22 @@ sealed trait BufferSetterGetterUtils { row.setNullAt(ordinal) } + case YearMonthIntervalType => + (row: InternalRow, ordinal: Int, value: Any) => + if (value != null) { + row.setInt(ordinal, value.asInstanceOf[Int]) + } else { + row.setNullAt(ordinal) + } + + case DayTimeIntervalType => + (row: InternalRow, ordinal: Int, value: Any) => + if (value != null) { + row.setLong(ordinal, value.asInstanceOf[Long]) + } else { + row.setNullAt(ordinal) + } + case other => (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { From 18c17ffa6848575cd589635b4b730a80056b57f7 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Wed, 14 Apr 2021 19:20:37 +0800 Subject: [PATCH 6/7] test --- .../org/apache/spark/sql/execution/aggregate/udaf.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index ca8aa69e843c..8c0a69f97dcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -87,6 +87,14 @@ sealed trait BufferSetterGetterUtils { (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getLong(ordinal) + case YearMonthIntervalType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getInt(ordinal) + + case DayTimeIntervalType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getLong(ordinal) + case other => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.get(ordinal, other) From 01effb8484723af55905ecab2a95b85e6b23fb48 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Thu, 15 Apr 2021 10:26:48 +0800 Subject: [PATCH 7/7] Update code --- .../spark/sql/execution/vectorized/OnHeapColumnVector.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 5a7d6cc20971..5942c5f00a71 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -541,14 +541,14 @@ protected void reserveInternal(int newCapacity) { shortData = newData; } } else if (type instanceof IntegerType || type instanceof DateType || - DecimalType.is32BitDecimalType(type)) { + DecimalType.is32BitDecimalType(type) || type instanceof YearMonthIntervalType) { if (intData == null || intData.length < newCapacity) { int[] newData = new int[newCapacity]; if (intData != null) System.arraycopy(intData, 0, newData, 0, capacity); intData = newData; } } else if (type instanceof LongType || type instanceof TimestampType || - DecimalType.is64BitDecimalType(type)) { + DecimalType.is64BitDecimalType(type) || type instanceof DayTimeIntervalType) { if (longData == null || longData.length < newCapacity) { long[] newData = new long[newCapacity]; if (longData != null) System.arraycopy(longData, 0, newData, 0, capacity);