diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 615256243ae2..2822e25021f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -227,7 +227,8 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" override def dataType: DataType = child.dataType match { case dt @ DecimalType.Fixed(_, 0) => dt case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision - scale + 1, 0) + val boundedPrecision = if (precision < scale) 1 else precision - scale + 1 + DecimalType.bounded(boundedPrecision, 0) case _ => LongType } @@ -344,7 +345,8 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO override def dataType: DataType = child.dataType match { case dt @ DecimalType.Fixed(_, 0) => dt case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision - scale + 1, 0) + val boundedPrecision = if (precision < scale) 1 else precision - scale + 1 + DecimalType.bounded(boundedPrecision, 0) case _ => LongType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 80916ee9c537..9c8ed6b4e02e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -384,13 +384,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this def floor: Decimal = if (scale == 0) this else { - val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision + val boundedPrecision = if (precision < scale) 1 else precision - scale + 1 + val newPrecision = DecimalType.bounded(boundedPrecision, 0).precision toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse( throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) } def ceil: Decimal = if (scale == 0) this else { - val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision + val boundedPrecision = if (precision < scale) 1 else precision - scale + 1 + val newPrecision = DecimalType.bounded(boundedPrecision, 0).precision toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse( throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 6af0cde73538..414f789f2efb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -254,7 +254,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testUnary(Ceil, (d: Double) => math.ceil(d).toLong) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) - testUnary(Ceil, (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.1))) + testUnary(Ceil, (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.0001))) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0)) @@ -274,7 +274,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testUnary(Floor, (d: Double) => math.floor(d).toLong) checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType) - testUnary(Floor, (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.1))) + testUnary(Floor, (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.0001))) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 41e9e2c92ca8..c6c655251ecc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2635,6 +2635,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-20211: should be able to floor or ceil with a decimal when its precision < scale") { + val df = Seq(0).toDF("a") + withTempView("tb") { + df.createOrReplaceTempView("tb") + checkAnswer(sql("SELECT 1 > 0.00001 FROM tb"), Row(true)) + checkAnswer(sql("SELECT floor(0.0001) FROM tb"), Row(0)) + checkAnswer(sql("SELECT ceil(0.0001) FROM tb"), Row(1)) + checkAnswer(sql("SELECT floor(0.00123) FROM tb"), Row(0)) + checkAnswer(sql("SELECT floor(0.00010) FROM tb"), Row(0)) + } + } + test("SPARK-12868: Allow adding jars from hdfs ") { val jarFromHdfs = "hdfs://doesnotmatter/test.jar" val jarFromInvalidFs = "fffs://doesnotmatter/test.jar"