Skip to content

Commit 22b0c2a

Browse files
committed
General data types to be mapped to Oracle
1 parent 0b71d9a commit 22b0c2a

File tree

3 files changed

+116
-15
lines changed

3 files changed

+116
-15
lines changed

external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
package org.apache.spark.sql.jdbc
1919

20-
import java.sql.Connection
20+
import java.sql.{Connection, Date, Timestamp}
2121
import java.util.Properties
2222

23+
import org.apache.spark.sql.Row
2324
import org.apache.spark.sql.test.SharedSQLContext
25+
import org.apache.spark.sql.types._
2426
import org.apache.spark.tags.DockerTest
2527

2628
/**
@@ -77,4 +79,74 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
7779
// verify the value is the inserted correct or not
7880
assert(rows(0).getString(0).equals("foo"))
7981
}
82+
83+
test("SPARK-16625: General data types to be mapped to Oracle") {
84+
val props = new Properties()
85+
props.put("oracle.jdbc.mapDateToTimestamp", "false")
86+
87+
val schema = StructType(Seq(
88+
StructField("boolean_type", BooleanType, true),
89+
StructField("integer_type", IntegerType, true),
90+
StructField("long_type", LongType, true),
91+
StructField("float_Type", FloatType, true),
92+
StructField("double_type", DoubleType, true),
93+
StructField("byte_type", ByteType, true),
94+
StructField("short_type", ShortType, true),
95+
StructField("string_type", StringType, true),
96+
StructField("binary_type", BinaryType, true),
97+
StructField("date_type", DateType, true),
98+
StructField("timestamp_type", TimestampType, true)
99+
))
100+
101+
val tableName = "test_oracle_general_types"
102+
val booleanVal = true
103+
val integerVal = 1
104+
val longVal = 2L
105+
val floatVal = 3.0f
106+
val doubleVal = 4.0
107+
val byteVal = 2.toByte
108+
val shortVal = 5.toShort
109+
val stringVal = "string"
110+
val binaryVal = Array[Byte](6, 7, 8)
111+
val dateVal = Date.valueOf("2016-07-26")
112+
val timestampVal = Timestamp.valueOf("2016-07-26 11:49:45")
113+
114+
val data = spark.sparkContext.parallelize(Seq(
115+
Row(
116+
booleanVal, integerVal, longVal, floatVal, doubleVal, byteVal, shortVal, stringVal,
117+
binaryVal, dateVal, timestampVal
118+
)))
119+
120+
val dfWrite = spark.createDataFrame(data, schema)
121+
dfWrite.write.jdbc(jdbcUrl, tableName, props)
122+
123+
val dfRead = spark.read.jdbc(jdbcUrl, tableName, props)
124+
val rows = dfRead.collect()
125+
// verify the data type is inserted
126+
val types = rows(0).toSeq.map(x => x.getClass.toString)
127+
assert(types(0).equals("class java.lang.Boolean"))
128+
assert(types(1).equals("class java.lang.Integer"))
129+
assert(types(2).equals("class java.lang.Long"))
130+
assert(types(3).equals("class java.lang.Float"))
131+
assert(types(4).equals("class java.lang.Float"))
132+
assert(types(5).equals("class java.lang.Integer"))
133+
assert(types(6).equals("class java.lang.Integer"))
134+
assert(types(7).equals("class java.lang.String"))
135+
assert(types(8).equals("class [B"))
136+
assert(types(9).equals("class java.sql.Date"))
137+
assert(types(10).equals("class java.sql.Timestamp"))
138+
// verify the value is the inserted correct or not
139+
val values = rows(0)
140+
assert(values.getBoolean(0).equals(booleanVal))
141+
assert(values.getInt(1).equals(integerVal))
142+
assert(values.getLong(2).equals(longVal))
143+
assert(values.getFloat(3).equals(floatVal))
144+
assert(values.getFloat(4).equals(doubleVal.toFloat))
145+
assert(values.getInt(5).equals(byteVal.toInt))
146+
assert(values.getInt(6).equals(shortVal.toInt))
147+
assert(values.getString(7).equals(stringVal))
148+
assert(values.getAs[Array[Byte]](8).mkString.equals("678"))
149+
assert(values.getDate(9).equals(dateVal))
150+
assert(values.getTimestamp(10).equals(timestampVal))
151+
}
80152
}

sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,30 +26,38 @@ private case object OracleDialect extends JdbcDialect {
2626

2727
override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle")
2828

29-
override def getCatalystType(
30-
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
29+
override def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder):
30+
Option[DataType] = sqlType match {
3131
// Handle NUMBER fields that have no precision/scale in special way
3232
// because JDBC ResultSetMetaData converts this to 0 precision and -127 scale
3333
// For more details, please see
3434
// https://github.com/apache/spark/pull/8780#issuecomment-145598968
3535
// and
3636
// https://github.com/apache/spark/pull/8780#issuecomment-144541760
37-
if (sqlType == Types.NUMERIC && size == 0) {
38-
// This is sub-optimal as we have to pick a precision/scale in advance whereas the data
39-
// in Oracle is allowed to have different precision/scale for each value.
37+
case Types.NUMERIC if size == 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10))
38+
// Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts
39+
// this to NUMERIC with -127 scale
40+
// Not sure if there is a more robust way to identify the field as a float (or other
41+
// numeric types that do not specify a scale.
42+
case Types.NUMERIC if md.build().getLong("scale") == -127 =>
4043
Option(DecimalType(DecimalType.MAX_PRECISION, 10))
41-
} else if (sqlType == Types.NUMERIC && md.build().getLong("scale") == -127) {
42-
// Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts
43-
// this to NUMERIC with -127 scale
44-
// Not sure if there is a more robust way to identify the field as a float (or other
45-
// numeric types that do not specify a scale.
46-
Option(DecimalType(DecimalType.MAX_PRECISION, 10))
47-
} else {
48-
None
49-
}
44+
case Types.NUMERIC if size == 1 => Option(BooleanType)
45+
case Types.NUMERIC if size == 3 || size == 5 || size == 10 => Option(IntegerType)
46+
case Types.NUMERIC if size == 19 && md.build().getLong("scale") == 0L => Option(LongType)
47+
case Types.NUMERIC if size == 19 && md.build().getLong("scale") == 4L => Option(FloatType)
48+
case _ => None
5049
}
5150

5251
override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
52+
// For more details, please see
53+
// https://docs.oracle.com/cd/E19501-01/819-3659/gcmaz/
54+
case BooleanType => Some(JdbcType("NUMBER(1)", java.sql.Types.BOOLEAN))
55+
case IntegerType => Some(JdbcType("NUMBER(10)", java.sql.Types.INTEGER))
56+
case LongType => Some(JdbcType("NUMBER(19)", java.sql.Types.BIGINT))
57+
case FloatType => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.FLOAT))
58+
case DoubleType => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.DOUBLE))
59+
case ByteType => Some(JdbcType("NUMBER(3)", java.sql.Types.SMALLINT))
60+
case ShortType => Some(JdbcType("NUMBER(5)", java.sql.Types.SMALLINT))
5361
case StringType => Some(JdbcType("VARCHAR2(255)", java.sql.Types.VARCHAR))
5462
case _ => None
5563
}

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,27 @@ class JDBCSuite extends SparkFunSuite
739739
map(_.databaseTypeDefinition).get == "VARCHAR2(255)")
740740
}
741741

742+
test("SPARK-16625: General data types to be mapped to Oracle") {
743+
744+
def getJdbcType(dialect: JdbcDialect, dt: DataType): String = {
745+
dialect.getJDBCType(dt).orElse(JdbcUtils.getCommonJDBCType(dt)).
746+
map(_.databaseTypeDefinition).get
747+
}
748+
749+
val oracleDialect = JdbcDialects.get("jdbc:oracle://127.0.0.1/db")
750+
assert(getJdbcType(oracleDialect, BooleanType) == "NUMBER(1)")
751+
assert(getJdbcType(oracleDialect, IntegerType) == "NUMBER(10)")
752+
assert(getJdbcType(oracleDialect, LongType) == "NUMBER(19)")
753+
assert(getJdbcType(oracleDialect, FloatType) == "NUMBER(19, 4)")
754+
assert(getJdbcType(oracleDialect, DoubleType) == "NUMBER(19, 4)")
755+
assert(getJdbcType(oracleDialect, ByteType) == "NUMBER(3)")
756+
assert(getJdbcType(oracleDialect, ShortType) == "NUMBER(5)")
757+
assert(getJdbcType(oracleDialect, StringType) == "VARCHAR2(255)")
758+
assert(getJdbcType(oracleDialect, BinaryType) == "BLOB")
759+
assert(getJdbcType(oracleDialect, DateType) == "DATE")
760+
assert(getJdbcType(oracleDialect, TimestampType) == "TIMESTAMP")
761+
}
762+
742763
private def assertEmptyQuery(sqlString: String): Unit = {
743764
assert(sql(sqlString).collect().isEmpty)
744765
}

0 commit comments

Comments
 (0)