Skip to content

Commit 381a825

Browse files
cloud-fandavies
authored andcommitted
[SPARK-15241] [SPARK-15242] [SQL] fix 2 decimal-related issues in RowEncoder
## What changes were proposed in this pull request? SPARK-15241: We now support java decimal and catalyst decimal in external row, it makes sense to also support scala decimal. SPARK-15242: This is a long-standing bug, and is exposed after #12364, which eliminate the `If` expression if the field is not nullable: ``` val fieldValue = serializerFor( GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)), f.dataType) if (f.nullable) { If( Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, f.dataType), fieldValue) } else { fieldValue } ``` Previously, we always use `DecimalType.SYSTEM_DEFAULT` as the output type of converted decimal field, which is wrong as it doesn't match the real decimal type. However, it works well because we always put converted field into `If` expression to do the null check, and `If` use its `trueValue`'s data type as its output type. Now if we have a not nullable decimal field, then the converted field's output type will be `DecimalType.SYSTEM_DEFAULT`, and we will write wrong data into unsafe row. The fix is simple, just use the given decimal type as the output type of converted decimal field. These 2 issues was found at #13008 ## How was this patch tested? new tests in RowEncoderSuite Author: Wenchen Fan <[email protected]> Closes #13019 from cloud-fan/encoder-decimal. (cherry picked from commit d8935db) Signed-off-by: Davies Liu <[email protected]>
1 parent 403ba65 commit 381a825

File tree

4 files changed

+29
-10
lines changed

4 files changed

+29
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,10 @@ object RowEncoder {
8484
"fromJavaDate",
8585
inputObject :: Nil)
8686

87-
case _: DecimalType =>
87+
case d: DecimalType =>
8888
StaticInvoke(
8989
Decimal.getClass,
90-
DecimalType.SYSTEM_DEFAULT,
90+
d,
9191
"fromDecimal",
9292
inputObject :: Nil)
9393

@@ -162,7 +162,7 @@ object RowEncoder {
162162
* `org.apache.spark.sql.types.Decimal`.
163163
*/
164164
private def externalDataTypeForInput(dt: DataType): DataType = dt match {
165-
// In order to support both Decimal and java BigDecimal in external row, we make this
165+
// In order to support both Decimal and java/scala BigDecimal in external row, we make this
166166
// as java.lang.Object.
167167
case _: DecimalType => ObjectType(classOf[java.lang.Object])
168168
case _ => externalDataTypeFor(dt)

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ object Decimal {
386386
def fromDecimal(value: Any): Decimal = {
387387
value match {
388388
case j: java.math.BigDecimal => apply(j)
389+
case d: BigDecimal => apply(d)
389390
case d: Decimal => d
390391
}
391392
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
108108
encodeDecodeTest(new java.lang.Double(-3.7), "boxed double")
109109

110110
encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal")
111-
// encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
111+
encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
112112

113113
encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal")
114114

@@ -336,6 +336,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
336336
Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]])
337337
case (b1: Array[_], b2: Array[_]) =>
338338
Arrays.equals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]])
339+
case (left: Comparable[Any], right: Comparable[Any]) => left.compareTo(right) == 0
339340
case _ => input == convertedBack
340341
}
341342

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,21 +143,38 @@ class RowEncoderSuite extends SparkFunSuite {
143143
assert(input.getStruct(0) == convertedBack.getStruct(0))
144144
}
145145

146-
test("encode/decode Decimal") {
146+
test("encode/decode decimal type") {
147147
val schema = new StructType()
148148
.add("int", IntegerType)
149149
.add("string", StringType)
150150
.add("double", DoubleType)
151-
.add("decimal", DecimalType.SYSTEM_DEFAULT)
151+
.add("java_decimal", DecimalType.SYSTEM_DEFAULT)
152+
.add("scala_decimal", DecimalType.SYSTEM_DEFAULT)
153+
.add("catalyst_decimal", DecimalType.SYSTEM_DEFAULT)
152154

153155
val encoder = RowEncoder(schema)
154156

155-
val input: Row = Row(100, "test", 0.123, Decimal(1234.5678))
157+
val javaDecimal = new java.math.BigDecimal("1234.5678")
158+
val scalaDecimal = BigDecimal("1234.5678")
159+
val catalystDecimal = Decimal("1234.5678")
160+
161+
val input = Row(100, "test", 0.123, javaDecimal, scalaDecimal, catalystDecimal)
156162
val row = encoder.toRow(input)
157163
val convertedBack = encoder.fromRow(row)
158-
// Decimal inside external row will be converted back to Java BigDecimal when decoding.
159-
assert(input.get(3).asInstanceOf[Decimal].toJavaBigDecimal
160-
.compareTo(convertedBack.getDecimal(3)) == 0)
164+
// Decimal will be converted back to Java BigDecimal when decoding.
165+
assert(convertedBack.getDecimal(3).compareTo(javaDecimal) == 0)
166+
assert(convertedBack.getDecimal(4).compareTo(scalaDecimal.bigDecimal) == 0)
167+
assert(convertedBack.getDecimal(5).compareTo(catalystDecimal.toJavaBigDecimal) == 0)
168+
}
169+
170+
test("RowEncoder should preserve decimal precision and scale") {
171+
val schema = new StructType().add("decimal", DecimalType(10, 5), false)
172+
val encoder = RowEncoder(schema)
173+
val decimal = Decimal("67123.45")
174+
val input = Row(decimal)
175+
val row = encoder.toRow(input)
176+
177+
assert(row.toSeq(schema).head == decimal)
161178
}
162179

163180
test("RowEncoder should preserve schema nullability") {

0 commit comments

Comments
 (0)