Skip to content
1 change: 1 addition & 0 deletions dataframe-jdbc/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies {
testImplementation(libs.mssql)
testImplementation(libs.junit)
testImplementation(libs.sl4j)
testImplementation(libs.jts)
testImplementation(libs.kotestAssertions) {
exclude("org.jetbrains.kotlin", "kotlin-stdlib-jdk8")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,67 @@ import kotlin.reflect.KType
/**
* Represents the H2 database type.
*
* This class provides methods to convert data from a ResultSet to the appropriate type for H2,
* This class provides methods to convert data from a ResultSet to the appropriate type for H2
* and to generate the corresponding column schema.
*
* NOTE: All date and timestamp related types are converted to String to avoid java.sql.* types.
* NOTE: All date and timestamp-related types are converted to String to avoid java.sql.* types.
*/
public object H2 : DbType("h2") {
public class H2(public val dialect: DbType = MySql) : DbType("h2") {
init {
require(dialect::class != H2::class) { "H2 database could not be specified with H2 dialect!" }
}

/**
* It contains constants related to different database modes.
*
* The mode value is used in the [extractDBTypeFromConnection] function to determine the corresponding `DbType` for the H2 database connection URL.
* For example, if the URL contains the mode value "MySQL", the H2 instance with the MySQL database type is returned.
* Otherwise, the `DbType` is determined based on the URL without the mode value.
*
* @see [extractDBTypeFromConnection]
* @see [createH2Instance]
*/
public companion object {
/** It represents the mode value "MySQL" for the H2 database. */
public const val MODE_MYSQL: String = "MySQL"

/** It represents the mode value "PostgreSQL" for the H2 database. */
public const val MODE_POSTGRESQL: String = "PostgreSQL"

/** It represents the mode value "MSSQLServer" for the H2 database. */
public const val MODE_MSSQLSERVER: String = "MSSQLServer"

/** It represents the mode value "MariaDB" for the H2 database. */
public const val MODE_MARIADB: String = "MariaDB"
}

override val driverClassName: String
get() = "org.h2.Driver"

override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? {
return null
return dialect.convertSqlTypeToColumnSchemaValue(tableColumnMetadata)
}

override fun isSystemTable(tableMetadata: TableMetadata): Boolean {
return tableMetadata.name.lowercase(Locale.getDefault()).contains("sys_") ||
tableMetadata.schemaName?.lowercase(Locale.getDefault())?.contains("information_schema") ?: false
val locale = Locale.getDefault()
fun String?.containsWithLowercase(substr: String) = this?.lowercase(locale)?.contains(substr) == true
val schemaName = tableMetadata.schemaName

// could be extended for other symptoms of the system tables for H2
val isH2SystemTable = schemaName.containsWithLowercase("information_schema")

return isH2SystemTable || dialect.isSystemTable(tableMetadata)
}

override fun buildTableMetadata(tables: ResultSet): TableMetadata {
return TableMetadata(
tables.getString("TABLE_NAME"),
tables.getString("TABLE_SCHEM"),
tables.getString("TABLE_CAT")
)
return dialect.buildTableMetadata(tables)
}

override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? {
return null
return dialect.convertSqlTypeToKType(tableColumnMetadata)
}

public override fun sqlQueryLimit(sqlQuery: String, limit: Int): String {
return dialect.sqlQueryLimit(sqlQuery, limit)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata
import org.jetbrains.kotlinx.dataframe.io.TableMetadata
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import java.sql.ResultSet
import java.util.*
import java.util.Locale
import kotlin.reflect.KType
import kotlin.reflect.full.createType

/**
* Represents the MSSQL database type.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,75 @@
package org.jetbrains.kotlinx.dataframe.io.db

import io.github.oshai.kotlinlogging.KotlinLogging
import java.sql.Connection
import java.sql.SQLException
import java.util.Locale

private val logger = KotlinLogging.logger {}

/**
* Extracts the database type from the given connection.
*
* @param [connection] the database connection.
* @return the corresponding [DbType].
* @throws [IllegalStateException] if URL information is missing in connection meta-data.
* @throws [IllegalArgumentException] if the URL specifies an unsupported database type.
* @throws [SQLException] if the URL is null.
*/
public fun extractDBTypeFromConnection(connection: Connection): DbType {
val url = connection.metaData?.url ?: throw IllegalStateException("URL information is missing in connection meta data!")
logger.info { "Processing DB type extraction for connection url: $url" }

return if (url.contains(H2().dbTypeInJdbcUrl)) {
// works only for H2 version 2
val modeQuery = "SELECT SETTING_VALUE FROM INFORMATION_SCHEMA.SETTINGS WHERE SETTING_NAME = 'MODE'"
var mode = ""
connection.createStatement().use { st ->
st.executeQuery(
modeQuery
).use { rs ->
if (rs.next()) {
mode = rs.getString("SETTING_VALUE")
logger.debug { "Fetched H2 DB mode: $mode" }
} else {
throw IllegalStateException("The information about H2 mode is not found in the H2 meta-data!")
}
}
}

// H2 doesn't support MariaDB and SQLite
when (mode.lowercase(Locale.getDefault())) {
H2.MODE_MYSQL.lowercase(Locale.getDefault()) -> H2(MySql)
H2.MODE_MSSQLSERVER.lowercase(Locale.getDefault()) -> H2(MsSql)
H2.MODE_POSTGRESQL.lowercase(Locale.getDefault()) -> H2(PostgreSql)
H2.MODE_MARIADB.lowercase(Locale.getDefault()) -> H2(MariaDb)
else -> {
val message = "Unsupported database type in the url: $url. " +
"Only MySQL, MariaDB, MSSQL and PostgreSQL are supported!"
logger.error { message }

throw IllegalArgumentException(message)
}
}
} else {
val dbType = extractDBTypeFromUrl(url)
logger.info { "Identified DB type as $dbType from url: $url" }
dbType
}
}

/**
* Extracts the database type from the given JDBC URL.
*
* @param [url] the JDBC URL.
* @return the corresponding [DbType].
* @throws RuntimeException if the url is null.
* @throws [RuntimeException] if the url is null.
*/
public fun extractDBTypeFromUrl(url: String?): DbType {
if (url != null) {
val helperH2Instance = H2()
return when {
H2.dbTypeInJdbcUrl in url -> H2
helperH2Instance.dbTypeInJdbcUrl in url -> createH2Instance(url)
MariaDb.dbTypeInJdbcUrl in url -> MariaDb
MySql.dbTypeInJdbcUrl in url -> MySql
Sqlite.dbTypeInJdbcUrl in url -> Sqlite
Expand All @@ -28,6 +85,37 @@ public fun extractDBTypeFromUrl(url: String?): DbType {
}
}

/**
* Creates an instance of DbType based on the provided JDBC URL.
*
* @param [url] The JDBC URL representing the database connection.
* @return The corresponding [DbType] instance.
* @throws [IllegalArgumentException] if the provided URL does not contain a valid mode.
*/
private fun createH2Instance(url: String): DbType {
val modePattern = "MODE=(.*?);".toRegex()
val matchResult = modePattern.find(url)

val mode: String = if (matchResult != null && matchResult.groupValues.size == 2) {
matchResult.groupValues[1]
} else {
throw IllegalArgumentException("The provided URL `$url` does not contain a valid mode.")
}

// H2 doesn't support MariaDB and SQLite
return when (mode.lowercase(Locale.getDefault())) {
H2.MODE_MYSQL.lowercase(Locale.getDefault()) -> H2(MySql)
H2.MODE_MSSQLSERVER.lowercase(Locale.getDefault()) -> H2(MsSql)
H2.MODE_POSTGRESQL.lowercase(Locale.getDefault()) -> H2(PostgreSql)
H2.MODE_MARIADB.lowercase(Locale.getDefault()) -> H2(MariaDb)

else -> throw IllegalArgumentException(
"Unsupported database mode: $mode. " +
"Only MySQL, MariaDB, MSSQL, PostgreSQL modes are supported!"
)
}
}

/**
* Retrieves the driver class name from the given JDBC URL.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.api.Infer
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.impl.schema.DataFrameSchemaImpl
import org.jetbrains.kotlinx.dataframe.io.db.DbType
import org.jetbrains.kotlinx.dataframe.io.db.extractDBTypeFromUrl
import org.jetbrains.kotlinx.dataframe.io.db.*
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema
import java.math.BigDecimal
Expand Down Expand Up @@ -138,7 +137,7 @@ public fun DataFrame.Companion.readSqlTable(
inferNullability: Boolean = true,
): AnyFrame {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val selectAllQuery = if (limit > 0) dbType.sqlQueryLimit("SELECT * FROM $tableName", limit)
else "SELECT * FROM $tableName"
Expand Down Expand Up @@ -203,8 +202,7 @@ public fun DataFrame.Companion.readSqlQuery(
"Also it should not contain any separators like `;`."
}

val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val internalSqlQuery = if (limit > 0) dbType.sqlQueryLimit(sqlQuery, limit) else sqlQuery

Expand Down Expand Up @@ -283,8 +281,7 @@ public fun DataFrame.Companion.readResultSet(
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

return readResultSet(resultSet, dbType, limit, inferNullability)
}
Expand Down Expand Up @@ -329,8 +326,7 @@ public fun DataFrame.Companion.readAllSqlTables(
inferNullability: Boolean = true,
): Map<String, AnyFrame> {
val metaData = connection.metaData
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

// exclude a system and other tables without data, but it looks like it is supported badly for many databases
val tables = metaData.getTables(catalogue, null, null, arrayOf("TABLE"))
Expand Down Expand Up @@ -390,8 +386,7 @@ public fun DataFrame.Companion.getSchemaForSqlTable(
connection: Connection,
tableName: String
): DataFrameSchema {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val sqlQuery = "SELECT * FROM $tableName"
val selectFirstRowQuery = dbType.sqlQueryLimit(sqlQuery, limit = 1)
Expand Down Expand Up @@ -432,8 +427,7 @@ public fun DataFrame.Companion.getSchemaForSqlQuery(
* @see DriverManager.getConnection
*/
public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQuery: String): DataFrameSchema {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

connection.createStatement().use { st ->
st.executeQuery(sqlQuery).use { rs ->
Expand Down Expand Up @@ -468,8 +462,7 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbTyp
* @return the schema of the [ResultSet] as a [DataFrameSchema] object.
*/
public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, connection: Connection): DataFrameSchema {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val tableColumns = getTableColumnsMetadata(resultSet)
return buildSchemaByTableColumns(tableColumns, dbType)
Expand All @@ -495,8 +488,7 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfig
*/
public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): Map<String, DataFrameSchema> {
val metaData = connection.metaData
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val tableTypes = arrayOf("TABLE")
// exclude a system and other tables without data
Expand Down
Loading