diff --git a/build.gradle.kts b/build.gradle.kts index 7406202de8..89f64525fc 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -37,6 +37,7 @@ dependencies { api(project(":dataframe-arrow")) api(project(":dataframe-excel")) api(project(":dataframe-openapi")) + api(project(":dataframe-jdbc")) } allprojects { diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/annotations/ImportDataSchema.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/annotations/ImportDataSchema.kt index 462fd7b734..43a57c8bce 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/annotations/ImportDataSchema.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/annotations/ImportDataSchema.kt @@ -8,21 +8,24 @@ import org.jetbrains.kotlinx.dataframe.io.JSON /** * Annotation preprocessing will generate a DataSchema interface from the data at `path`. - * Data must be of supported format: CSV, JSON, Apache Arrow, Excel, OpenAPI (Swagger) in YAML/JSON. + * Data must be of supported format: CSV, JSON, Apache Arrow, Excel, OpenAPI (Swagger) in YAML/JSON, JDBC. * Generated data schema has properties inferred from data and a companion object with `read method`. * `read method` is either `readCSV` or `readJson` that returns `DataFrame` * * @param name name of the generated interface * @param path URL or relative path to data. - * if path starts with protocol (http, https, ftp), it's considered a URL. Otherwise, it's treated as relative path. + * If a path starts with protocol (http, https, ftp, jdbc), it's considered a URL. + * Otherwise, it's treated as a relative path. * By default, it will be resolved relatively to project dir, i.e. File(projectDir, path) - * You can configure it by passing `dataframe.resolutionDir` option to preprocessor, see https://kotlinlang.org/docs/ksp-quickstart.html#pass-options-to-processors + * You can configure it by passing `dataframe.resolutionDir` option to preprocessor, + * see https://kotlinlang.org/docs/ksp-quickstart.html#pass-options-to-processors * @param visibility visibility of the generated interface. * @param normalizationDelimiters if not empty, split property names by delimiters, * lowercase parts and join to camel case. Set empty list to disable normalization * @param withDefaultPath if `true`, generate `defaultPath` property to the data schema's companion object and make it default argument for a `read method` * @param csvOptions options to parse CSV data. Not used when data is not Csv * @param jsonOptions options to parse JSON data. Not used when data is not Json + * @param jdbcOptions options to parse data from a database via JDBC. Not used when data is not stored in the database */ @Retention(AnnotationRetention.SOURCE) @Target(AnnotationTarget.FILE) @@ -35,6 +38,7 @@ public annotation class ImportDataSchema( val withDefaultPath: Boolean = true, val csvOptions: CsvOptions = CsvOptions(','), val jsonOptions: JsonOptions = JsonOptions(), + val jdbcOptions: JdbcOptions = JdbcOptions(), ) public enum class DataSchemaVisibility { @@ -45,6 +49,12 @@ public annotation class CsvOptions( public val delimiter: Char, ) +public annotation class JdbcOptions( + public val user: String = "", // TODO: I'm not sure about the default parameters + public val password: String = "", // TODO: I'm not sure about the default parameters) + public val sqlQuery: String = "" +) + public annotation class JsonOptions( /** Allows the choice of how to handle type clashes when reading a JSON file. */ diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/codeGen/DefaultReadDfMethods.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/codeGen/DefaultReadDfMethods.kt index 5f93bc31a9..e7e97675cb 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/codeGen/DefaultReadDfMethods.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/codeGen/DefaultReadDfMethods.kt @@ -24,6 +24,7 @@ private const val verify = "verify" // cast(true) is obscure, i think it's bette private const val readCSV = "readCSV" private const val readTSV = "readTSV" private const val readJson = "readJson" +private const val readJdbc = "readJdbc" public abstract class AbstractDefaultReadMethod( private val path: String?, diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/schema/DataFrameSchemaImpl.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/schema/DataFrameSchemaImpl.kt index 711efda09a..791853f9f6 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/schema/DataFrameSchemaImpl.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/schema/DataFrameSchemaImpl.kt @@ -5,7 +5,7 @@ import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema import org.jetbrains.kotlinx.dataframe.schema.CompareResult import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema -internal class DataFrameSchemaImpl(override val columns: Map) : DataFrameSchema { +public class DataFrameSchemaImpl(override val columns: Map) : DataFrameSchema { override fun compare(other: DataFrameSchema): CompareResult { require(other is DataFrameSchemaImpl) diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/Integration.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/Integration.kt index f74b139d39..cc215f92d4 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/Integration.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/Integration.kt @@ -169,6 +169,7 @@ internal class Integration( if (version != null) { dependencies( "org.jetbrains.kotlinx:dataframe-excel:$version", + "org.jetbrains.kotlinx:dataframe-jdbc:$version", "org.jetbrains.kotlinx:dataframe-arrow:$version", "org.jetbrains.kotlinx:dataframe-openapi:$version", ) @@ -176,7 +177,7 @@ internal class Integration( try { setMinimalKernelVersion(MIN_KERNEL_VERSION) - } catch (_: NoSuchMethodError) { // will be thrown on version < 0.11.0.198 + } catch (_: NoSuchMethodError) { // will be thrown when a version < 0.11.0.198 throw IllegalStateException( getKernelUpdateMessage(notebook.kernelVersion, MIN_KERNEL_VERSION, notebook.jupyterClientType) ) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/annotations/ImportDataSchema.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/annotations/ImportDataSchema.kt index 462fd7b734..43a57c8bce 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/annotations/ImportDataSchema.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/annotations/ImportDataSchema.kt @@ -8,21 +8,24 @@ import org.jetbrains.kotlinx.dataframe.io.JSON /** * Annotation preprocessing will generate a DataSchema interface from the data at `path`. - * Data must be of supported format: CSV, JSON, Apache Arrow, Excel, OpenAPI (Swagger) in YAML/JSON. + * Data must be of supported format: CSV, JSON, Apache Arrow, Excel, OpenAPI (Swagger) in YAML/JSON, JDBC. * Generated data schema has properties inferred from data and a companion object with `read method`. * `read method` is either `readCSV` or `readJson` that returns `DataFrame` * * @param name name of the generated interface * @param path URL or relative path to data. - * if path starts with protocol (http, https, ftp), it's considered a URL. Otherwise, it's treated as relative path. + * If a path starts with protocol (http, https, ftp, jdbc), it's considered a URL. + * Otherwise, it's treated as a relative path. * By default, it will be resolved relatively to project dir, i.e. File(projectDir, path) - * You can configure it by passing `dataframe.resolutionDir` option to preprocessor, see https://kotlinlang.org/docs/ksp-quickstart.html#pass-options-to-processors + * You can configure it by passing `dataframe.resolutionDir` option to preprocessor, + * see https://kotlinlang.org/docs/ksp-quickstart.html#pass-options-to-processors * @param visibility visibility of the generated interface. * @param normalizationDelimiters if not empty, split property names by delimiters, * lowercase parts and join to camel case. Set empty list to disable normalization * @param withDefaultPath if `true`, generate `defaultPath` property to the data schema's companion object and make it default argument for a `read method` * @param csvOptions options to parse CSV data. Not used when data is not Csv * @param jsonOptions options to parse JSON data. Not used when data is not Json + * @param jdbcOptions options to parse data from a database via JDBC. Not used when data is not stored in the database */ @Retention(AnnotationRetention.SOURCE) @Target(AnnotationTarget.FILE) @@ -35,6 +38,7 @@ public annotation class ImportDataSchema( val withDefaultPath: Boolean = true, val csvOptions: CsvOptions = CsvOptions(','), val jsonOptions: JsonOptions = JsonOptions(), + val jdbcOptions: JdbcOptions = JdbcOptions(), ) public enum class DataSchemaVisibility { @@ -45,6 +49,12 @@ public annotation class CsvOptions( public val delimiter: Char, ) +public annotation class JdbcOptions( + public val user: String = "", // TODO: I'm not sure about the default parameters + public val password: String = "", // TODO: I'm not sure about the default parameters) + public val sqlQuery: String = "" +) + public annotation class JsonOptions( /** Allows the choice of how to handle type clashes when reading a JSON file. */ diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/codeGen/DefaultReadDfMethods.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/codeGen/DefaultReadDfMethods.kt index 5f93bc31a9..e7e97675cb 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/codeGen/DefaultReadDfMethods.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/codeGen/DefaultReadDfMethods.kt @@ -24,6 +24,7 @@ private const val verify = "verify" // cast(true) is obscure, i think it's bette private const val readCSV = "readCSV" private const val readTSV = "readTSV" private const val readJson = "readJson" +private const val readJdbc = "readJdbc" public abstract class AbstractDefaultReadMethod( private val path: String?, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/schema/DataFrameSchemaImpl.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/schema/DataFrameSchemaImpl.kt index 711efda09a..791853f9f6 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/schema/DataFrameSchemaImpl.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/schema/DataFrameSchemaImpl.kt @@ -5,7 +5,7 @@ import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema import org.jetbrains.kotlinx.dataframe.schema.CompareResult import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema -internal class DataFrameSchemaImpl(override val columns: Map) : DataFrameSchema { +public class DataFrameSchemaImpl(override val columns: Map) : DataFrameSchema { override fun compare(other: DataFrameSchema): CompareResult { require(other is DataFrameSchemaImpl) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/Integration.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/Integration.kt index f74b139d39..cc215f92d4 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/Integration.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/Integration.kt @@ -169,6 +169,7 @@ internal class Integration( if (version != null) { dependencies( "org.jetbrains.kotlinx:dataframe-excel:$version", + "org.jetbrains.kotlinx:dataframe-jdbc:$version", "org.jetbrains.kotlinx:dataframe-arrow:$version", "org.jetbrains.kotlinx:dataframe-openapi:$version", ) @@ -176,7 +177,7 @@ internal class Integration( try { setMinimalKernelVersion(MIN_KERNEL_VERSION) - } catch (_: NoSuchMethodError) { // will be thrown on version < 0.11.0.198 + } catch (_: NoSuchMethodError) { // will be thrown when a version < 0.11.0.198 throw IllegalStateException( getKernelUpdateMessage(notebook.kernelVersion, MIN_KERNEL_VERSION, notebook.jupyterClientType) ) diff --git a/dataframe-excel/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/xlsx.kt b/dataframe-excel/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/xlsx.kt index c21455bb34..9b4d3586a4 100644 --- a/dataframe-excel/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/xlsx.kt +++ b/dataframe-excel/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/xlsx.kt @@ -57,7 +57,7 @@ internal class DefaultReadExcelMethod(path: String?) : AbstractDefaultReadMethod private const val readExcel = "readExcel" /** - * @param sheetName sheet to read. By default, first sheet in the document + * @param sheetName sheet to read. By default, the first sheet in the document * @param columns comma separated list of Excel column letters and column ranges (e.g. “A:E” or “A,C,E:F”) * @param skipRows number of rows before header * @param rowsCount number of rows to read. @@ -77,7 +77,7 @@ public fun DataFrame.Companion.readExcel( } /** - * @param sheetName sheet to read. By default, first sheet in the document + * @param sheetName sheet to read. By default, the first sheet in the document * @param columns comma separated list of Excel column letters and column ranges (e.g. “A:E” or “A,C,E:F”) * @param skipRows number of rows before header * @param rowsCount number of rows to read. @@ -97,7 +97,7 @@ public fun DataFrame.Companion.readExcel( } /** - * @param sheetName sheet to read. By default, first sheet in the document + * @param sheetName sheet to read. By default, the first sheet in the document * @param columns comma separated list of Excel column letters and column ranges (e.g. “A:E” or “A,C,E:F”) * @param skipRows number of rows before header * @param rowsCount number of rows to read. @@ -114,7 +114,7 @@ public fun DataFrame.Companion.readExcel( ): AnyFrame = readExcel(asURL(fileOrUrl), sheetName, skipRows, columns, rowsCount, nameRepairStrategy) /** - * @param sheetName sheet to read. By default, first sheet in the document + * @param sheetName sheet to read. By default, the first sheet in the document * @param columns comma separated list of Excel column letters and column ranges (e.g. “A:E” or “A,C,E:F”) * @param skipRows number of rows before header * @param rowsCount number of rows to read. @@ -134,7 +134,7 @@ public fun DataFrame.Companion.readExcel( } /** - * @param sheetName sheet to read. By default, first sheet in the document + * @param sheetName sheet to read. By default, the first sheet in the document * @param columns comma separated list of Excel column letters and column ranges (e.g. “A:E” or “A,C,E:F”) * @param skipRows number of rows before header * @param rowsCount number of rows to read. @@ -446,8 +446,8 @@ private fun Cell.setCellValueByGuessedType(any: Any) { /** * Set LocalDateTime value correctly also if date have zero value in Excel. - * Zero date is usually used fore storing time component only, - * is displayed as 00.01.1900 in Excel and as 30.12.1899 in LibreOffice Calc and also in POI. + * Zero dates are usually used for storing a time component only, + * are displayed as 00.01.1900 in Excel and as 30.12.1899 in LibreOffice Calc and also in POI. * POI can not set 1899 year directly. */ private fun Cell.setTime(localDateTime: LocalDateTime) { @@ -455,9 +455,9 @@ private fun Cell.setTime(localDateTime: LocalDateTime) { } /** - * Set Date value correctly also if date have zero value in Excel. - * Zero date is usually used fore storing time component only, - * is displayed as 00.01.1900 in Excel and as 30.12.1899 in LibreOffice Calc and also in POI. + * Set Date value correctly also if date has zero value in Excel. + * Zero dates are usually used for storing a time component only, + * are displayed as 00.01.1900 in Excel and as 30.12.1899 in LibreOffice Calc and also in POI. * POI can not set 1899 year directly. */ private fun Cell.setDate(date: Date) { diff --git a/dataframe-jdbc/build.gradle.kts b/dataframe-jdbc/build.gradle.kts new file mode 100644 index 0000000000..50652715e8 --- /dev/null +++ b/dataframe-jdbc/build.gradle.kts @@ -0,0 +1,43 @@ +plugins { + kotlin("jvm") + kotlin("libs.publisher") + id("org.jetbrains.kotlinx.kover") + kotlin("jupyter.api") +} + +group = "org.jetbrains.kotlinx" + +val jupyterApiTCRepo: String by project + +repositories { + mavenCentral() + maven(jupyterApiTCRepo) +} + +dependencies { + api(project(":core")) + implementation(libs.mariadb) + implementation(libs.kotlinLogging) + testImplementation(libs.sqlite) + testImplementation(libs.postgresql) + testImplementation(libs.mysql) + testImplementation(libs.h2db) + testImplementation(libs.junit) + testImplementation(libs.sl4j) + testImplementation(libs.kotestAssertions) { + exclude("org.jetbrains.kotlin", "kotlin-stdlib-jdk8") + } +} + +kotlinPublications { + publication { + publicationName.set("dataframeJDBC") + artifactId.set(project.name) + description.set("JDBC support for Kotlin Dataframe") + packageName.set(artifactId) + } +} + +kotlin { + explicitApi() +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/Jdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/Jdbc.kt new file mode 100644 index 0000000000..1e5a3fa911 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/Jdbc.kt @@ -0,0 +1,53 @@ +package org.jetbrains.kotlinx.dataframe.io + +import org.jetbrains.kotlinx.dataframe.AnyFrame +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.codeGen.AbstractDefaultReadMethod +import org.jetbrains.kotlinx.dataframe.codeGen.DefaultReadDfMethod +import org.jetbrains.kotlinx.jupyter.api.Code +import java.io.File +import java.io.InputStream + +// TODO: https://github.com/Kotlin/dataframe/issues/450 +public class Jdbc : SupportedCodeGenerationFormat, SupportedDataFrameFormat { + public override fun readDataFrame(stream: InputStream, header: List): AnyFrame = DataFrame.readJDBC(stream) + + public override fun readDataFrame(file: File, header: List): AnyFrame = DataFrame.readJDBC(file) + override fun readCodeForGeneration( + stream: InputStream, + name: String, + generateHelperCompanionObject: Boolean + ): Code { + TODO("Not yet implemented") + } + + override fun readCodeForGeneration( + file: File, + name: String, + generateHelperCompanionObject: Boolean + ): Code { + TODO("Not yet implemented") + } + + override fun acceptsExtension(ext: String): Boolean = ext == "jdbc" + + override fun acceptsSample(sample: SupportedFormatSample): Boolean = true // Extension is enough + + override val testOrder: Int = 40000 + + override fun createDefaultReadMethod(pathRepresentation: String?): DefaultReadDfMethod { + return DefaultReadJdbcMethod(pathRepresentation) + } +} + +private fun DataFrame.Companion.readJDBC(stream: File): DataFrame<*> { + TODO("Not yet implemented") +} + +private fun DataFrame.Companion.readJDBC(stream: InputStream): DataFrame<*> { + TODO("Not yet implemented") +} + +internal class DefaultReadJdbcMethod(path: String?) : AbstractDefaultReadMethod(path, MethodArguments.EMPTY, readJDBC) + +private const val readJDBC = "readJDBC" diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt new file mode 100644 index 0000000000..653c8ff769 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt @@ -0,0 +1,44 @@ +package org.jetbrains.kotlinx.dataframe.io.db + +import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata +import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema +import java.sql.ResultSet +import org.jetbrains.kotlinx.dataframe.io.TableMetadata + +/** + * The `DbType` class represents a database type used for reading dataframe from the database. + * + * @property [dbTypeInJdbcUrl] The name of the database as specified in the JDBC URL. + */ +public abstract class DbType(public val dbTypeInJdbcUrl: String) { + /** + * Converts the data from the given [ResultSet] into the specified [TableColumnMetadata] type. + * + * @param rs The [ResultSet] containing the data to be converted. + * @param tableColumnMetadata The [TableColumnMetadata] representing the target type of the conversion. + * @return The converted data as an instance of [Any]. + */ + public abstract fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? + + /** + * Returns a [ColumnSchema] produced from [tableColumnMetadata]. + */ + public abstract fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema + + /** + * Checks if the given table name is a system table for the specified database type. + * + * @param [tableMetadata] the table object representing the table from the database. + * @param [dbType] the database type to check against. + * @return True if the table is a system table for the specified database type, false otherwise. + */ + public abstract fun isSystemTable(tableMetadata: TableMetadata): Boolean + + /** + * Builds the table metadata based on the database type and the ResultSet from the query. + * + * @param [tables] the ResultSet containing the table's meta-information. + * @return the TableMetadata object representing the table metadata. + */ + public abstract fun buildTableMetadata(tables: ResultSet): TableMetadata +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt new file mode 100644 index 0000000000..575c520dda --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt @@ -0,0 +1,99 @@ +package org.jetbrains.kotlinx.dataframe.io.db + +import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata +import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema +import java.sql.ResultSet +import java.util.Locale +import org.jetbrains.kotlinx.dataframe.DataRow +import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup +import org.jetbrains.kotlinx.dataframe.io.TableMetadata +import kotlin.reflect.typeOf + +/** + * Represents the H2 database type. + * + * This class provides methods to convert data from a ResultSet to the appropriate type for H2, + * and to generate the corresponding column schema. + * + * NOTE: All date and timestamp related types are converted to String to avoid java.sql.* types. + */ +public object H2 : DbType("h2") { + override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? { + val name = tableColumnMetadata.name + return when (tableColumnMetadata.sqlTypeName) { + "CHARACTER", "CHAR" -> rs.getString(name) + "CHARACTER VARYING", "CHAR VARYING", "VARCHAR" -> rs.getString(name) + "CHARACTER LARGE OBJECT", "CHAR LARGE OBJECT", "CLOB" -> rs.getString(name) + "MEDIUMTEXT" -> rs.getString(name) + "VARCHAR_IGNORECASE" -> rs.getString(name) + "BINARY" -> rs.getBytes(name) + "BINARY VARYING", "VARBINARY" -> rs.getBytes(name) + "BINARY LARGE OBJECT", "BLOB" -> rs.getBytes(name) + "BOOLEAN" -> rs.getBoolean(name) + "TINYINT" -> rs.getByte(name) + "SMALLINT" -> rs.getShort(name) + "INTEGER", "INT" -> rs.getInt(name) + "BIGINT" -> rs.getLong(name) + "NUMERIC", "DECIMAL", "DEC" -> rs.getFloat(name) // not a BigDecimal + "REAL", "FLOAT" -> rs.getFloat(name) + "DOUBLE PRECISION" -> rs.getDouble(name) + "DECFLOAT" -> rs.getDouble(name) + "DATE" -> rs.getDate(name).toString() + "TIME" -> rs.getTime(name).toString() + "TIME WITH TIME ZONE" -> rs.getTime(name).toString() + "TIMESTAMP" -> rs.getTimestamp(name).toString() + "TIMESTAMP WITH TIME ZONE" -> rs.getTimestamp(name).toString() + "INTERVAL" -> rs.getObject(name).toString() + "JAVA_OBJECT" -> rs.getObject(name) + "ENUM" -> rs.getString(name) + "JSON" -> rs.getString(name) // TODO: https://github.com/Kotlin/dataframe/issues/462 + "UUID" -> rs.getString(name) + else -> throw IllegalArgumentException("Unsupported H2 type: ${tableColumnMetadata.sqlTypeName}") + } + } + + override fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema { + return when (tableColumnMetadata.sqlTypeName) { + "CHARACTER", "CHAR" -> ColumnSchema.Value(typeOf()) + "CHARACTER VARYING", "CHAR VARYING", "VARCHAR" -> ColumnSchema.Value(typeOf()) + "CHARACTER LARGE OBJECT", "CHAR LARGE OBJECT", "CLOB" -> ColumnSchema.Value(typeOf()) + "MEDIUMTEXT" -> ColumnSchema.Value(typeOf()) + "VARCHAR_IGNORECASE" -> ColumnSchema.Value(typeOf()) + "BINARY" -> ColumnSchema.Value(typeOf()) + "BINARY VARYING", "VARBINARY" -> ColumnSchema.Value(typeOf()) + "BINARY LARGE OBJECT", "BLOB" -> ColumnSchema.Value(typeOf()) + "BOOLEAN" -> ColumnSchema.Value(typeOf()) + "TINYINT" -> ColumnSchema.Value(typeOf()) + "SMALLINT" -> ColumnSchema.Value(typeOf()) + "INTEGER", "INT" -> ColumnSchema.Value(typeOf()) + "BIGINT" -> ColumnSchema.Value(typeOf()) + "NUMERIC", "DECIMAL", "DEC" -> ColumnSchema.Value(typeOf()) + "REAL", "FLOAT" -> ColumnSchema.Value(typeOf()) + "DOUBLE PRECISION" -> ColumnSchema.Value(typeOf()) + "DECFLOAT" -> ColumnSchema.Value(typeOf()) + "DATE" -> ColumnSchema.Value(typeOf()) + "TIME" -> ColumnSchema.Value(typeOf()) + "TIME WITH TIME ZONE" -> ColumnSchema.Value(typeOf()) + "TIMESTAMP" -> ColumnSchema.Value(typeOf()) + "TIMESTAMP WITH TIME ZONE" -> ColumnSchema.Value(typeOf()) + "INTERVAL" -> ColumnSchema.Value(typeOf()) + "JAVA_OBJECT" -> ColumnSchema.Value(typeOf()) + "ENUM" -> ColumnSchema.Value(typeOf()) + "JSON" -> ColumnSchema.Value(typeOf()) // TODO: https://github.com/Kotlin/dataframe/issues/462 + "UUID" -> ColumnSchema.Value(typeOf()) + else -> throw IllegalArgumentException("Unsupported H2 type: ${tableColumnMetadata.sqlTypeName} for column ${tableColumnMetadata.name}") + } + } + + override fun isSystemTable(tableMetadata: TableMetadata): Boolean { + return tableMetadata.name.lowercase(Locale.getDefault()).contains("sys_") + || tableMetadata.schemaName?.lowercase(Locale.getDefault())?.contains("information_schema") ?: false + } + + override fun buildTableMetadata(tables: ResultSet): TableMetadata { + return TableMetadata( + tables.getString("TABLE_NAME"), + tables.getString("TABLE_SCHEM"), + tables.getString("TABLE_CAT")) + } +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MariaDb.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MariaDb.kt new file mode 100644 index 0000000000..ace43ea999 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MariaDb.kt @@ -0,0 +1,95 @@ +package org.jetbrains.kotlinx.dataframe.io.db + +import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata +import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema +import java.sql.ResultSet +import org.jetbrains.kotlinx.dataframe.io.TableMetadata +import kotlin.reflect.typeOf + +/** + * Represents the MariaDb database type. + * + * This class provides methods to convert data from a ResultSet to the appropriate type for MariaDb, + * and to generate the corresponding column schema. + */ +public object MariaDb : DbType("mariadb") { + override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? { + val name = tableColumnMetadata.name + return when (tableColumnMetadata.sqlTypeName) { + "BIT" -> rs.getBytes(name) + "TINYINT" -> rs.getInt(name) + "SMALLINT" -> rs.getInt(name) + "MEDIUMINT"-> rs.getInt(name) + "MEDIUMINT UNSIGNED" -> rs.getLong(name) + "INTEGER", "INT" -> rs.getInt(name) + "INTEGER UNSIGNED", "INT UNSIGNED" -> rs.getLong(name) + "BIGINT" -> rs.getLong(name) + "FLOAT" -> rs.getFloat(name) + "DOUBLE" -> rs.getDouble(name) + "DECIMAL" -> rs.getBigDecimal(name) + "DATE" -> rs.getDate(name).toString() + "DATETIME" -> rs.getTimestamp(name).toString() + "TIMESTAMP" -> rs.getTimestamp(name).toString() + "TIME"-> rs.getTime(name).toString() + "YEAR" -> rs.getDate(name).toString() + "VARCHAR", "CHAR" -> rs.getString(name) + "BINARY" -> rs.getBytes(name) + "VARBINARY" -> rs.getBytes(name) + "TINYBLOB"-> rs.getBytes(name) + "BLOB"-> rs.getBytes(name) + "MEDIUMBLOB" -> rs.getBytes(name) + "LONGBLOB" -> rs.getBytes(name) + "TEXT" -> rs.getString(name) + "MEDIUMTEXT" -> rs.getString(name) + "LONGTEXT" -> rs.getString(name) + "ENUM" -> rs.getString(name) + "SET" -> rs.getString(name) + else -> throw IllegalArgumentException("Unsupported MariaDB type: ${tableColumnMetadata.sqlTypeName}") + } + } + + override fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema { + return when (tableColumnMetadata.sqlTypeName) { + "BIT" -> ColumnSchema.Value(typeOf()) + "TINYINT" -> ColumnSchema.Value(typeOf()) + "SMALLINT" -> ColumnSchema.Value(typeOf()) + "MEDIUMINT"-> ColumnSchema.Value(typeOf()) + "MEDIUMINT UNSIGNED" -> ColumnSchema.Value(typeOf()) + "INTEGER", "INT" -> ColumnSchema.Value(typeOf()) + "INTEGER UNSIGNED", "INT UNSIGNED" -> ColumnSchema.Value(typeOf()) + "BIGINT" -> ColumnSchema.Value(typeOf()) + "FLOAT" -> ColumnSchema.Value(typeOf()) + "DOUBLE" -> ColumnSchema.Value(typeOf()) + "DECIMAL" -> ColumnSchema.Value(typeOf()) + "DATE" -> ColumnSchema.Value(typeOf()) + "DATETIME" -> ColumnSchema.Value(typeOf()) + "TIMESTAMP" -> ColumnSchema.Value(typeOf()) + "TIME"-> ColumnSchema.Value(typeOf()) + "YEAR" -> ColumnSchema.Value(typeOf()) + "VARCHAR", "CHAR" -> ColumnSchema.Value(typeOf()) + "BINARY" -> ColumnSchema.Value(typeOf()) + "VARBINARY" -> ColumnSchema.Value(typeOf()) + "TINYBLOB"-> ColumnSchema.Value(typeOf()) + "BLOB"-> ColumnSchema.Value(typeOf()) + "MEDIUMBLOB" -> ColumnSchema.Value(typeOf()) + "LONGBLOB" -> ColumnSchema.Value(typeOf()) + "TEXT" -> ColumnSchema.Value(typeOf()) + "MEDIUMTEXT" -> ColumnSchema.Value(typeOf()) + "LONGTEXT" -> ColumnSchema.Value(typeOf()) + "ENUM" -> ColumnSchema.Value(typeOf()) + "SET" -> ColumnSchema.Value(typeOf()) + else -> throw IllegalArgumentException("Unsupported MariaDB type: ${tableColumnMetadata.sqlTypeName} for column ${tableColumnMetadata.name}") + } + } + + override fun isSystemTable(tableMetadata: TableMetadata): Boolean { + return MySql.isSystemTable(tableMetadata) + } + + override fun buildTableMetadata(tables: ResultSet): TableMetadata { + return TableMetadata( + tables.getString("table_name"), + tables.getString("table_schem"), + tables.getString("table_cat")) + } +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MySql.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MySql.kt new file mode 100644 index 0000000000..6fc899254a --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MySql.kt @@ -0,0 +1,116 @@ +package org.jetbrains.kotlinx.dataframe.io.db + +import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata +import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema +import java.sql.ResultSet +import java.util.Locale +import org.jetbrains.kotlinx.dataframe.DataRow +import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup +import org.jetbrains.kotlinx.dataframe.io.TableMetadata +import kotlin.reflect.typeOf + +/** + * Represents the MySql database type. + * + * This class provides methods to convert data from a ResultSet to the appropriate type for MySql, + * and to generate the corresponding column schema. + */ +public object MySql : DbType("mysql") { + override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? { + val name = tableColumnMetadata.name + return when (tableColumnMetadata.sqlTypeName) { + "BIT" -> rs.getBytes(name) + "TINYINT" -> rs.getInt(name) + "SMALLINT" -> rs.getInt(name) + "MEDIUMINT"-> rs.getInt(name) + "MEDIUMINT UNSIGNED" -> rs.getLong(name) + "INTEGER", "INT" -> rs.getInt(name) + "INTEGER UNSIGNED", "INT UNSIGNED" -> rs.getLong(name) + "BIGINT" -> rs.getLong(name) + "FLOAT" -> rs.getFloat(name) + "DOUBLE" -> rs.getDouble(name) + "DECIMAL" -> rs.getBigDecimal(name) + "DATE" -> rs.getDate(name).toString() + "DATETIME" -> rs.getTimestamp(name).toString() + "TIMESTAMP" -> rs.getTimestamp(name).toString() + "TIME"-> rs.getTime(name).toString() + "YEAR" -> rs.getDate(name).toString() + "VARCHAR", "CHAR" -> rs.getString(name) + "BINARY" -> rs.getBytes(name) + "VARBINARY" -> rs.getBytes(name) + "TINYBLOB"-> rs.getBytes(name) + "BLOB"-> rs.getBytes(name) + "MEDIUMBLOB" -> rs.getBytes(name) + "LONGBLOB" -> rs.getBytes(name) + "TEXT" -> rs.getString(name) + "MEDIUMTEXT" -> rs.getString(name) + "LONGTEXT" -> rs.getString(name) + "ENUM" -> rs.getString(name) + "SET" -> rs.getString(name) + // special mysql types + "JSON" -> rs.getString(name) // TODO: https://github.com/Kotlin/dataframe/issues/462 + "GEOMETRY" -> rs.getBytes(name) + else -> throw IllegalArgumentException("Unsupported MySQL type: ${tableColumnMetadata.sqlTypeName}") + } + } + + override fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema { + return when (tableColumnMetadata.sqlTypeName) { + "BIT" -> ColumnSchema.Value(typeOf()) + "TINYINT" -> ColumnSchema.Value(typeOf()) + "SMALLINT" -> ColumnSchema.Value(typeOf()) + "MEDIUMINT"-> ColumnSchema.Value(typeOf()) + "MEDIUMINT UNSIGNED" -> ColumnSchema.Value(typeOf()) + "INTEGER", "INT" -> ColumnSchema.Value(typeOf()) + "INTEGER UNSIGNED", "INT UNSIGNED" -> ColumnSchema.Value(typeOf()) + "BIGINT" -> ColumnSchema.Value(typeOf()) + "FLOAT" -> ColumnSchema.Value(typeOf()) + "DOUBLE" -> ColumnSchema.Value(typeOf()) + "DECIMAL" -> ColumnSchema.Value(typeOf()) + "DATE" -> ColumnSchema.Value(typeOf()) + "DATETIME" -> ColumnSchema.Value(typeOf()) + "TIMESTAMP" -> ColumnSchema.Value(typeOf()) + "TIME"-> ColumnSchema.Value(typeOf()) + "YEAR" -> ColumnSchema.Value(typeOf()) + "VARCHAR", "CHAR" -> ColumnSchema.Value(typeOf()) + "BINARY" -> ColumnSchema.Value(typeOf()) + "VARBINARY" -> ColumnSchema.Value(typeOf()) + "TINYBLOB"-> ColumnSchema.Value(typeOf()) + "BLOB"-> ColumnSchema.Value(typeOf()) + "MEDIUMBLOB" -> ColumnSchema.Value(typeOf()) + "LONGBLOB" -> ColumnSchema.Value(typeOf()) + "TEXT" -> ColumnSchema.Value(typeOf()) + "MEDIUMTEXT" -> ColumnSchema.Value(typeOf()) + "LONGTEXT" -> ColumnSchema.Value(typeOf()) + "ENUM" -> ColumnSchema.Value(typeOf()) + "SET" -> ColumnSchema.Value(typeOf()) + // special mysql types + "JSON" -> ColumnSchema.Value(typeOf>>()) // TODO: https://github.com/Kotlin/dataframe/issues/462 + "GEOMETRY" -> ColumnSchema.Value(typeOf()) + else -> throw IllegalArgumentException("Unsupported MySQL type: ${tableColumnMetadata.sqlTypeName} for column ${tableColumnMetadata.name}") + } + } + + override fun isSystemTable(tableMetadata: TableMetadata): Boolean { + val locale = Locale.getDefault() + + fun String?.containsWithLowercase(substr: String) = this?.lowercase(locale)?.contains(substr) == true + + val schemaName = tableMetadata.schemaName + val name = tableMetadata.name + + return schemaName.containsWithLowercase("information_schema") + || tableMetadata.catalogue.containsWithLowercase("performance_schema") + || tableMetadata.catalogue.containsWithLowercase("mysql") + || schemaName?.contains("mysql.") == true + || name.contains("mysql.") + || name.contains("sys_config") + } + + override fun buildTableMetadata(tables: ResultSet): TableMetadata { + return TableMetadata( + tables.getString("table_name"), + tables.getString("table_schem"), + tables.getString("table_cat")) + } +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/PostgreSql.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/PostgreSql.kt new file mode 100644 index 0000000000..3aa0808559 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/PostgreSql.kt @@ -0,0 +1,103 @@ +package org.jetbrains.kotlinx.dataframe.io.db + +import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata +import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema +import java.sql.ResultSet +import java.util.Locale +import org.jetbrains.kotlinx.dataframe.DataRow +import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup +import org.jetbrains.kotlinx.dataframe.io.TableMetadata +import kotlin.reflect.typeOf + +/** + * Represents the PostgreSql database type. + * + * This class provides methods to convert data from a ResultSet to the appropriate type for PostgreSql, + * and to generate the corresponding column schema. + */ +public object PostgreSql : DbType("postgresql") { + override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? { + val name = tableColumnMetadata.name + return when (tableColumnMetadata.sqlTypeName) { + "serial" -> rs.getInt(name) + "int8", "bigint", "bigserial" -> rs.getLong(name) + "bool" -> rs.getBoolean(name) + "box" -> rs.getString(name) + "bytea" -> rs.getBytes(name) + "character", "bpchar" -> rs.getString(name) + "circle" -> rs.getString(name) + "date" -> rs.getDate(name).toString() + "float8", "double precision" -> rs.getDouble(name) + "int4", "integer" -> rs.getInt(name) + "interval" -> rs.getString(name) + "json", "jsonb" -> rs.getString(name) // TODO: https://github.com/Kotlin/dataframe/issues/462 + "line" -> rs.getString(name) + "lseg" -> rs.getString(name) + "macaddr" -> rs.getString(name) + "money" -> rs.getString(name) + "numeric" -> rs.getString(name) + "path" -> rs.getString(name) + "point" -> rs.getString(name) + "polygon" -> rs.getString(name) + "float4", "real" -> rs.getFloat(name) + "int2", "smallint" -> rs.getShort(name) + "smallserial" -> rs.getInt(name) + "text" -> rs.getString(name) + "time" -> rs.getString(name) + "timetz", "time with time zone" -> rs.getString(name) + "timestamp" -> rs.getString(name) + "timestamptz", "timestamp with time zone" -> rs.getString(name) + "uuid" -> rs.getString(name) + "xml" -> rs.getString(name) + else -> throw IllegalArgumentException("Unsupported PostgreSQL type: ${tableColumnMetadata.sqlTypeName}") + } + } + + override fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema { + return when (tableColumnMetadata.sqlTypeName) { + "serial" -> ColumnSchema.Value(typeOf()) + "int8", "bigint", "bigserial" -> ColumnSchema.Value(typeOf()) + "bool" -> ColumnSchema.Value(typeOf()) + "box" -> ColumnSchema.Value(typeOf()) + "bytea" -> ColumnSchema.Value(typeOf()) + "character", "bpchar" -> ColumnSchema.Value(typeOf()) + "circle" -> ColumnSchema.Value(typeOf()) + "date" -> ColumnSchema.Value(typeOf()) + "float8", "double precision" -> ColumnSchema.Value(typeOf()) + "int4", "integer" -> ColumnSchema.Value(typeOf()) + "interval" -> ColumnSchema.Value(typeOf()) + "json", "jsonb" -> ColumnSchema.Value(typeOf>>()) // TODO: https://github.com/Kotlin/dataframe/issues/462 + "line" -> ColumnSchema.Value(typeOf()) + "lseg" -> ColumnSchema.Value(typeOf()) + "macaddr" -> ColumnSchema.Value(typeOf()) + "money" -> ColumnSchema.Value(typeOf()) + "numeric" -> ColumnSchema.Value(typeOf()) + "path" -> ColumnSchema.Value(typeOf()) + "point" -> ColumnSchema.Value(typeOf()) + "polygon" -> ColumnSchema.Value(typeOf()) + "float4", "real" -> ColumnSchema.Value(typeOf()) + "int2", "smallint" -> ColumnSchema.Value(typeOf()) + "smallserial" -> ColumnSchema.Value(typeOf()) + "text" -> ColumnSchema.Value(typeOf()) + "time" -> ColumnSchema.Value(typeOf()) + "timetz", "time with time zone" -> ColumnSchema.Value(typeOf()) + "timestamp" -> ColumnSchema.Value(typeOf()) + "timestamptz", "timestamp with time zone" -> ColumnSchema.Value(typeOf()) + "uuid" -> ColumnSchema.Value(typeOf()) + "xml" -> ColumnSchema.Value(typeOf()) + else -> throw IllegalArgumentException("Unsupported PostgreSQL type: ${tableColumnMetadata.sqlTypeName} for column ${tableColumnMetadata.name}") + } + } + + override fun isSystemTable(tableMetadata: TableMetadata): Boolean { + return tableMetadata.name.lowercase(Locale.getDefault()).contains("pg_") + || tableMetadata.schemaName?.lowercase(Locale.getDefault())?.contains("pg_catalog.") ?: false + } + + override fun buildTableMetadata(tables: ResultSet): TableMetadata { + return TableMetadata( + tables.getString("table_name"), + tables.getString("table_schem"), + tables.getString("table_cat")) + } +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/Sqlite.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/Sqlite.kt new file mode 100644 index 0000000000..28b47729bf --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/Sqlite.kt @@ -0,0 +1,49 @@ +package org.jetbrains.kotlinx.dataframe.io.db + +import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata +import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema +import java.sql.ResultSet +import org.jetbrains.kotlinx.dataframe.io.TableMetadata +import kotlin.reflect.typeOf + +/** + * Represents the Sqlite database type. + * + * This class provides methods to convert data from a ResultSet to the appropriate type for Sqlite, + * and to generate the corresponding column schema. + */ +public object Sqlite : DbType("sqlite") { + override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? { + val name = tableColumnMetadata.name + return when (tableColumnMetadata.sqlTypeName) { + "INTEGER", "INTEGER AUTO_INCREMENT" -> rs.getInt(name) + "TEXT" -> rs.getString(name) + "REAL" -> rs.getDouble(name) + "NUMERIC" -> rs.getDouble(name) + "BLOB" -> rs.getBytes(name) + else -> throw IllegalArgumentException("Unsupported SQLite type: ${tableColumnMetadata.sqlTypeName}") + } + } + + override fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema { + return when (tableColumnMetadata.sqlTypeName) { + "INTEGER", "INTEGER AUTO_INCREMENT" -> ColumnSchema.Value(typeOf()) + "TEXT" -> ColumnSchema.Value(typeOf()) + "REAL" -> ColumnSchema.Value(typeOf()) + "NUMERIC" -> ColumnSchema.Value(typeOf()) + "BLOB" -> ColumnSchema.Value(typeOf()) + else -> throw IllegalArgumentException("Unsupported SQLite type: ${tableColumnMetadata.sqlTypeName} for column ${tableColumnMetadata.name}") + } + } + + override fun isSystemTable(tableMetadata: TableMetadata): Boolean { + return tableMetadata.name.startsWith("sqlite_") + } + + override fun buildTableMetadata(tables: ResultSet): TableMetadata { + return TableMetadata( + tables.getString("TABLE_NAME"), + tables.getString("TABLE_SCHEM"), + tables.getString("TABLE_CAT")) + } +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt new file mode 100644 index 0000000000..794247a30f --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt @@ -0,0 +1,26 @@ +package org.jetbrains.kotlinx.dataframe.io.db + +import java.sql.SQLException + +/** + * Extracts the database type from the given JDBC URL. + * + * @param [url] the JDBC URL. + * @return the corresponding [DbType]. + * @throws RuntimeException if the url is null. + */ +public fun extractDBTypeFromURL(url: String?): DbType { + if (url != null) { + return when { + H2.dbTypeInJdbcUrl in url -> H2 + MariaDb.dbTypeInJdbcUrl in url -> MariaDb + MySql.dbTypeInJdbcUrl in url -> MySql + Sqlite.dbTypeInJdbcUrl in url -> Sqlite + PostgreSql.dbTypeInJdbcUrl in url -> PostgreSql + else -> throw IllegalArgumentException("Unsupported database type in the url: $url. " + + "Only H2, MariaDB, MySQL, SQLite and PostgreSQL are supported!") + } + } else { + throw SQLException("Database URL could not be null. The existing value is $url") + } +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/jdbcSchema.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/jdbcSchema.kt new file mode 100644 index 0000000000..c599dafe57 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/jdbcSchema.kt @@ -0,0 +1,31 @@ +package org.jetbrains.kotlinx.dataframe.io + +import org.jetbrains.dataframe.impl.codeGen.CodeGenerator +import org.jetbrains.kotlinx.dataframe.annotations.DataSchema +import org.jetbrains.kotlinx.dataframe.impl.codeGen.CodeGenerationReadResult +import org.jetbrains.kotlinx.jupyter.api.Code +import java.net.URL + +// TODO: helper functions created to support existing hierarchy https://github.com/Kotlin/dataframe/issues/450 +public val CodeGenerator.Companion.databaseCodeGenReader: ( + url: URL, + name: String +) -> CodeGenerationReadResult + get() = { url, name -> + try { + val code = buildCodeForDB(url, name) + throw RuntimeException() + CodeGenerationReadResult.Success(code, Jdbc()) + } catch (e: Throwable) { + CodeGenerationReadResult.Error(e) + } + } + +public fun buildCodeForDB(url: URL, name: String): Code { + val annotationName = DataSchema::class.simpleName + val visibility = "public " + val propertyVisibility = "public " + + val declarations = mutableListOf() + return declarations.joinToString() +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt new file mode 100644 index 0000000000..108271cea6 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -0,0 +1,597 @@ +package org.jetbrains.kotlinx.dataframe.io + +import io.github.oshai.kotlinlogging.KotlinLogging +import java.sql.Connection +import java.sql.DatabaseMetaData +import java.sql.DriverManager +import java.sql.ResultSet +import java.sql.ResultSetMetaData +import org.jetbrains.kotlinx.dataframe.AnyFrame +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.api.toDataFrame +import org.jetbrains.kotlinx.dataframe.impl.schema.DataFrameSchemaImpl +import org.jetbrains.kotlinx.dataframe.io.db.DbType +import org.jetbrains.kotlinx.dataframe.io.db.H2 +import org.jetbrains.kotlinx.dataframe.io.db.Sqlite +import org.jetbrains.kotlinx.dataframe.io.db.extractDBTypeFromURL +import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema + +private val logger = KotlinLogging.logger {} + +/** + * The default limit value. + * + * This constant represents the default limit value to be used in cases where no specific limit + * is provided. + * + * @see Int.MIN_VALUE + */ +private const val DEFAULT_LIMIT = Int.MIN_VALUE + +/** + * Represents a column in a database table to keep all required meta-information. + * + * @property [name] the name of the column. + * @property [sqlTypeName] the SQL data type of the column. + * @property [jdbcType] the JDBC data type of the column produced from [java.sql.Types]. + * @property [size] the size of the column. + */ +public data class TableColumnMetadata(val name: String, val sqlTypeName: String, val jdbcType: Int, val size: Int) + +/** + * Represents a table metadata to store information about a database table, + * including its name, schema name, and catalogue name. + * + * NOTE: we need to extract both, [schemaName] and [catalogue] + * because the different databases have different implementations of metadata. + * + * @property [name] the name of the table. + * @property [schemaName] the name of the schema the table belongs to (optional). + * @property [catalogue] the name of the catalogue the table belongs to (optional). + */ +public data class TableMetadata(val name: String, val schemaName: String?, val catalogue: String?) + +/** + * Represents the configuration for a database connection. + * + * @property [url] the URL of the database. Keep it in the following form jdbc:subprotocol:subnam + * @property [user] the username used for authentication (optional, default is empty string). + * @property [password] the password used for authentication (optional, default is empty string). + */ +public data class DatabaseConfiguration(val url: String, val user: String = "", val password: String = "") + +/** + * Reads data from an SQL table and converts it into a DataFrame. + * + * @param [dbConfig] the configuration for the database, including URL, user, and password. + * @param [tableName] the name of the table to read data from. + * @return the DataFrame containing the data from the SQL table. + */ +public fun DataFrame.Companion.readSqlTable(dbConfig: DatabaseConfiguration, tableName: String): AnyFrame { + DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> + return readSqlTable(connection, tableName, DEFAULT_LIMIT) + } +} + +/** + * Reads data from an SQL table and converts it into a DataFrame. + * + * @param [dbConfig] the configuration for the database, including URL, user, and password. + * @param [tableName] the name of the table to read data from. + * @param [limit] the maximum number of rows to retrieve from the table. + * @return the DataFrame containing the data from the SQL table. + */ +public fun DataFrame.Companion.readSqlTable(dbConfig: DatabaseConfiguration, tableName: String, limit: Int): AnyFrame { + DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> + return readSqlTable(connection, tableName, limit) + } +} + +/** + * Reads data from an SQL table and converts it into a DataFrame. + * + * @param [connection] the database connection to read tables from. + * @param [tableName] the name of the table to read data from. + * @return the DataFrame containing the data from the SQL table. + * + * @see DriverManager.getConnection + */ +public fun DataFrame.Companion.readSqlTable(connection: Connection, tableName: String): AnyFrame { + return readSqlTable(connection, tableName, DEFAULT_LIMIT) +} + +/** + * Reads data from an SQL table and converts it into a DataFrame. + * + * @param [connection] the database connection to read tables from. + * @param [tableName] the name of the table to read data from. + * @param [limit] the maximum number of rows to retrieve from the table. + * @return the DataFrame containing the data from the SQL table. + * + * @see DriverManager.getConnection + */ +public fun DataFrame.Companion.readSqlTable(connection: Connection, tableName: String, limit: Int): AnyFrame { + var preparedQuery = "SELECT * FROM $tableName" + if (limit > 0) preparedQuery += " LIMIT $limit" + + val url = connection.metaData.url + val dbType = extractDBTypeFromURL(url) + + connection.createStatement().use { st -> + logger.debug { "Connection with url:${url} is established successfully." } + val tableColumns = getTableColumnsMetadata(connection, tableName) + + st.executeQuery( + preparedQuery + ).use { rs -> + val data = fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit) + return data.toDataFrame() + } + } +} + +/** + * Converts the result of an SQL query to the DataFrame. + * + * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. + * @param [sqlQuery] the SQL query to execute. + * @return the DataFrame containing the result of the SQL query. + */ +public fun DataFrame.Companion.readSqlQuery(dbConfig: DatabaseConfiguration, sqlQuery: String): AnyFrame { + DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> + return readSqlQuery(connection, sqlQuery, DEFAULT_LIMIT) + } +} + +/** + * Converts the result of an SQL query to the DataFrame. + * + * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. + * @param [sqlQuery] the SQL query to execute. + * @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution. + * @return the DataFrame containing the result of the SQL query. + */ +public fun DataFrame.Companion.readSqlQuery(dbConfig: DatabaseConfiguration, sqlQuery: String, limit: Int): AnyFrame { + DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> + return readSqlQuery(connection, sqlQuery, limit) + } +} + +/** + * Converts the result of an SQL query to the DataFrame. + * + * @param [connection] the database connection to execute the SQL query. + * @param [sqlQuery] the SQL query to execute. + * @return the DataFrame containing the result of the SQL query. + * + * @see DriverManager.getConnection + */ +public fun DataFrame.Companion.readSqlQuery(connection: Connection, sqlQuery: String): AnyFrame { + return readSqlQuery(connection, sqlQuery, DEFAULT_LIMIT) +} + +/** + * Converts the result of an SQL query to the DataFrame. + * + * @param [connection] the database connection to execute the SQL query. + * @param [sqlQuery] the SQL query to execute. + * @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution. + * @return the DataFrame containing the result of the SQL query. + * + * @see DriverManager.getConnection + */ +public fun DataFrame.Companion.readSqlQuery(connection: Connection, sqlQuery: String, limit: Int): AnyFrame { + val url = connection.metaData.url + val dbType = extractDBTypeFromURL(url) + + var internalSqlQuery = sqlQuery + if (limit > 0) internalSqlQuery += " LIMIT $limit" + + logger.debug { "Executing SQL query: $internalSqlQuery" } + + connection.createStatement().use { st -> + st.executeQuery(internalSqlQuery).use { rs -> + val tableColumns = getTableColumnsMetadata(rs) + val data = fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, DEFAULT_LIMIT) + + logger.debug { "SQL query executed successfully. Converting data to DataFrame." } + + return data.toDataFrame() + } + } +} + +/** + * Reads the data from a [ResultSet] and converts it into a DataFrame. + * + * @param [resultSet] the ResultSet containing the data to read. + * @param [dbType] the type of database that the ResultSet belongs to. + * @return the DataFrame generated from the ResultSet data. + */ +public fun DataFrame.Companion.readResultSet(resultSet: ResultSet, dbType: DbType): AnyFrame { + return readResultSet(resultSet, dbType, DEFAULT_LIMIT) +} + +/** + * Reads the data from a ResultSet and converts it into a DataFrame. + * + * @param [resultSet] the ResultSet containing the data to read. + * @param [dbType] the type of database that the ResultSet belongs to. + * @param [limit] the maximum number of rows to read from the ResultSet. + * @return the DataFrame generated from the ResultSet data. + */ +public fun DataFrame.Companion.readResultSet(resultSet: ResultSet, dbType: DbType, limit: Int): AnyFrame { + val tableColumns = getTableColumnsMetadata(resultSet) + val data = fetchAndConvertDataFromResultSet(tableColumns, resultSet, dbType, limit) + return data.toDataFrame() +} + +/** + * Reads the data from a ResultSet and converts it into a DataFrame. + * + * @param [resultSet] the ResultSet containing the data to read. + * @param [connection] the connection to the database (it's required to extract the database type). + * @return the DataFrame generated from the ResultSet data. + */ +public fun DataFrame.Companion.readResultSet(resultSet: ResultSet, connection: Connection): AnyFrame { + return readResultSet(resultSet, connection, DEFAULT_LIMIT) +} + +/** + * Reads the data from a ResultSet and converts it into a DataFrame. + * + * @param [resultSet] the ResultSet containing the data to read. + * @param [connection] the connection to the database (it's required to extract the database type). + * @param [limit] the maximum number of rows to read from the ResultSet. + * @return the DataFrame generated from the ResultSet data. + */ +public fun DataFrame.Companion.readResultSet(resultSet: ResultSet, connection: Connection, limit: Int): AnyFrame { + val url = connection.metaData.url + val dbType = extractDBTypeFromURL(url) + + return readResultSet(resultSet, dbType, limit) +} + +/** + * Reads all non-system tables from a database and returns them as a list of data frames + * using the provided database configuration. + * + * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. + * @return a list of [AnyFrame] objects representing the non-system tables from the database. + */ +public fun DataFrame.Companion.readAllSqlTables(dbConfig: DatabaseConfiguration): List { + return readAllSqlTables(dbConfig, DEFAULT_LIMIT) +} + +/** + * Reads all tables from the given database using the provided database configuration and limit. + * + * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. + * @param [limit] the maximum number of rows to read from each table. + * @return a list of [AnyFrame] objects representing the non-system tables from the database. + */ +public fun DataFrame.Companion.readAllSqlTables(dbConfig: DatabaseConfiguration, limit: Int): List { + DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> + return readAllSqlTables(connection, limit) + } +} + +/** + * Reads all non-system tables from a database and returns them as a list of data frames. + * + * @param [connection] the database connection to read tables from. + * @return a list of [AnyFrame] objects representing the non-system tables from the database. + * + * @see DriverManager.getConnection + */ +public fun DataFrame.Companion.readAllSqlTables(connection: Connection): List { + return readAllSqlTables(connection, DEFAULT_LIMIT) +} + +/** + * Reads all non-system tables from a database and returns them as a list of data frames. + * + * @param [connection] the database connection to read tables from. + * @param [limit] the maximum number of rows to read from each table. + * @return a list of [AnyFrame] objects representing the non-system tables from the database. + * + * @see DriverManager.getConnection + */ +public fun DataFrame.Companion.readAllSqlTables(connection: Connection, limit: Int): List { + val metaData = connection.metaData + val url = connection.metaData.url + val dbType = extractDBTypeFromURL(url) + + // exclude a system and other tables without data, but it looks like it supported badly for many databases + val tables = metaData.getTables(null, null, null, arrayOf("TABLE")) + + val dataFrames = mutableListOf() + + while (tables.next()) { + val table = dbType.buildTableMetadata(tables) + if (!dbType.isSystemTable(table)) { + // we filter her second time because of specific logic with SQLite and possible issues with future databases + logger.debug { "Reading table: ${table.name}" } + val dataFrame = readSqlTable(connection, table.name, limit) + dataFrames += dataFrame + logger.debug { "Finished reading table: ${table.name}" } + } + } + + return dataFrames +} + +/** + * Retrieves the schema for an SQL table using the provided database configuration. + * + * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. + * @param [tableName] the name of the SQL table for which to retrieve the schema. + * @return the DataFrameSchema object representing the schema of the SQL table + */ +public fun DataFrame.Companion.getSchemaForSqlTable(dbConfig: DatabaseConfiguration, tableName: String): DataFrameSchema { + DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> + return getSchemaForSqlTable(connection, tableName) + } +} + +/** + * Retrieves the schema for an SQL table using the provided database connection. + * + * @param [connection] the database connection. + * @param [tableName] the name of the SQL table for which to retrieve the schema. + * @return the schema of the SQL table as a [DataFrameSchema] object. + * + * @see DriverManager.getConnection + */ +public fun DataFrame.Companion.getSchemaForSqlTable( + connection: Connection, + tableName: String +): DataFrameSchema { + val url = connection.metaData.url + val dbType = extractDBTypeFromURL(url) + + connection.createStatement().use { + logger.debug { "Connection with url:${connection.metaData.url} is established successfully." } + + val tableColumns = getTableColumnsMetadata(connection, tableName) + + return buildSchemaByTableColumns(tableColumns, dbType) + } +} + +/** + * Retrieves the schema of an SQL query result using the provided database configuration. + * + * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. + * @param [sqlQuery] the SQL query to execute and retrieve the schema from. + * @return the schema of the SQL query as a [DataFrameSchema] object. + */ +public fun DataFrame.Companion.getSchemaForSqlQuery(dbConfig: DatabaseConfiguration, sqlQuery: String): DataFrameSchema { + DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> + return getSchemaForSqlQuery(connection, sqlQuery) + } +} + +/** + * Retrieves the schema of an SQL query result using the provided database connection. + * + * @param [connection] the database connection. + * @param [sqlQuery] the SQL query to execute and retrieve the schema from. + * @return the schema of the SQL query as a [DataFrameSchema] object. + * + * @see DriverManager.getConnection + */ +public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQuery: String): DataFrameSchema { + val url = connection.metaData.url + val dbType = extractDBTypeFromURL(url) + + connection.createStatement().use { st -> + st.executeQuery(sqlQuery).use { rs -> + val tableColumns = getTableColumnsMetadata(rs) + return buildSchemaByTableColumns(tableColumns, dbType) + } + } +} + +/** + * Retrieves the schema from ResultSet. + * + * NOTE: This function will not close connection and result set and not retrieve data from the result set. + * + * @param [resultSet] the ResultSet obtained from executing a database query. + * @param [dbType] the type of database that the ResultSet belongs to. + * @return the schema of the ResultSet as a [DataFrameSchema] object. + */ +public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbType: DbType): DataFrameSchema { + val tableColumns = getTableColumnsMetadata(resultSet) + return buildSchemaByTableColumns(tableColumns, dbType) +} + +/** + * Retrieves the schema from ResultSet. + * + * NOTE: [connection] is required to extract the database type. + * This function will not close connection and result set and not retrieve data from the result set. + * + * @param [resultSet] the ResultSet obtained from executing a database query. + * @param [connection] the connection to the database (it's required to extract the database type). + * @return the schema of the ResultSet as a [DataFrameSchema] object. + */ +public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, connection: Connection): DataFrameSchema { + val url = connection.metaData.url + val dbType = extractDBTypeFromURL(url) + + val tableColumns = getTableColumnsMetadata(resultSet) + return buildSchemaByTableColumns(tableColumns, dbType) +} + +/** + * Retrieves the schema of all non-system tables in the database using the provided database configuration. + * + * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. + * @return a list of DataFrameSchema objects representing the schema of each non-system table. + */ +public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): List { + DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> + return getSchemaForAllSqlTables(connection) + } +} + +/** + * Retrieves the schema of all non-system tables in the database using the provided database connection. + * + * @param [connection] the database connection. + * @return a list of DataFrameSchema objects representing the schema of each non-system table. + */ +public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): List { + val metaData = connection.metaData + val url = connection.metaData.url + val dbType = extractDBTypeFromURL(url) + + val tableTypes = arrayOf("TABLE") + // exclude system and other tables without data + val tables = metaData.getTables(null, null, null, tableTypes) + + val dataFrameSchemas = mutableListOf() + + while (tables.next()) { + val jdbcTable = dbType.buildTableMetadata(tables) + if (!dbType.isSystemTable(jdbcTable)) { + // we filter her second time because of specific logic with SQLite and possible issues with future databases + val dataFrameSchema = getSchemaForSqlTable(connection, jdbcTable.name) + dataFrameSchemas += dataFrameSchema + } + } + + return dataFrameSchemas +} + +/** + * Builds a DataFrame schema based on the given table columns. + * + * @param [tableColumns] a mutable map containing the table columns, where the key represents the column name + * and the value represents the metadata of the column + * @param [dbType] the type of database. + * @return a DataFrameSchema object representing the schema built from the table columns. + */ +private fun buildSchemaByTableColumns(tableColumns: MutableMap, dbType: DbType): DataFrameSchema { + val schemaColumns = tableColumns.map { + Pair(it.key, dbType.toColumnSchema(it.value)) + }.toMap() + + return DataFrameSchemaImpl( + columns = schemaColumns + ) +} + +/** + * Retrieves the metadata of the columns in the result set. + * + * @param [rs] the result set + * @return a mutable map of column names to [TableColumnMetadata] objects, + * where each TableColumnMetadata object contains information such as the column type, + * JDBC type, size, and name. + */ +private fun getTableColumnsMetadata(rs: ResultSet): MutableMap { + val metaData: ResultSetMetaData = rs.metaData + val numberOfColumns: Int = metaData.columnCount + + val tableColumns = mutableMapOf() + + for (i in 1 until numberOfColumns + 1) { + val name = metaData.getColumnName(i) + val size = metaData.getColumnDisplaySize(i) + val type = metaData.getColumnTypeName(i) + val jdbcType = metaData.getColumnType(i) + + // TODO: + // add strategy for multiple columns handling (throw exception, ignore, + // create columns with additional indexes in name) + // column names should be unique + check(!tableColumns.containsKey(name)) { "Multiple columns with name $name from table ${metaData.getTableName(i)}. Rename columns to make it unique." } + + tableColumns += Pair(name, TableColumnMetadata(name, type, jdbcType, size)) + + } + return tableColumns +} + +/** + * Retrieves the metadata of columns for a given table. + * + * @param [connection] the database connection + * @param [tableName] the name of the table + * @return a mutable map of column names to [TableColumnMetadata] objects, + * where each TableColumnMetadata object contains information such as the column type, + * JDBC type, size, and name. + */ +private fun getTableColumnsMetadata(connection: Connection, tableName: String): MutableMap { + val dbMetaData: DatabaseMetaData = connection.metaData + val columns: ResultSet = dbMetaData.getColumns(null, null, tableName, null) + val tableColumns = mutableMapOf() + + while (columns.next()) { + val name = columns.getString("COLUMN_NAME") + val type = columns.getString("TYPE_NAME") + val jdbcType = columns.getInt("DATA_TYPE") + val size = columns.getInt("COLUMN_SIZE") + tableColumns += Pair(name, TableColumnMetadata(name, type, jdbcType, size)) + } + return tableColumns +} + +/** + * Fetches and converts data from a ResultSet into a mutable map. + * + * @param [tableColumns] a map containing the column metadata for the table. + * @param [rs] the ResultSet object containing the data to be fetched and converted. + * @param [dbType] the type of the database. + * @param [limit] the maximum number of rows to fetch and convert. + * @return A mutable map containing the fetched and converted data. + */ +private fun fetchAndConvertDataFromResultSet( + tableColumns: MutableMap, + rs: ResultSet, + dbType: DbType, + limit: Int +): MutableMap> { + // map + val data = mutableMapOf>() + + // init data + tableColumns.forEach { (columnName, _) -> + data[columnName] = mutableListOf() + } + + var counter = 0 + + if (limit > 0) { + while (rs.next() && counter < limit) { + handleRow(tableColumns, data, dbType, rs) + counter++ + // if (counter % 1000 == 0) logger.debug { "Loaded $counter rows." } // TODO: https://github.com/Kotlin/dataframe/issues/455 + } + } else { + while (rs.next()) { + handleRow(tableColumns, data, dbType, rs) + counter++ + // if (counter % 1000 == 0) logger.debug { "Loaded $counter rows." } // TODO: https://github.com/Kotlin/dataframe/issues/455 + } + } + + return data +} + +private fun handleRow( + tableColumns: MutableMap, + data: MutableMap>, + dbType: DbType, + rs: ResultSet +) { + tableColumns.forEach { (columnName, jdbcColumn) -> + data[columnName] = (data[columnName]!! + dbType.convertDataFromResultSet(rs, jdbcColumn)).toMutableList() + } +} + + + diff --git a/dataframe-jdbc/src/main/resources/META-INF/services/org.jetbrains.kotlinx.dataframe.io.SupportedFormat b/dataframe-jdbc/src/main/resources/META-INF/services/org.jetbrains.kotlinx.dataframe.io.SupportedFormat new file mode 100644 index 0000000000..7698e519eb --- /dev/null +++ b/dataframe-jdbc/src/main/resources/META-INF/services/org.jetbrains.kotlinx.dataframe.io.SupportedFormat @@ -0,0 +1 @@ +org.jetbrains.kotlinx.dataframe.io.Jdbc diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt new file mode 100644 index 0000000000..e73656320c --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt @@ -0,0 +1,465 @@ +package org.jetbrains.kotlinx.dataframe.io + +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.shouldBe +import org.h2.jdbc.JdbcSQLSyntaxErrorException +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.DataSchema +import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.filter +import org.jetbrains.kotlinx.dataframe.io.db.H2 +import org.junit.AfterClass +import org.junit.BeforeClass +import org.junit.Test +import java.sql.Connection +import java.sql.DriverManager +import java.sql.ResultSet +import java.sql.SQLException +import org.jetbrains.kotlinx.dataframe.DataRow +import org.jetbrains.kotlinx.dataframe.api.print +import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup +import kotlin.reflect.typeOf + +private const val URL = "jdbc:h2:mem:test;DB_CLOSE_DELAY=-1;MODE=MySQL;DATABASE_TO_UPPER=false" + +@DataSchema +interface Customer { + val id: Int + val name: String + val age: Int +} + +@DataSchema +interface Sale { + val id: Int + val customerId: Int + val amount: Double +} + +@DataSchema +interface CustomerSales { + val customerName: String + val totalSalesAmount: Double +} + +@DataSchema +interface TestTableData { + val characterCol: String + val characterVaryingCol: String + val characterLargeObjectCol: String + val mediumTextCol: String + val varcharIgnoreCaseCol: String + val binaryCol: ByteArray + val binaryVaryingCol: ByteArray + val binaryLargeObjectCol: ByteArray + val booleanCol: Boolean + val tinyIntCol: Byte + val smallIntCol: Short + val integerCol: Int + val bigIntCol: Long + val numericCol: Double + val realCol: Float + val doublePrecisionCol: Double + val decFloatCol: Double + val dateCol: String + val timeCol: String + val timeWithTimeZoneCol: String + val timestampCol: String + val timestampWithTimeZoneCol: String + val intervalCol: String + val javaObjectCol: Any? + val enumCol: String + val jsonCol: String + val uuidCol: String +} + +class JdbcTest { + companion object { + private lateinit var connection: Connection + + @BeforeClass + @JvmStatic + fun setUpClass() { + connection = + DriverManager.getConnection(URL) + + + // Crate table Customer + @Language("SQL") + val createCustomerTableQuery = """ + CREATE TABLE Customer ( + id INT PRIMARY KEY, + name VARCHAR(50), + age INT + ) + """ + + connection.createStatement().execute(createCustomerTableQuery) + + // Create table Sale + @Language("SQL") + val createSaleTableQuery = """ + CREATE TABLE Sale ( + id INT PRIMARY KEY, + customerId INT, + amount DECIMAL(10, 2) + ) + """ + + connection.createStatement().execute( + createSaleTableQuery + ) + + // add data to the Customer table + connection.createStatement().execute("INSERT INTO Customer (id, name, age) VALUES (1, 'John', 40)") + connection.createStatement().execute("INSERT INTO Customer (id, name, age) VALUES (2, 'Alice', 25)") + connection.createStatement().execute("INSERT INTO Customer (id, name, age) VALUES (3, 'Bob', 47)") + connection.createStatement().execute("INSERT INTO Customer (id, name, age) VALUES (4, NULL, NULL)") + + // add data to the Sale table + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (1, 1, 100.50)") + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (2, 2, 50.00)") + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (3, 1, 75.25)") + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (4, 3, 35.15)") + } + + @AfterClass + @JvmStatic + fun tearDownClass() { + try { + connection.close() + } catch (e: SQLException) { + e.printStackTrace() + } + } + } + + @Test + fun `read from huge table`() { + @Language("SQL") + val createTableQuery = """ + CREATE TABLE TestTable ( + characterCol CHAR(10), + characterVaryingCol VARCHAR(20), + characterLargeObjectCol CLOB, + mediumTextCol CLOB, + varcharIgnoreCaseCol VARCHAR_IGNORECASE(30), + binaryCol BINARY(8), + binaryVaryingCol VARBINARY(16), + binaryLargeObjectCol BLOB, + booleanCol BOOLEAN, + tinyIntCol TINYINT, + smallIntCol SMALLINT, + integerCol INT, + bigIntCol BIGINT, + numericCol NUMERIC(10, 2), + realCol REAL, + doublePrecisionCol DOUBLE PRECISION, + decFloatCol DECFLOAT(16), + dateCol DATE, + timeCol TIME, + timeWithTimeZoneCol TIME WITH TIME ZONE, + timestampCol TIMESTAMP, + timestampWithTimeZoneCol TIMESTAMP WITH TIME ZONE, + javaObjectCol OBJECT, + enumCol VARCHAR(10), + jsonCol JSON, + uuidCol UUID + ) + """ + + connection.createStatement().execute(createTableQuery.trimIndent()) + + connection.prepareStatement( + """ + INSERT INTO TestTable VALUES ( + 'ABC', 'XYZ', 'Long text data for CLOB', 'Medium text data for CLOB', + 'Varchar IgnoreCase', X'010203', X'040506', X'070809', + TRUE, 1, 100, 1000, 100000, + 123.45, 1.23, 3.14, 2.71, + '2023-07-20', '08:30:00', '18:15:00', '2023-07-19 12:45:30', + '2023-07-18 12:45:30', NULL, + 'Option1', '{"key": "value"}', '123e4567-e89b-12d3-a456-426655440000' + ) + """.trimIndent() + ).executeUpdate() + + connection.prepareStatement( + """ + INSERT INTO TestTable VALUES ( + 'DEF', 'LMN', 'Another CLOB data', 'Different CLOB data', + 'Another Varchar', X'101112', X'131415', X'161718', + FALSE, 2, 200, 2000, 200000, + 234.56, 2.34, 4.56, 3.14, + '2023-07-21', '14:30:00', '22:45:00', '2023-07-20 18:15:30', + '2023-07-19 18:15:30', NULL, + 'Option2', '{"key": "another_value"}', '234e5678-e89b-12d3-a456-426655440001' + ) + """.trimIndent() + ).executeUpdate() + + connection.prepareStatement( + """ + INSERT INTO TestTable VALUES ( + 'GHI', 'OPQ', 'Third CLOB entry', 'Yet another CLOB data', + 'Yet Another Varchar', X'192021', X'222324', X'252627', + TRUE, 3, 300, 3000, 300000, + 345.67, 3.45, 5.67, 4.71, + '2023-07-22', '20:45:00', '03:30:00', '2023-07-21 23:45:15', + '2023-07-20 23:45:15', NULL, + 'Option3', '{ "person": { "name": "John Doe", "age": 30 }, ' || + '"address": { "street": "123 Main St", "city": "Exampleville", "zipcode": "12345"}}', + '345e6789-e89b-12d3-a456-426655440002' + ) + """.trimIndent() + ).executeUpdate() + + val df = DataFrame.readSqlTable(connection, "TestTable").cast() + df.rowsCount() shouldBe 3 + df.filter { it[TestTableData::integerCol] > 1000}.rowsCount() shouldBe 2 + } + + @Test + fun `read from table`() { + val tableName = "Customer" + val df = DataFrame.readSqlTable(connection, tableName).cast() + + df.rowsCount() shouldBe 4 + df.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 2 + df[0][1] shouldBe "John" + + val df1 = DataFrame.readSqlTable(connection, tableName, 1).cast() + + df1.rowsCount() shouldBe 1 + df1.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 1 + df1[0][1] shouldBe "John" + + val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName) + dataSchema.columns.size shouldBe 3 + dataSchema.columns["name"]!!.type shouldBe typeOf() + + val dbConfig = DatabaseConfiguration(url = URL) + val df2 = DataFrame.readSqlTable(dbConfig, tableName).cast() + + df2.rowsCount() shouldBe 4 + df2.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 2 + df2[0][1] shouldBe "John" + + val df3 = DataFrame.readSqlTable(dbConfig, tableName, 1).cast() + + df3.rowsCount() shouldBe 1 + df3.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 1 + df3[0][1] shouldBe "John" + + val dataSchema1 = DataFrame.getSchemaForSqlTable(dbConfig, tableName) + dataSchema1.columns.size shouldBe 3 + dataSchema.columns["name"]!!.type shouldBe typeOf() + } + + @Test + fun `read from ResultSet`() { + connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st -> + @Language("SQL") + val selectStatement = "SELECT * FROM Customer" + + st.executeQuery(selectStatement).use { rs -> + val df = DataFrame.readResultSet(rs, H2).cast() + + df.rowsCount() shouldBe 4 + df.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 2 + df[0][1] shouldBe "John" + + rs.beforeFirst() + + val df1 = DataFrame.readResultSet(rs, H2, 1).cast() + + df1.rowsCount() shouldBe 1 + df1.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 1 + df1[0][1] shouldBe "John" + + rs.beforeFirst() + + val dataSchema = DataFrame.getSchemaForResultSet(rs, H2) + dataSchema.columns.size shouldBe 3 + dataSchema.columns["name"]!!.type shouldBe typeOf() + + rs.beforeFirst() + + val df2 = DataFrame.readResultSet(rs, connection).cast() + + df2.rowsCount() shouldBe 4 + df2.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 2 + df2[0][1] shouldBe "John" + + rs.beforeFirst() + + val df3 = DataFrame.readResultSet(rs, connection, 1).cast() + + df3.rowsCount() shouldBe 1 + df3.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 1 + df3[0][1] shouldBe "John" + + rs.beforeFirst() + + val dataSchema1 = DataFrame.getSchemaForResultSet(rs, connection) + dataSchema1.columns.size shouldBe 3 + dataSchema.columns["name"]!!.type shouldBe typeOf() + } + } + } + + @Test + fun `read from non-existing table`() { + shouldThrow { + DataFrame.readSqlTable(connection, "WrongTableName").cast() + } + } + + @Test + fun `read from non-existing jdbc url`() { + shouldThrow { + DataFrame.readSqlTable(DriverManager.getConnection("ddd"), "WrongTableName") + } + } + + @Test + fun `read from sql query`() { + @Language("SQL") + val sqlQuery = """ + SELECT c.name as customerName, SUM(s.amount) as totalSalesAmount + FROM Sale s + INNER JOIN Customer c ON s.customerId = c.id + WHERE c.age > 35 + GROUP BY s.customerId, c.name + """.trimIndent() + + val df = DataFrame.readSqlQuery(connection, sqlQuery).cast() + + df.rowsCount() shouldBe 2 + df.filter { it[CustomerSales::totalSalesAmount] > 100 }.rowsCount() shouldBe 1 + df[0][0] shouldBe "John" + + val df1 = DataFrame.readSqlQuery(connection, sqlQuery, 1).cast() + + df1.rowsCount() shouldBe 1 + df1.filter { it[CustomerSales::totalSalesAmount] > 100 }.rowsCount() shouldBe 1 + df1[0][0] shouldBe "John" + + val dataSchema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery) + dataSchema.columns.size shouldBe 2 + dataSchema.columns["name"]!!.type shouldBe typeOf() + + val dbConfig = DatabaseConfiguration(url = URL) + val df2 = DataFrame.readSqlQuery(dbConfig, sqlQuery).cast() + + df2.rowsCount() shouldBe 2 + df2.filter { it[CustomerSales::totalSalesAmount] > 100 }.rowsCount() shouldBe 1 + df2[0][0] shouldBe "John" + + val df3 = DataFrame.readSqlQuery(dbConfig, sqlQuery, 1).cast() + + df3.rowsCount() shouldBe 1 + df3.filter { it[CustomerSales::totalSalesAmount] > 100 }.rowsCount() shouldBe 1 + df3[0][0] shouldBe "John" + + val dataSchema1 = DataFrame.getSchemaForSqlQuery(dbConfig, sqlQuery) + dataSchema1.columns.size shouldBe 2 + dataSchema.columns["name"]!!.type shouldBe typeOf() + } + + @Test + fun `read from sql query with repeated columns` () { + @Language("SQL") + val sqlQuery = """ + SELECT c1.name, c2.name + FROM Customer c1 + INNER JOIN Customer c2 ON c1.id = c2.id + """.trimIndent() + + shouldThrow { + DataFrame.readSqlQuery(connection, sqlQuery) + } + } + + @Test + fun `read from all tables`() { + val dataframes = DataFrame.readAllSqlTables(connection) + + val customerDf = dataframes[0].cast() + + customerDf.rowsCount() shouldBe 4 + customerDf.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 2 + customerDf[0][1] shouldBe "John" + + val saleDf = dataframes[1].cast() + + saleDf.rowsCount() shouldBe 4 + saleDf.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 + saleDf[0][2] shouldBe 100.5f + + val dataframes1 = DataFrame.readAllSqlTables(connection, 1) + + val customerDf1 = dataframes1[0].cast() + + customerDf1.rowsCount() shouldBe 1 + customerDf1.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 1 + customerDf1[0][1] shouldBe "John" + + val saleDf1 = dataframes1[1].cast() + + saleDf1.rowsCount() shouldBe 1 + saleDf1.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 1 + saleDf1[0][2] shouldBe 100.5f + + val dataSchemas = DataFrame.getSchemaForAllSqlTables(connection) + + val customerDataSchema = dataSchemas[0] + customerDataSchema.columns.size shouldBe 3 + customerDataSchema.columns["name"]!!.type shouldBe typeOf() + + val saleDataSchema = dataSchemas[1] + saleDataSchema.columns.size shouldBe 3 + saleDataSchema.columns["amount"]!!.type shouldBe typeOf() + + val dbConfig = DatabaseConfiguration(url = URL) + val dataframes2 = DataFrame.readAllSqlTables(dbConfig) + + val customerDf2 = dataframes2[0].cast() + + customerDf2.rowsCount() shouldBe 4 + customerDf2.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 2 + customerDf2[0][1] shouldBe "John" + + val saleDf2 = dataframes2[1].cast() + + saleDf2.rowsCount() shouldBe 4 + saleDf2.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 + saleDf2[0][2] shouldBe 100.5f + + val dataframes3 = DataFrame.readAllSqlTables(dbConfig, 1) + + val customerDf3 = dataframes3[0].cast() + + customerDf3.rowsCount() shouldBe 1 + customerDf3.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 1 + customerDf3[0][1] shouldBe "John" + + val saleDf3 = dataframes3[1].cast() + + saleDf3.rowsCount() shouldBe 1 + saleDf3.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 1 + saleDf3[0][2] shouldBe 100.5f + + val dataSchemas1 = DataFrame.getSchemaForAllSqlTables(dbConfig) + + val customerDataSchema1 = dataSchemas1[0] + customerDataSchema1.columns.size shouldBe 3 + customerDataSchema1.columns["name"]!!.type shouldBe typeOf() + + val saleDataSchema1 = dataSchemas1[1] + saleDataSchema1.columns.size shouldBe 3 + saleDataSchema1.columns["amount"]!!.type shouldBe typeOf() + } +} + diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/imdbTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/imdbTest.kt new file mode 100644 index 0000000000..0a7bda75bf --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/imdbTest.kt @@ -0,0 +1,76 @@ +package org.jetbrains.kotlinx.dataframe.io + +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.DataSchema +import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.print +import org.junit.Test +import java.sql.DriverManager +import java.util.Properties +import org.junit.Ignore + +private const val URL = "jdbc:mariadb://localhost:3306/imdb" +private const val USER_NAME = "root" +private const val PASSWORD = "pass" + +@DataSchema +interface ActorKDF { + val id: Int + val firstName: String + val lastName: String + val gender: String +} + +@DataSchema +interface RankedMoviesWithGenres { + val name: String + val year: Int + val rank: Float + val genres: String +} + +@Ignore +class ImdbTestTest { + @Test + fun `read table`() { + val props = Properties() + props.setProperty("user", USER_NAME) + props.setProperty("password", PASSWORD) + + // generate kdf schemas by database metadata (as interfaces or extensions) + // for gradle or as classes under the hood in KNB + + DriverManager.getConnection(URL, props).use { connection -> + val df = DataFrame.readSqlTable(connection, "actors", 100).cast() + df.print() + } + } + + + @Test + fun `read sql query`() { + val sqlQuery = "select name, year, rank,\n" + + "GROUP_CONCAT (genre) as \"genres\"\n" + + "from movies join movies_directors on movie_id = movies.id\n" + + " join directors on directors.id=director_id left join movies_genres on movies.id = movies_genres.movie_id \n" + + "where directors.first_name = \"Quentin\" and directors.last_name = \"Tarantino\"\n" + + "group by name, year, rank\n" + + "order by year" + val props = Properties() + props.setProperty("user", USER_NAME) + props.setProperty("password", PASSWORD) + + // generate kdf schemas by database metadata (as interfaces or extensions) + // for gradle or as classes under the hood in KNB + + DriverManager.getConnection(URL, props).use { connection -> + val df = DataFrame.readSqlQuery(connection, sqlQuery).cast() + //df.filter { year > 2000 }.print() + df.print() + + val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery) + schema.print() + } + + } +} diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt new file mode 100644 index 0000000000..abf1d3ebd8 --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt @@ -0,0 +1,328 @@ +package org.jetbrains.kotlinx.dataframe.io + +import io.kotest.matchers.shouldBe +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.DataSchema +import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.filter +import org.jetbrains.kotlinx.dataframe.api.print +import org.junit.AfterClass +import org.junit.Assert.assertEquals +import org.junit.BeforeClass +import org.junit.Test +import java.math.BigDecimal +import java.sql.Connection +import java.sql.DriverManager +import java.sql.SQLException +import org.junit.Ignore + +private const val URL = "jdbc:mariadb://localhost:3307" +private const val USER_NAME = "root" +private const val PASSWORD = "pass" +private const val TEST_DATABASE_NAME = "testKDFdatabase" + +@DataSchema +interface Table1MariaDb { + val id: Int + val bitCol: Boolean + val tinyintCol: Int + val smallintCol: Int + val mediumintCol: Int + val mediumintUnsignedCol: Long + val integerCol: Int + val intCol: Int + val integerUnsignedCol: Long + val bigintCol: Long + val floatCol: Float + val doubleCol: Double + val decimalCol: Double + val dateCol: String + val datetimeCol: String + val timestampCol: String + val timeCol: String + val yearCol: String + val varcharCol: String + val charCol: String + val binaryCol: ByteArray + val varbinaryCol: ByteArray + val tinyblobCol: ByteArray + val blobCol: ByteArray + val mediumblobCol: ByteArray + val longblobCol: ByteArray + val textCol: String + val mediumtextCol: String + val longtextCol: String + val enumCol: String + val setCol: String +} + +@DataSchema +interface Table2MariaDb { + val id: Int + val enumCol: String + val setCol: String +} + +@Ignore +class MariadbTest { + companion object { + private lateinit var connection: Connection + + @BeforeClass + @JvmStatic + fun setUpClass() { + connection = DriverManager.getConnection(URL, USER_NAME, PASSWORD) + + connection.createStatement().use { st -> + // Drop the test database if it exists + val dropDatabaseQuery = "DROP DATABASE IF EXISTS $TEST_DATABASE_NAME" + st.executeUpdate(dropDatabaseQuery) + + // Create the test database + val createDatabaseQuery = "CREATE DATABASE $TEST_DATABASE_NAME" + st.executeUpdate(createDatabaseQuery) + + // Use the newly created database + val useDatabaseQuery = "USE $TEST_DATABASE_NAME" + st.executeUpdate(useDatabaseQuery) + } + + connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table1") } + connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table2") } + + @Language("SQL") + val createTableQuery = """ + CREATE TABLE IF NOT EXISTS table1 ( + id INT AUTO_INCREMENT PRIMARY KEY, + bitCol BIT, + tinyintCol TINYINT, + smallintCol SMALLINT, + mediumintCol MEDIUMINT, + mediumintUnsignedCol MEDIUMINT UNSIGNED, + integerCol INTEGER, + intCol INT, + integerUnsignedCol INTEGER UNSIGNED, + bigintCol BIGINT, + floatCol FLOAT, + doubleCol DOUBLE, + decimalCol DECIMAL, + dateCol DATE, + datetimeCol DATETIME, + timestampCol TIMESTAMP, + timeCol TIME, + yearCol YEAR, + varcharCol VARCHAR(255), + charCol CHAR(10), + binaryCol BINARY(64), + varbinaryCol VARBINARY(128), + tinyblobCol TINYBLOB, + blobCol BLOB, + mediumblobCol MEDIUMBLOB, + longblobCol LONGBLOB, + textCol TEXT, + mediumtextCol MEDIUMTEXT, + longtextCol LONGTEXT, + enumCol ENUM('Value1', 'Value2', 'Value3'), + setCol SET('Option1', 'Option2', 'Option3') + ) + """ + connection.createStatement().execute( + createTableQuery.trimIndent() + ) + + @Language("SQL") + val createTableQuery2 = """ + CREATE TABLE IF NOT EXISTS table2 ( + id INT AUTO_INCREMENT PRIMARY KEY, + bitCol BIT, + tinyintCol TINYINT, + smallintCol SMALLINT, + mediumintCol MEDIUMINT, + mediumintUnsignedCol MEDIUMINT UNSIGNED, + integerCol INTEGER, + intCol INT, + integerUnsignedCol INTEGER UNSIGNED, + bigintCol BIGINT, + floatCol FLOAT, + doubleCol DOUBLE, + decimalCol DECIMAL, + dateCol DATE, + datetimeCol DATETIME, + timestampCol TIMESTAMP, + timeCol TIME, + yearCol YEAR, + varcharCol VARCHAR(255), + charCol CHAR(10), + binaryCol BINARY(64), + varbinaryCol VARBINARY(128), + tinyblobCol TINYBLOB, + blobCol BLOB, + mediumblobCol MEDIUMBLOB, + longblobCol LONGBLOB, + textCol TEXT, + mediumtextCol MEDIUMTEXT, + longtextCol LONGTEXT, + enumCol ENUM('Value1', 'Value2', 'Value3'), + setCol SET('Option1', 'Option2', 'Option3') + ) + """ + connection.createStatement().execute( + createTableQuery2.trimIndent() + ) + + @Language("SQL") + val insertData1 = """ + INSERT INTO table1 ( + bitCol, tinyintCol, smallintCol, mediumintCol, mediumintUnsignedCol, integerCol, intCol, + integerUnsignedCol, bigintCol, floatCol, doubleCol, decimalCol, dateCol, datetimeCol, timestampCol, + timeCol, yearCol, varcharCol, charCol, binaryCol, varbinaryCol, tinyblobCol, blobCol, + mediumblobCol, longblobCol, textCol, mediumtextCol, longtextCol, enumCol, setCol + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """.trimIndent() + + + @Language("SQL") + val insertData2 = """ + INSERT INTO table2 ( + bitCol, tinyintCol, smallintCol, mediumintCol, mediumintUnsignedCol, integerCol, intCol, + integerUnsignedCol, bigintCol, floatCol, doubleCol, decimalCol, dateCol, datetimeCol, timestampCol, + timeCol, yearCol, varcharCol, charCol, binaryCol, varbinaryCol, tinyblobCol, blobCol, + mediumblobCol, longblobCol, textCol, mediumtextCol, longtextCol, enumCol, setCol + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """.trimIndent() + + connection.prepareStatement(insertData1).use { st -> + // Insert data into table1 + for (i in 1..3) { + st.setBoolean(1, true) + st.setByte(2, i.toByte()) + st.setShort(3, (i * 10).toShort()) + st.setInt(4, i * 100) + st.setInt(5, i * 100) + st.setInt(6, i * 100) + st.setInt(7, i * 100) + st.setInt(8, i * 100) + st.setInt(9, i * 100) + st.setFloat(10, i * 10.0f) + st.setDouble(11, i * 10.0) + st.setBigDecimal(12, BigDecimal(i * 10)) + st.setDate(13, java.sql.Date(System.currentTimeMillis())) + st.setTimestamp(14, java.sql.Timestamp(System.currentTimeMillis())) + st.setTimestamp(15, java.sql.Timestamp(System.currentTimeMillis())) + st.setTime(16, java.sql.Time(System.currentTimeMillis())) + st.setInt(17, 2023) + st.setString(18, "varcharValue$i") + st.setString(19, "charValue$i") + st.setBytes(20, "binaryValue".toByteArray()) + st.setBytes(21, "varbinaryValue".toByteArray()) + st.setBytes(22, "tinyblobValue".toByteArray()) + st.setBytes(23, "blobValue".toByteArray()) + st.setBytes(24, "mediumblobValue".toByteArray()) + st.setBytes(25, "longblobValue".toByteArray()) + st.setString(26, "textValue$i") + st.setString(27, "mediumtextValue$i") + st.setString(28, "longtextValue$i") + st.setString(29, "Value$i") + st.setString(30, "Option$i") + + st.executeUpdate() + } + } + + connection.prepareStatement(insertData2).use { st -> + // Insert data into table2 + for (i in 1..3) { + st.setBoolean(1, false) + st.setByte(2, (i * 2).toByte()) + st.setShort(3, (i * 20).toShort()) + st.setInt(4, i * 200) + st.setInt(5, i * 200) + st.setInt(6, i * 200) + st.setInt(7, i * 200) + st.setInt(8, i * 200) + st.setInt(9, i * 200) + st.setFloat(10, i * 20.0f) + st.setDouble(11, i * 20.0) + st.setBigDecimal(12, BigDecimal(i * 20)) + st.setDate(13, java.sql.Date(System.currentTimeMillis())) + st.setTimestamp(14, java.sql.Timestamp(System.currentTimeMillis())) + st.setTimestamp(15, java.sql.Timestamp(System.currentTimeMillis())) + st.setTime(16, java.sql.Time(System.currentTimeMillis())) + st.setInt(17, 2023) + st.setString(18, "varcharValue$i") + st.setString(19, "charValue$i") + st.setBytes(20, "binaryValue".toByteArray()) + st.setBytes(21, "varbinaryValue".toByteArray()) + st.setBytes(22, "tinyblobValue".toByteArray()) + st.setBytes(23, "blobValue".toByteArray()) + st.setBytes(24, "mediumblobValue".toByteArray()) + st.setBytes(25, "longblobValue".toByteArray()) + st.setString(26, "textValue$i") + st.setString(27, "mediumtextValue$i") + st.setString(28, "longtextValue$i") + st.setString(29, "Value$i") + st.setString(30, "Option$i") + st.executeUpdate() + } + } + } + + @AfterClass + @JvmStatic + fun tearDownClass() { + try { + connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table1") } + connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table2") } + connection.createStatement().use { st -> st.execute("DROP DATABASE IF EXISTS $TEST_DATABASE_NAME") } + connection.close() + } catch (e: SQLException) { + e.printStackTrace() + } + } + } + + @Test + fun `basic test for reading sql tables`() { + val df1 = DataFrame.readSqlTable(connection, "table1").cast() + df1.print() + assertEquals(3, df1.rowsCount()) + + val df2 = DataFrame.readSqlTable(connection, "table2").cast() + df2.print() + assertEquals(3, df2.rowsCount()) + } + + @Test + fun `read from sql query`() { + @Language("SQL") + val sqlQuery = """ + SELECT + t1.id, + t2.enumCol, + t2.setCol + FROM table1 t1 + JOIN table2 t2 ON t1.id = t2.id; + """.trimIndent() + + val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery).cast() + df.rowsCount() shouldBe 3 + } + + @Test + fun `read from all tables`() { + val dataframes = DataFrame.readAllSqlTables(connection) + + val table1Df = dataframes[0].cast() + + table1Df.rowsCount() shouldBe 3 + table1Df.filter { it[Table1MariaDb::integerCol] > 100 }.rowsCount() shouldBe 2 + table1Df[0][11] shouldBe 10.0 + + val table2Df = dataframes[1].cast() + + table2Df.rowsCount() shouldBe 3 + table2Df.filter { it[Table1MariaDb::integerCol] > 400 }.rowsCount() shouldBe 1 + table2Df[0][11] shouldBe 20.0 + } +} diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt new file mode 100644 index 0000000000..bb717fdc81 --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt @@ -0,0 +1,332 @@ +package org.jetbrains.kotlinx.dataframe.io + +import io.kotest.matchers.shouldBe +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.DataSchema +import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.filter +import org.junit.AfterClass +import org.junit.BeforeClass +import org.junit.Test +import java.math.BigDecimal +import java.sql.Connection +import java.sql.DriverManager +import java.sql.SQLException +import org.junit.Ignore + +private const val URL = "jdbc:mysql://localhost:3306" +private const val USER_NAME = "root" +private const val PASSWORD = "pass" +private const val TEST_DATABASE_NAME = "testKDFdatabase" + +@DataSchema +interface Table1MySql { + val id: Int + val bitCol: Boolean + val tinyintCol: Int + val smallintCol: Int + val mediumintCol: Int + val mediumintUnsignedCol: Long + val integerCol: Int + val intCol: Int + val integerUnsignedCol: Long + val bigintCol: Long + val floatCol: Float + val doubleCol: Double + val decimalCol: Double + val dateCol: String + val datetimeCol: String + val timestampCol: String + val timeCol: String + val yearCol: String + val varcharCol: String + val charCol: String + val binaryCol: ByteArray + val varbinaryCol: ByteArray + val tinyblobCol: ByteArray + val blobCol: ByteArray + val mediumblobCol: ByteArray + val longblobCol: ByteArray + val textCol: String + val mediumtextCol: String + val longtextCol: String + val enumCol: String + val setCol: String +} + +@DataSchema +interface Table2MySql { + val id: Int + val enumCol: String + val setCol: String +} + +@Ignore +class MySqlTest { + companion object { + private lateinit var connection: Connection + + @BeforeClass + @JvmStatic + fun setUpClass() { + connection = DriverManager.getConnection(URL, USER_NAME, PASSWORD) + + connection.createStatement().use { st -> + // Drop the test database if it exists + val dropDatabaseQuery = "DROP DATABASE IF EXISTS $TEST_DATABASE_NAME" + st.executeUpdate(dropDatabaseQuery) + + // Create the test database + val createDatabaseQuery = "CREATE DATABASE $TEST_DATABASE_NAME" + st.executeUpdate(createDatabaseQuery) + + // Use the newly created database + val useDatabaseQuery = "USE $TEST_DATABASE_NAME" + st.executeUpdate(useDatabaseQuery) + } + + connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table1") } + connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table2") } + + @Language("SQL") + val createTableQuery = """ + CREATE TABLE IF NOT EXISTS table1 ( + id INT AUTO_INCREMENT PRIMARY KEY, + bitCol BIT, + tinyintCol TINYINT, + smallintCol SMALLINT, + mediumintCol MEDIUMINT, + mediumintUnsignedCol MEDIUMINT UNSIGNED, + integerCol INTEGER, + intCol INT, + integerUnsignedCol INTEGER UNSIGNED, + bigintCol BIGINT, + floatCol FLOAT, + doubleCol DOUBLE, + decimalCol DECIMAL, + dateCol DATE, + datetimeCol DATETIME, + timestampCol TIMESTAMP, + timeCol TIME, + yearCol YEAR, + varcharCol VARCHAR(255), + charCol CHAR(10), + binaryCol BINARY(64), + varbinaryCol VARBINARY(128), + tinyblobCol TINYBLOB, + blobCol BLOB, + mediumblobCol MEDIUMBLOB, + longblobCol LONGBLOB, + textCol TEXT, + mediumtextCol MEDIUMTEXT, + longtextCol LONGTEXT, + enumCol ENUM('Value1', 'Value2', 'Value3'), + setCol SET('Option1', 'Option2', 'Option3'), + location GEOMETRY, + data JSON + ) + """ + + connection.createStatement().execute( + createTableQuery.trimIndent() + ) + + @Language("SQL") + val createTableQuery2 = """ + CREATE TABLE IF NOT EXISTS table2 ( + id INT AUTO_INCREMENT PRIMARY KEY, + bitCol BIT, + tinyintCol TINYINT, + smallintCol SMALLINT, + mediumintCol MEDIUMINT, + mediumintUnsignedCol MEDIUMINT UNSIGNED, + integerCol INTEGER, + intCol INT, + integerUnsignedCol INTEGER UNSIGNED, + bigintCol BIGINT, + floatCol FLOAT, + doubleCol DOUBLE, + decimalCol DECIMAL, + dateCol DATE, + datetimeCol DATETIME, + timestampCol TIMESTAMP, + timeCol TIME, + yearCol YEAR, + varcharCol VARCHAR(255), + charCol CHAR(10), + binaryCol BINARY(64), + varbinaryCol VARBINARY(128), + tinyblobCol TINYBLOB, + blobCol BLOB, + mediumblobCol MEDIUMBLOB, + longblobCol LONGBLOB, + textCol TEXT, + mediumtextCol MEDIUMTEXT, + longtextCol LONGTEXT, + enumCol ENUM('Value1', 'Value2', 'Value3'), + setCol SET('Option1', 'Option2', 'Option3'), + location GEOMETRY, + data JSON + ) + """ + + connection.createStatement().execute( + createTableQuery2.trimIndent() + ) + + @Language("SQL") + val insertData1 = """ + INSERT INTO table1 ( + bitCol, tinyintCol, smallintCol, mediumintCol, mediumintUnsignedCol, integerCol, intCol, + integerUnsignedCol, bigintCol, floatCol, doubleCol, decimalCol, dateCol, datetimeCol, timestampCol, + timeCol, yearCol, varcharCol, charCol, binaryCol, varbinaryCol, tinyblobCol, blobCol, + mediumblobCol, longblobCol, textCol, mediumtextCol, longtextCol, enumCol, setCol, location, data + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ST_GeomFromText('POINT(1 1)'), ?) + """.trimIndent() + + @Language("SQL") + val insertData2 = """ + INSERT INTO table2 ( + bitCol, tinyintCol, smallintCol, mediumintCol, mediumintUnsignedCol, integerCol, intCol, + integerUnsignedCol, bigintCol, floatCol, doubleCol, decimalCol, dateCol, datetimeCol, timestampCol, + timeCol, yearCol, varcharCol, charCol, binaryCol, varbinaryCol, tinyblobCol, blobCol, + mediumblobCol, longblobCol, textCol, mediumtextCol, longtextCol, enumCol, setCol, location, data + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ST_GeomFromText('POINT(1 1)'), ?) + """.trimIndent() + + connection.prepareStatement(insertData1).use { st -> + // Insert data into table1 + for (i in 1..3) { + st.setBoolean(1, true) + st.setByte(2, i.toByte()) + st.setShort(3, (i * 10).toShort()) + st.setInt(4, i * 100) + st.setInt(5, i * 100) + st.setInt(6, i * 100) + st.setInt(7, i * 100) + st.setInt(8, i * 100) + st.setInt(9, i * 100) + st.setFloat(10, i * 10.0f) + st.setDouble(11, i * 10.0) + st.setBigDecimal(12, BigDecimal(i * 10)) + st.setDate(13, java.sql.Date(System.currentTimeMillis())) + st.setTimestamp(14, java.sql.Timestamp(System.currentTimeMillis())) + st.setTimestamp(15, java.sql.Timestamp(System.currentTimeMillis())) + st.setTime(16, java.sql.Time(System.currentTimeMillis())) + st.setInt(17, 2023) + st.setString(18, "varcharValue$i") + st.setString(19, "charValue$i") + st.setBytes(20, "binaryValue".toByteArray()) + st.setBytes(21, "varbinaryValue".toByteArray()) + st.setBytes(22, "tinyblobValue".toByteArray()) + st.setBytes(23, "blobValue".toByteArray()) + st.setBytes(24, "mediumblobValue".toByteArray()) + st.setBytes(25, "longblobValue".toByteArray()) + st.setString(26, "textValue$i") + st.setString(27, "mediumtextValue$i") + st.setString(28, "longtextValue$i") + st.setString(29, "Value$i") + st.setString(30, "Option$i") + st.setString(31, "{\"key\": \"value\"}") + st.executeUpdate() + } + } + + connection.prepareStatement(insertData2).use { st -> + // Insert data into table2 + for (i in 1..3) { + st.setBoolean(1, false) + st.setByte(2, (i * 2).toByte()) + st.setShort(3, (i * 20).toShort()) + st.setInt(4, i * 200) + st.setInt(5, i * 200) + st.setInt(6, i * 200) + st.setInt(7, i * 200) + st.setInt(8, i * 200) + st.setInt(9, i * 200) + st.setFloat(10, i * 20.0f) + st.setDouble(11, i * 20.0) + st.setBigDecimal(12, BigDecimal(i * 20)) + st.setDate(13, java.sql.Date(System.currentTimeMillis())) + st.setTimestamp(14, java.sql.Timestamp(System.currentTimeMillis())) + st.setTimestamp(15, java.sql.Timestamp(System.currentTimeMillis())) + st.setTime(16, java.sql.Time(System.currentTimeMillis())) + st.setInt(17, 2023) + st.setString(18, "varcharValue$i") + st.setString(19, "charValue$i") + st.setBytes(20, "binaryValue".toByteArray()) + st.setBytes(21, "varbinaryValue".toByteArray()) + st.setBytes(22, "tinyblobValue".toByteArray()) + st.setBytes(23, "blobValue".toByteArray()) + st.setBytes(24, "mediumblobValue".toByteArray()) + st.setBytes(25, "longblobValue".toByteArray()) + st.setString(26, "textValue$i") + st.setString(27, "mediumtextValue$i") + st.setString(28, "longtextValue$i") + st.setString(29, "Value$i") + st.setString(30, "Option$i") + st.setString(31, "{\"key\": \"value\"}") + st.executeUpdate() + } + } + } + + @AfterClass + @JvmStatic + fun tearDownClass() { + try { + connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table1") } + connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table2") } + connection.createStatement().use { st -> st.execute("DROP DATABASE IF EXISTS $TEST_DATABASE_NAME") } + connection.close() + } catch (e: SQLException) { + e.printStackTrace() + } + } + } + + @Test + fun `basic test for reading sql tables`() { + val df1 = DataFrame.readSqlTable(connection, "table1").cast() + df1.rowsCount() shouldBe 3 + + val df2 = DataFrame.readSqlTable(connection, "table2").cast() + df2.rowsCount() shouldBe 3 + + //TODO: add test for JSON column + } + + @Test + fun `read from sql query`() { + @Language("SQL") + val sqlQuery = """ + SELECT + t1.id, + t2.enumCol, + t2.setCol + FROM table1 t1 + JOIN table2 t2 ON t1.id = t2.id; + """.trimIndent() + + val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery).cast() + df.rowsCount() shouldBe 3 + } + + @Test + fun `read from all tables`() { + val dataframes = DataFrame.readAllSqlTables(connection) + + val table1Df = dataframes[0].cast() + + table1Df.rowsCount() shouldBe 3 + table1Df.filter { it[Table1MariaDb::integerCol] > 100 }.rowsCount() shouldBe 2 + table1Df[0][11] shouldBe 10.0 + + val table2Df = dataframes[1].cast() + + table2Df.rowsCount() shouldBe 3 + table2Df.filter { it[Table1MariaDb::integerCol] > 400 }.rowsCount() shouldBe 1 + table2Df[0][11] shouldBe 20.0 + } +} diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt new file mode 100644 index 0000000000..198e54283b --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt @@ -0,0 +1,286 @@ +package org.jetbrains.kotlinx.dataframe.io + +import io.kotest.matchers.shouldBe +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.DataSchema +import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.filter +import org.junit.AfterClass +import org.junit.BeforeClass +import org.junit.Test +import org.postgresql.util.PGobject +import java.math.BigDecimal +import java.sql.Connection +import java.sql.DriverManager +import java.sql.SQLException +import java.util.UUID +import org.junit.Ignore + +private const val URL = "jdbc:postgresql://localhost:5432/test" +private const val USER_NAME = "postgres" +private const val PASSWORD = "pass" + +@DataSchema +interface Table1 { + val id: Int + val bigintcol: Long + val bigserialcol: Long + val booleancol: Boolean + val boxcol: String + val byteacol: ByteArray + val charactercol: String + val characterncol: String + val charcol: String + val circlecol: String + val datecol: java.sql.Date + val doublecol: Double + val integercol: Int + val intervalcol: String + val jsoncol: String + val jsonbcol: String +} + +@DataSchema +interface Table2 { + val id: Int + val linecol: String + val lsegcol: String + val macaddrcol: String + val moneycol: String + val numericcol: String + val pathcol: String + val pointcol: String + val polygoncol: String + val realcol: Float + val smallintcol: Short + val smallserialcol: Int + val serialcol: Int + val textcol: String + val timecol: String + val timewithzonecol: String + val timestampcol: String + val timestampwithzonecol: String + val uuidcol: String + val xmlcol: String +} + +@DataSchema +interface ViewTable { + val id: Int + val bigintcol: Long + val linecol: String + val numericcol: String +} + +@Ignore +class PostgresTest { + companion object { + private lateinit var connection: Connection + + @BeforeClass + @JvmStatic + fun setUpClass() { + connection = DriverManager.getConnection(URL, USER_NAME, PASSWORD) + + connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table1") } + connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table2") } + + @Language("SQL") + val createTableStatement = """ + CREATE TABLE IF NOT EXISTS table1 ( + id serial PRIMARY KEY, + bigintCol bigint, + bigserialCol bigserial, + booleanCol boolean, + boxCol box, + byteaCol bytea, + characterCol character, + characterNCol character(10), + charCol char, + circleCol circle, + dateCol date, + doubleCol double precision, + integerCol integer, + intervalCol interval, + jsonCol json, + jsonbCol jsonb + ) + """ + connection.createStatement().execute( + createTableStatement.trimIndent() + ) + + @Language("SQL") + val createTableQuery = """ + CREATE TABLE IF NOT EXISTS table2 ( + id serial PRIMARY KEY, + lineCol line, + lsegCol lseg, + macaddrCol macaddr, + moneyCol money, + numericCol numeric, + pathCol path, + pointCol point, + polygonCol polygon, + realCol real, + smallintCol smallint, + smallserialCol smallserial, + serialCol serial, + textCol text, + timeCol time, + timeWithZoneCol time with time zone, + timestampCol timestamp, + timestampWithZoneCol timestamp with time zone, + uuidCol uuid, + xmlCol xml + ) + """ + connection.createStatement().execute( + createTableQuery.trimIndent() + ) + + @Language("SQL") + val insertData1 = """ + INSERT INTO table1 ( + bigintCol, bigserialCol, booleanCol, + boxCol, byteaCol, characterCol, characterNCol, charCol, + circleCol, dateCol, doubleCol, + integerCol, intervalCol, jsonCol, jsonbCol + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """ + + @Language("SQL") + val insertData2 = """ + INSERT INTO table2 ( + lineCol, lsegCol, macaddrCol, moneyCol, numericCol, + pathCol, pointCol, polygonCol, realCol, smallintCol, + smallserialCol, serialCol, textCol, timeCol, + timeWithZoneCol, timestampCol, timestampWithZoneCol, + uuidCol, xmlCol + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """ + + connection.prepareStatement(insertData1).use { st -> + // Insert data into table1 + for (i in 1..3) { + st.setLong(1, i * 1000L) + st.setLong(2, 1000000000L + i) + st.setBoolean(3, i % 2 == 1) + st.setObject(4, org.postgresql.geometric.PGbox("(1,1),(2,2)")) + st.setBytes(5, byteArrayOf(1, 2, 3)) + st.setString(6, "A") + st.setString(7, "Hello") + st.setString(8, "A") + st.setObject(9, org.postgresql.geometric.PGcircle("<(1,2),3>")) + st.setDate(10, java.sql.Date.valueOf("2023-08-01")) + st.setDouble(11, 12.34) + st.setInt(12, 12345 * i) + st.setObject(13, org.postgresql.util.PGInterval("1 year")) + + val jsonbObject = PGobject() + jsonbObject.type = "jsonb" + jsonbObject.value = "{\"key\": \"value\"}" + + st.setObject(14, jsonbObject) + st.setObject(15, jsonbObject) + st.executeUpdate() + } + } + + connection.prepareStatement(insertData2).use { st -> + // Insert data into table2 + for (i in 1..3) { + st.setObject(1, org.postgresql.geometric.PGline("{1,2,3}")) + st.setObject(2, org.postgresql.geometric.PGlseg("[(-1,0),(1,0)]")) + + val macaddrObject = PGobject() + macaddrObject.type = "macaddr" + macaddrObject.value = "00:00:00:00:00:0$i" + + st.setObject(3, macaddrObject) + st.setBigDecimal(4, BigDecimal("123.45")) + st.setBigDecimal(5, BigDecimal("12.34")) + st.setObject(6, org.postgresql.geometric.PGpath("((1,2),(3,$i))")) + st.setObject(7, org.postgresql.geometric.PGpoint("(1,2)")) + st.setObject(8, org.postgresql.geometric.PGpolygon("((1,1),(2,2),(3,3))")) + st.setFloat(9, 12.34f) + st.setShort(10, (i * 100).toShort()) + st.setInt(11, 1000 + i) + st.setInt(12, 1000000 + i) + st.setString(13, "Text data $i") + st.setTime(14, java.sql.Time.valueOf("12:34:56")) + + st.setTimestamp(15, java.sql.Timestamp(System.currentTimeMillis())) + st.setTimestamp(16, java.sql.Timestamp(System.currentTimeMillis())) + st.setTimestamp(17, java.sql.Timestamp(System.currentTimeMillis())) + + st.setObject(18, UUID.randomUUID(), java.sql.Types.OTHER) + val xmlObject = PGobject() + xmlObject.type = "xml" + xmlObject.value = "data" + + st.setObject(19, xmlObject) + st.executeUpdate() + } + } + } + + @AfterClass + @JvmStatic + fun tearDownClass() { + try { + connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table1") } + connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table2") } + connection.close() + } catch (e: SQLException) { + e.printStackTrace() + } + } + } + + @Test + fun `read from tables`() { + val df1 = DataFrame.readSqlTable(connection, "table1").cast() + df1.rowsCount() shouldBe 3 + + val df2 = DataFrame.readSqlTable(connection, "table2").cast() + df2.rowsCount() shouldBe 3 + } + + @Test + fun `read from sql query`() { + @Language("SQL") + val sqlQuery = """ + SELECT + t1.id AS t1_id, + t1.bigintCol, + t2.lineCol, + t2.numericCol + FROM table1 t1 + JOIN table2 t2 ON t1.id = t2.id; + """.trimIndent() + + val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery).cast() + df.rowsCount() shouldBe 3 + } + + @Test + fun `read from all tables`() { + val dataframes = DataFrame.readAllSqlTables(connection) + + val table1Df = dataframes[0].cast() + + table1Df.rowsCount() shouldBe 3 + table1Df.filter { it[Table1::integercol] > 12345 }.rowsCount() shouldBe 2 + table1Df[0][1] shouldBe 1000L + + val table2Df = dataframes[1].cast() + + table2Df.rowsCount() shouldBe 3 + table2Df.filter { it[Table2::pathcol] == "((1,2),(3,1))" }.rowsCount() shouldBe 1 + table2Df[0][11] shouldBe 1001 + + //TODO: add test for JSON column + } +} diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt new file mode 100644 index 0000000000..afb397ca8f --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt @@ -0,0 +1,191 @@ +package org.jetbrains.kotlinx.dataframe.io + +import io.kotest.matchers.shouldBe +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.DataSchema +import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.filter +import org.junit.AfterClass +import org.junit.BeforeClass +import org.junit.Test +import java.sql.Connection +import java.sql.DriverManager +import java.sql.SQLException +import org.junit.Ignore + +private const val DATABASE_URL = "jdbc:sqlite:" + +@DataSchema +interface CustomerSQLite { + val id: Int + val name: String + val age: Int + val salary: Double + val profilePicture: ByteArray +} + +@DataSchema +interface OrderSQLite { + val id: Int + val customerName: String + val orderDate: String + val totalAmount: Double + val orderDetails: ByteArray +} + +@DataSchema +interface CustomerOrderSQLite { + val customerId: Int + val customerName: String + val customerAge: Int + val customerSalary: Double + val customerProfilePicture: ByteArray + val orderId: Int + val orderDate: String + val totalAmount: Double + val orderDetails: ByteArray +} + +@Ignore +class SqliteTest { + companion object { + private lateinit var connection: Connection + + @BeforeClass + @JvmStatic + fun setUpClass() { + connection = DriverManager.getConnection(DATABASE_URL) + + @Language("SQL") + val createCustomersTableQuery = """ + CREATE TABLE Customers ( + id INTEGER AUTO_INCREMENT PRIMARY KEY, + name TEXT, + age INTEGER, + salary REAL, + profilePicture BLOB + ) + """ + + connection.createStatement().execute( + createCustomersTableQuery + ) + + @Language("SQL") + val createOrderTableQuery = """ + CREATE TABLE Orders ( + id INTEGER AUTO_INCREMENT PRIMARY KEY, + customerName TEXT, + orderDate TEXT, + totalAmount NUMERIC, + orderDetails BLOB + ) + """ + + connection.createStatement().execute( + createOrderTableQuery + ) + + val profilePicture = "SampleProfilePictureData".toByteArray() + val orderDetails = "OrderDetailsData".toByteArray() + + connection.prepareStatement("INSERT INTO Customers (name, age, salary, profilePicture) VALUES (?, ?, ?, ?)") + .use { + it.setString(1, "John Doe") + it.setInt(2, 30) + it.setDouble(3, 2500.50) + it.setBytes(4, profilePicture) + it.executeUpdate() + } + + connection.prepareStatement("INSERT INTO Customers (name, age, salary, profilePicture) VALUES (?, ?, ?, ?)") + .use { + it.setString(1, "Max Joint") + it.setInt(2, 40) + it.setDouble(3, 1500.50) + it.setBytes(4, profilePicture) + it.executeUpdate() + } + + connection.prepareStatement("INSERT INTO Orders (customerName, orderDate, totalAmount, orderDetails) VALUES (?, ?, ?, ?)") + .use { + it.setString(1, "John Doe") + it.setString(2, "2023-07-21") + it.setDouble(3, 150.75) + it.setBytes(4, orderDetails) + it.executeUpdate() + } + + connection.prepareStatement("INSERT INTO Orders (customerName, orderDate, totalAmount, orderDetails) VALUES (?, ?, ?, ?)") + .use { + it.setString(1, "Max Joint") + it.setString(2, "2023-08-21") + it.setDouble(3, 250.75) + it.setBytes(4, orderDetails) + it.executeUpdate() + } + } + + @AfterClass + @JvmStatic + fun tearDownClass() { + try { + connection.close() + } catch (e: SQLException) { + e.printStackTrace() + } + } + } + + @Test + fun `read from tables`() { + val df = DataFrame.readSqlTable(connection, "Customers").cast() + df.rowsCount() shouldBe 2 + + val df2 = DataFrame.readSqlTable(connection, "Orders").cast() + df2.rowsCount() shouldBe 2 + } + + @Test + fun `read from sql query`() { + @Language("SQL") + val sqlQuery = """ + SELECT + c.id AS customerId, + c.name AS customerName, + c.age AS customerAge, + c.salary AS customerSalary, + c.profilePicture AS customerProfilePicture, + o.id AS orderId, + o.orderDate AS orderDate, + o.totalAmount AS totalAmount, + o.orderDetails AS orderDetails + FROM Customers c + INNER JOIN Orders o ON c.name = o.customerName + """ + + val df = DataFrame.readSqlQuery(connection, sqlQuery).cast() + df.rowsCount() shouldBe 2 + + val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery) + schema.columns.entries.size shouldBe 9 + } + + @Test + fun `read from all tables`() { + val dataframes = DataFrame.readAllSqlTables(connection) + + val customerDf = dataframes[0].cast() + + customerDf.rowsCount() shouldBe 2 + customerDf.filter { it[CustomerSQLite::age] > 30 }.rowsCount() shouldBe 1 + customerDf[0][1] shouldBe "John Doe" + + val orderDf = dataframes[1].cast() + + orderDf.rowsCount() shouldBe 2 + orderDf.filter { it[OrderSQLite::totalAmount] > 200 }.rowsCount() shouldBe 1 + orderDf[0][1] shouldBe "John Doe" + } +} diff --git a/dataframe-jdbc/src/test/resources/simplelogger.properties b/dataframe-jdbc/src/test/resources/simplelogger.properties new file mode 100644 index 0000000000..473ff213b7 --- /dev/null +++ b/dataframe-jdbc/src/test/resources/simplelogger.properties @@ -0,0 +1,34 @@ +# SLF4J's SimpleLogger configuration file +# Simple implementation of Logger that sends all enabled log messages, for all defined loggers, to System.err. + +# Default logging detail level for all instances of SimpleLogger. +# Must be one of ("trace", "debug", "info", "warn", or "error"). +# If not specified, defaults to "info". +org.slf4j.simpleLogger.defaultLogLevel=debug + +# Logging detail level for a SimpleLogger instance named "xxxxx". +# Must be one of ("trace", "debug", "info", "warn", or "error"). +# If not specified, the default logging detail level is used. +#org.slf4j.simpleLogger.log.xxxxx= + +# Set to true if you want the current date and time to be included in output messages. +# Default is false, and will output the number of milliseconds elapsed since startup. +org.slf4j.simpleLogger.showDateTime=true + +# The date and time format to be used in the output messages. +# The pattern describing the date and time format is the same that is used in java.text.SimpleDateFormat. +# If the format is not specified or is invalid, the default format is used. +# The default format is yyyy-MM-dd HH:mm:ss:SSS Z. +org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z + +# Set to true if you want to output the current thread name. +# Defaults to true. +org.slf4j.simpleLogger.showThreadName=true + +# Set to true if you want the Logger instance name to be included in output messages. +# Defaults to true. +org.slf4j.simpleLogger.showLogName=true + +# Set to true if you want the last component of the name to be included in output messages. +# Defaults to false. +#org.slf4j.simpleLogger.showShortLogName=false diff --git a/docs/StardustDocs/topics/gettingStartedGradleAdvanced.md b/docs/StardustDocs/topics/gettingStartedGradleAdvanced.md index ddf789272d..861112e0ac 100644 --- a/docs/StardustDocs/topics/gettingStartedGradleAdvanced.md +++ b/docs/StardustDocs/topics/gettingStartedGradleAdvanced.md @@ -125,6 +125,7 @@ dependencies { implementation("org.jetbrains.kotlinx:dataframe-core:%dataFrameVersion%") // Optional formats support implementation("org.jetbrains.kotlinx:dataframe-excel:%dataFrameVersion%") + implementation("org.jetbrains.kotlinx:dataframe-jdbc:%dataFrameVersion%") implementation("org.jetbrains.kotlinx:dataframe-arrow:%dataFrameVersion%") implementation("org.jetbrains.kotlinx:dataframe-openapi:%dataFrameVersion%") } @@ -140,6 +141,7 @@ dependencies { implementation 'org.jetbrains.kotlinx:dataframe-core:%dataFrameVersion%' // Optional formats support implementation 'org.jetbrains.kotlinx:dataframe-excel:%dataFrameVersion%' + implementation 'org.jetbrains.kotlinx:dataframe-jdbc:%dataFrameVersion%' implementation 'org.jetbrains.kotlinx:dataframe-arrow:%dataFrameVersion%' implementation 'org.jetbrains.kotlinx:dataframe-openapi:%dataFrameVersion%' } diff --git a/docs/StardustDocs/topics/read.md b/docs/StardustDocs/topics/read.md index d4999bcb4f..ea445900ba 100644 --- a/docs/StardustDocs/topics/read.md +++ b/docs/StardustDocs/topics/read.md @@ -3,7 +3,7 @@ The Kotlin DataFrame library supports CSV, TSV, JSON, XLS and XLSX, and Apache Arrow input formats. -The `.read()` function automatically detects the input format based on file extension and content: +The `.read()` function automatically detects the input format based on a file extension and content: ```kotlin DataFrame.read("input.csv") @@ -129,7 +129,7 @@ val df = DataFrame.readCSV( -2) Disable type inference for specific column and convert it yourself +2) Disable type inference for a specific column and convert it yourself diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 8b6b20a85d..b16f3a7bcc 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -16,9 +16,16 @@ commonsCompress = "1.21" klaxon = "5.5" fuel = "2.3.1" poi = "5.2.2" +mariadb = "3.1.4" +h2db = "2.2.220" +mysql = "8.0.33" +postgresql = "42.6.0" +sqlite = "3.42.0.1" kotlinDatetime = "0.4.0" kotlinpoet = "1.12.0" openapi = "2.1.13" +kotlinLogging = "5.0.1" +sl4j = "2.0.7" junit = "4.13.2" kotestAsserions = "4.6.3" @@ -42,6 +49,12 @@ commonsCompress = { module = "org.apache.commons:commons-compress", version.ref klaxon = { module = "com.beust:klaxon", version.ref = "klaxon" } fuel = { module = "com.github.kittinunf.fuel:fuel", version.ref = "fuel" } poi = { module = "org.apache.poi:poi", version.ref = "poi" } +mariadb = { group = "org.mariadb.jdbc", name = "mariadb-java-client", version.ref = "mariadb" } +h2db = { group = "com.h2database", name = "h2", version.ref = "h2db" } +mysql = { group = "mysql", name = "mysql-connector-java", version.ref = "mysql" } +postgresql = { group = "org.postgresql", name = "postgresql", version.ref = "postgresql" } +sqlite = { group = "org.xerial", name = "sqlite-jdbc", version.ref = "sqlite" } + poi-ooxml = { module = "org.apache.poi:poi-ooxml", version.ref = "poi" } kotlin-datetimeJvm = { module = "org.jetbrains.kotlinx:kotlinx-datetime-jvm", version.ref = "kotlinDatetime" } @@ -56,6 +69,9 @@ arrow-memory = { group = "org.apache.arrow", name = "arrow-memory-unsafe", versi kotlinpoet = { group = "com.squareup", name = "kotlinpoet", version.ref = "kotlinpoet" } swagger = { group = "io.swagger.parser.v3", name = "swagger-parser", version.ref = "openapi" } +kotlinLogging = { group = "io.github.oshai", name = "kotlin-logging", version.ref = "kotlinLogging" } +sl4j = { group = "org.slf4j", name = "slf4j-simple", version.ref = "sl4j" } + [plugins] jupyter-api = { id = "org.jetbrains.kotlin.jupyter.api", version.ref = "kotlinJupyter" } diff --git a/plugins/dataframe-gradle-plugin/build.gradle.kts b/plugins/dataframe-gradle-plugin/build.gradle.kts index ba81c9e8e0..30bc3d1d30 100644 --- a/plugins/dataframe-gradle-plugin/build.gradle.kts +++ b/plugins/dataframe-gradle-plugin/build.gradle.kts @@ -21,6 +21,7 @@ dependencies { implementation(project(":dataframe-arrow")) implementation(project(":dataframe-openapi")) implementation(project(":dataframe-excel")) + implementation(project(":dataframe-jdbc")) implementation(kotlin("gradle-plugin-api")) implementation(kotlin("gradle-plugin")) implementation("com.beust:klaxon:5.5") @@ -32,6 +33,7 @@ dependencies { testImplementation("com.android.tools.build:gradle-api:7.3.1") testImplementation("com.android.tools.build:gradle:7.3.1") testImplementation("io.ktor:ktor-server-netty:1.6.7") + testImplementation(libs.h2db) testImplementation(gradleApi()) } @@ -125,10 +127,12 @@ val integrationTestTask = task("integrationTest") { dependsOn(":plugins:symbol-processor:publishToMavenLocal") dependsOn(":dataframe-arrow:publishToMavenLocal") dependsOn(":dataframe-excel:publishToMavenLocal") + dependsOn(":dataframe-jdbc:publishToMavenLocal") dependsOn(":dataframe-openapi:publishToMavenLocal") dependsOn(":publishApiPublicationToMavenLocal") dependsOn(":dataframe-arrow:publishDataframeArrowPublicationToMavenLocal") dependsOn(":dataframe-excel:publishDataframeExcelPublicationToMavenLocal") + dependsOn(":dataframe-jdbc:publishDataframeJDBCPublicationToMavenLocal") dependsOn(":dataframe-openapi:publishDataframeOpenApiPublicationToMavenLocal") dependsOn(":plugins:symbol-processor:publishMavenPublicationToMavenLocal") dependsOn(":core:publishCorePublicationToMavenLocal") diff --git a/plugins/dataframe-gradle-plugin/src/integrationTest/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorPluginIntegrationTest.kt b/plugins/dataframe-gradle-plugin/src/integrationTest/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorPluginIntegrationTest.kt index 446cf1d0a5..d6c1e7a299 100644 --- a/plugins/dataframe-gradle-plugin/src/integrationTest/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorPluginIntegrationTest.kt +++ b/plugins/dataframe-gradle-plugin/src/integrationTest/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorPluginIntegrationTest.kt @@ -6,6 +6,8 @@ import org.gradle.testkit.runner.TaskOutcome import org.junit.Ignore import org.junit.Test import java.io.File +import java.sql.Connection +import java.sql.DriverManager class SchemaGeneratorPluginIntegrationTest : AbstractDataFramePluginIntegrationTest() { private companion object { @@ -372,6 +374,127 @@ class SchemaGeneratorPluginIntegrationTest : AbstractDataFramePluginIntegrationT result.task(":build")?.outcome shouldBe TaskOutcome.SUCCESS } + @Test + @Ignore + // TODO: test is broken + /* + e: file://test3901867314473689900/src/main/kotlin/Main.kt:12:43 Unresolved reference: readSqlTable + e: file://test3901867314473689900/src/main/kotlin/Main.kt:13:43 Unresolved reference: DatabaseConfiguration + e: file://test3901867314473689900/src/main/kotlin/Main.kt:19:28 Unresolved reference: readSqlTable + e: file://test3901867314473689900/src/main/kotlin/Main.kt:20:21 Unresolved reference: age + e: file://test3901867314473689900/src/main/kotlin/Main.kt:22:29 Unresolved reference: readSqlTable + e: file://test3901867314473689900/src/main/kotlin/Main.kt:23:22 Unresolved reference: age + e: file://test3901867314473689900/src/main/kotlin/Main.kt:25:24 Unresolved reference: DatabaseConfiguration + e: file://test3901867314473689900/src/main/kotlin/Main.kt:26:29 Unresolved reference: readSqlTable + e: file://test3901867314473689900/src/main/kotlin/Main.kt:27:22 Unresolved reference: age + e: file://test3901867314473689900/src/main/kotlin/Main.kt:29:29 Unresolved reference: readSqlTable + e: file://test3901867314473689900/src/main/kotlin/Main.kt:30:22 Unresolved reference: age + */ + fun `preprocessor imports schema from database`() { + val connectionUrl = "jdbc:h2:mem:test;DB_CLOSE_DELAY=-1;MODE=MySQL;DATABASE_TO_UPPER=false" + DriverManager.getConnection(connectionUrl).use { + val (_, result) = runGradleBuild(":build") { buildDir -> + createTestDatabase(it) + + val kotlin = File(buildDir, "src/main/kotlin").also { it.mkdirs() } + val main = File(kotlin, "Main.kt") + // this is a copy of the code snippet in the + // DataFrameJdbcSymbolProcessorTest.`schema extracted via readFromDB method is resolved` + main.writeText(""" + @file:ImportDataSchema(name = "Customer", path = "$connectionUrl") + + package test + + import org.jetbrains.kotlinx.dataframe.annotations.ImportDataSchema + import org.jetbrains.kotlinx.dataframe.api.filter + import org.jetbrains.kotlinx.dataframe.DataFrame + import org.jetbrains.kotlinx.dataframe.api.cast + import java.sql.Connection + import java.sql.DriverManager + import java.sql.SQLException + import org.jetbrains.kotlinx.dataframe.io.readSqlTable + import org.jetbrains.kotlinx.dataframe.io.DatabaseConfiguration + + fun main() { + Class.forName("org.h2.Driver") + val tableName = "Customer" + DriverManager.getConnection("$connectionUrl").use { connection -> + val df = DataFrame.readSqlTable(connection, tableName).cast() + df.filter { age > 30 } + + val df1 = DataFrame.readSqlTable(connection, tableName, 1).cast() + df1.filter { age > 30 } + + val dbConfig = DatabaseConfiguration(url = "$connectionUrl") + val df2 = DataFrame.readSqlTable(dbConfig, tableName).cast() + df2.filter { age > 30 } + + val df3 = DataFrame.readSqlTable(dbConfig, tableName, 1).cast() + df3.filter { age > 30 } + + } + } + """.trimIndent()) + + """ + import org.jetbrains.dataframe.gradle.SchemaGeneratorExtension + + plugins { + kotlin("jvm") version "$kotlinVersion" + id("org.jetbrains.kotlinx.dataframe") + } + + repositories { + mavenLocal() + mavenCentral() + } + + dependencies { + implementation(files("$dataframeJarPath")) + } + + kotlin.sourceSets.getByName("main").kotlin.srcDir("build/generated/ksp/main/kotlin/") + """.trimIndent() + } + result.task(":build")?.outcome shouldBe TaskOutcome.SUCCESS + } + } + + private fun createTestDatabase(connection: Connection) { + // Crate table Customer + connection.createStatement().execute( + """ + CREATE TABLE Customer ( + id INT PRIMARY KEY, + name VARCHAR(50), + age INT + ) + """.trimIndent() + ) + + // Create table Sale + connection.createStatement().execute( + """ + CREATE TABLE Sale ( + id INT PRIMARY KEY, + customerId INT, + amount DECIMAL(10, 2) + ) + """.trimIndent() + ) + + // add data to the Customer table + connection.createStatement().execute("INSERT INTO Customer (id, name, age) VALUES (1, 'John', 40)") + connection.createStatement().execute("INSERT INTO Customer (id, name, age) VALUES (2, 'Alice', 25)") + connection.createStatement().execute("INSERT INTO Customer (id, name, age) VALUES (3, 'Bob', 47)") + + // add data to the Sale table + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (1, 1, 100.50)") + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (2, 2, 50.00)") + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (3, 1, 75.25)") + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (4, 3, 35.15)") + } + @Test fun `generated code compiles in explicit api mode`() { val (_, result) = runGradleBuild(":build") { buildDir -> diff --git a/plugins/dataframe-gradle-plugin/src/main/kotlin/org/jetbrains/dataframe/gradle/GenerateDataSchemaTask.kt b/plugins/dataframe-gradle-plugin/src/main/kotlin/org/jetbrains/dataframe/gradle/GenerateDataSchemaTask.kt index fb1a1a9c05..b88c37b05b 100644 --- a/plugins/dataframe-gradle-plugin/src/main/kotlin/org/jetbrains/dataframe/gradle/GenerateDataSchemaTask.kt +++ b/plugins/dataframe-gradle-plugin/src/main/kotlin/org/jetbrains/dataframe/gradle/GenerateDataSchemaTask.kt @@ -8,6 +8,7 @@ import org.gradle.api.tasks.Input import org.gradle.api.tasks.OutputFile import org.gradle.api.tasks.TaskAction import org.jetbrains.dataframe.impl.codeGen.CodeGenerator +import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.codeGen.MarkerVisibility import org.jetbrains.kotlinx.dataframe.codeGen.NameNormalizer import org.jetbrains.kotlinx.dataframe.impl.codeGen.CodeGenerationReadResult @@ -16,16 +17,19 @@ import org.jetbrains.kotlinx.dataframe.impl.codeGen.from import org.jetbrains.kotlinx.dataframe.impl.codeGen.toStandaloneSnippet import org.jetbrains.kotlinx.dataframe.impl.codeGen.urlCodeGenReader import org.jetbrains.kotlinx.dataframe.impl.codeGen.urlDfReader +import java.io.File +import java.net.URL +import java.nio.file.Paths +import java.sql.DriverManager import org.jetbrains.kotlinx.dataframe.io.ArrowFeather import org.jetbrains.kotlinx.dataframe.io.CSV import org.jetbrains.kotlinx.dataframe.io.Excel import org.jetbrains.kotlinx.dataframe.io.JSON import org.jetbrains.kotlinx.dataframe.io.OpenApi import org.jetbrains.kotlinx.dataframe.io.TSV +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery +import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable import org.jetbrains.kotlinx.dataframe.io.isURL -import java.io.File -import java.net.URL -import java.nio.file.Paths abstract class GenerateDataSchemaTask : DefaultTask() { @@ -38,6 +42,9 @@ abstract class GenerateDataSchemaTask : DefaultTask() { @get:Input abstract val jsonOptions: Property + @get:Input + abstract val jdbcOptions: Property + @get:Input abstract val src: Property @@ -67,66 +74,104 @@ abstract class GenerateDataSchemaTask : DefaultTask() { fun generate() { val csvOptions = csvOptions.get() val jsonOptions = jsonOptions.get() - val url = urlOf(data.get()) + val jdbcOptions = jdbcOptions.get() val schemaFile = dataSchema.get() val escapedPackageName = escapePackageName(packageName.get()) - val formats = listOf( - CSV(delimiter = csvOptions.delimiter), - JSON(typeClashTactic = jsonOptions.typeClashTactic, keyValuePaths = jsonOptions.keyValuePaths), - Excel(), - TSV(), - ArrowFeather(), - OpenApi(), - ) - - // first try without creating dataframe - when (val codeGenResult = CodeGenerator.urlCodeGenReader(url, interfaceName.get(), formats, false)) { - is CodeGenerationReadResult.Success -> { - val readDfMethod = codeGenResult.getReadDfMethod(stringOf(data.get())) - val code = codeGenResult - .code - .toStandaloneSnippet(escapedPackageName, readDfMethod.additionalImports) - - schemaFile.bufferedWriter().use { - it.write(code) - } + val rawUrl = data.get().toString() + + // revisit architecture for an addition of the new data source https://github.com/Kotlin/dataframe/issues/450 + if (rawUrl.startsWith("jdbc")) { + val connection = DriverManager.getConnection(rawUrl, jdbcOptions.user, jdbcOptions.password) + connection.use { + val schema = if(jdbcOptions.sqlQuery.isBlank()) + DataFrame.getSchemaForSqlTable(connection, interfaceName.get()) + else DataFrame.getSchemaForSqlQuery(connection, jdbcOptions.sqlQuery) + + val codeGenerator = CodeGenerator.create(useFqNames = false) + + val additionalImports: List = listOf() + + val delimiters = delimiters.get() + val codeGenResult = codeGenerator.generate( + schema = schema, + name = interfaceName.get(), + fields = true, + extensionProperties = false, + isOpen = true, + visibility = when (schemaVisibility.get()) { + DataSchemaVisibility.INTERNAL -> MarkerVisibility.INTERNAL + DataSchemaVisibility.IMPLICIT_PUBLIC -> MarkerVisibility.IMPLICIT_PUBLIC + DataSchemaVisibility.EXPLICIT_PUBLIC -> MarkerVisibility.EXPLICIT_PUBLIC + else -> MarkerVisibility.IMPLICIT_PUBLIC + }, + readDfMethod = null, + fieldNameNormalizer = NameNormalizer.from(delimiters), + ) + + schemaFile.writeText(codeGenResult.toStandaloneSnippet(escapedPackageName, additionalImports)) return } + } else { + val url = urlOf(data.get()) + + val formats = listOf( + CSV(delimiter = csvOptions.delimiter), + JSON(typeClashTactic = jsonOptions.typeClashTactic, keyValuePaths = jsonOptions.keyValuePaths), + Excel(), + TSV(), + ArrowFeather(), + OpenApi(), + ) + + // first try without creating dataframe + when (val codeGenResult = CodeGenerator.urlCodeGenReader(url, interfaceName.get(), formats, false)) { + is CodeGenerationReadResult.Success -> { + val readDfMethod = codeGenResult.getReadDfMethod(stringOf(data.get())) + val code = codeGenResult + .code + .toStandaloneSnippet(escapedPackageName, readDfMethod.additionalImports) + + schemaFile.bufferedWriter().use { + it.write(code) + } + return + } + + is CodeGenerationReadResult.Error -> + logger.warn("Error while reading types-only from data at $url: ${codeGenResult.reason}") + } - is CodeGenerationReadResult.Error -> - logger.warn("Error while reading types-only from data at $url: ${codeGenResult.reason}") - } + // on error, try with reading dataframe first + val parsedDf = when (val readResult = CodeGenerator.urlDfReader(url, formats)) { + is DfReadResult.Error -> throw Exception( + "Error while reading dataframe from data at $url", + readResult.reason + ) - // on error, try with reading dataframe first - val parsedDf = when (val readResult = CodeGenerator.urlDfReader(url, formats)) { - is DfReadResult.Error -> throw Exception( - "Error while reading dataframe from data at $url", - readResult.reason - ) + is DfReadResult.Success -> readResult + } - is DfReadResult.Success -> readResult + val codeGenerator = CodeGenerator.create(useFqNames = false) + val delimiters = delimiters.get() + val readDfMethod = parsedDf.getReadDfMethod(stringOf(data.get())) + val codeGenResult = codeGenerator.generate( + schema = parsedDf.schema, + name = interfaceName.get(), + fields = true, + extensionProperties = false, + isOpen = true, + visibility = when (schemaVisibility.get()) { + DataSchemaVisibility.INTERNAL -> MarkerVisibility.INTERNAL + DataSchemaVisibility.IMPLICIT_PUBLIC -> MarkerVisibility.IMPLICIT_PUBLIC + DataSchemaVisibility.EXPLICIT_PUBLIC -> MarkerVisibility.EXPLICIT_PUBLIC + else -> MarkerVisibility.IMPLICIT_PUBLIC + }, + readDfMethod = readDfMethod, + fieldNameNormalizer = NameNormalizer.from(delimiters), + ) + schemaFile.writeText(codeGenResult.toStandaloneSnippet(escapedPackageName, readDfMethod.additionalImports)) } - - val codeGenerator = CodeGenerator.create(useFqNames = false) - val delimiters = delimiters.get() - val readDfMethod = parsedDf.getReadDfMethod(stringOf(data.get())) - val codeGenResult = codeGenerator.generate( - schema = parsedDf.schema, - name = interfaceName.get(), - fields = true, - extensionProperties = false, - isOpen = true, - visibility = when (schemaVisibility.get()) { - DataSchemaVisibility.INTERNAL -> MarkerVisibility.INTERNAL - DataSchemaVisibility.IMPLICIT_PUBLIC -> MarkerVisibility.IMPLICIT_PUBLIC - DataSchemaVisibility.EXPLICIT_PUBLIC -> MarkerVisibility.EXPLICIT_PUBLIC - else -> MarkerVisibility.IMPLICIT_PUBLIC - }, - readDfMethod = readDfMethod, - fieldNameNormalizer = NameNormalizer.from(delimiters), - ) - schemaFile.writeText(codeGenResult.toStandaloneSnippet(escapedPackageName, readDfMethod.additionalImports)) } private fun stringOf(data: Any): String = diff --git a/plugins/dataframe-gradle-plugin/src/main/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorExtension.kt b/plugins/dataframe-gradle-plugin/src/main/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorExtension.kt index 8a42793150..6911c87ace 100644 --- a/plugins/dataframe-gradle-plugin/src/main/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorExtension.kt +++ b/plugins/dataframe-gradle-plugin/src/main/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorExtension.kt @@ -61,6 +61,7 @@ class Schema( internal var withNormalizationBy: Set? = null, val csvOptions: CsvOptionsDsl = CsvOptionsDsl(), val jsonOptions: JsonOptionsDsl = JsonOptionsDsl(), + val jdbcOptions: JdbcOptionsDsl = JdbcOptionsDsl(), ) { fun setData(file: File) { data = file @@ -90,6 +91,14 @@ class Schema( project.configure(jsonOptions, config) } + fun jdbcOptions(config: JdbcOptionsDsl.() -> Unit) { + jdbcOptions.apply(config) + } + + fun jdbcOptions(config: Closure<*>) { + project.configure(jdbcOptions, config) + } + fun withoutDefaultPath() { defaultPath = false } @@ -122,3 +131,9 @@ data class JsonOptionsDsl( var typeClashTactic: JSON.TypeClashTactic = JSON.TypeClashTactic.ARRAY_AND_VALUE_COLUMNS, var keyValuePaths: List = emptyList(), ) : Serializable + +data class JdbcOptionsDsl( + var user: String = "", // TODO: I'm not sure about the default parameters + var password: String = "", // TODO: I'm not sure about the default parameters + var sqlQuery: String = "" +) : Serializable diff --git a/plugins/dataframe-gradle-plugin/src/main/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorPlugin.kt b/plugins/dataframe-gradle-plugin/src/main/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorPlugin.kt index 217ed4e3a2..4aa4760cc6 100644 --- a/plugins/dataframe-gradle-plugin/src/main/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorPlugin.kt +++ b/plugins/dataframe-gradle-plugin/src/main/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorPlugin.kt @@ -114,6 +114,7 @@ class SchemaGeneratorPlugin : Plugin { this.schemaVisibility.set(visibility) this.csvOptions.set(schema.csvOptions) this.jsonOptions.set(schema.jsonOptions) + this.jdbcOptions.set(schema.jdbcOptions) // TODO: probably remove this.defaultPath.set(defaultPath) this.delimiters.set(delimiters) } diff --git a/plugins/dataframe-gradle-plugin/src/test/kotlin/org/jetbrains/dataframe/gradle/DataFrameReadTest.kt b/plugins/dataframe-gradle-plugin/src/test/kotlin/org/jetbrains/dataframe/gradle/DataFrameReadTest.kt index dd7b5d91d0..a07602c592 100644 --- a/plugins/dataframe-gradle-plugin/src/test/kotlin/org/jetbrains/dataframe/gradle/DataFrameReadTest.kt +++ b/plugins/dataframe-gradle-plugin/src/test/kotlin/org/jetbrains/dataframe/gradle/DataFrameReadTest.kt @@ -8,6 +8,7 @@ import io.kotest.assertions.throwables.shouldThrowAny import io.kotest.matchers.shouldBe import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.io.read +import org.jetbrains.kotlinx.dataframe.io.readSqlTable import org.junit.Test import java.io.File import java.io.FileNotFoundException @@ -15,10 +16,10 @@ import java.io.IOException import java.net.URL import java.nio.file.Files import java.nio.file.Paths +import java.sql.DriverManager import kotlin.io.path.absolutePathString class DataFrameReadTest { - @Test fun `file that does not exists`() { val temp = Files.createTempDirectory("").toFile() @@ -94,4 +95,49 @@ class DataFrameReadTest { val df = DataFrame.read(temp) df.columnNames() shouldBe listOf("name", "age") } + + @Test + fun `jdbcSample is valid jdbc`() { + DriverManager.getConnection("jdbc:h2:mem:test;DB_CLOSE_DELAY=-1;MODE=MySQL;DATABASE_TO_UPPER=false") + .use { connection -> + // Create table Customer + connection.createStatement().execute( + """ + CREATE TABLE Customer ( + id INT PRIMARY KEY, + name VARCHAR(50), + age INT + ) + """.trimIndent() + ) + + // Create table Sale + connection.createStatement().execute( + """ + CREATE TABLE Sale ( + id INT PRIMARY KEY, + customerId INT, + amount DECIMAL(10, 2) + ) + """.trimIndent() + ) + + // add data to the Customer table + connection.createStatement().execute("INSERT INTO Customer (id, name, age) VALUES (1, 'John', 40)") + connection.createStatement().execute("INSERT INTO Customer (id, name, age) VALUES (2, 'Alice', 25)") + connection.createStatement().execute("INSERT INTO Customer (id, name, age) VALUES (3, 'Bob', 47)") + + // add data to the Sale table + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (1, 1, 100.50)") + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (2, 2, 50.00)") + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (3, 1, 75.25)") + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (4, 3, 35.15)") + + val dfCustomer = DataFrame.readSqlTable(connection, "Customer") + dfCustomer.columnNames() shouldBe listOf("id", "name", "age") + + val dfSale = DataFrame.readSqlTable(connection, "Sale") + dfSale.columnNames() shouldBe listOf("id", "customerId", "amount") + } + } } diff --git a/plugins/symbol-processor/build.gradle b/plugins/symbol-processor/build.gradle index 78e2a73e7a..594292bb6d 100644 --- a/plugins/symbol-processor/build.gradle +++ b/plugins/symbol-processor/build.gradle @@ -17,8 +17,11 @@ dependencies { implementation(project(":dataframe-arrow")) implementation(project(":dataframe-openapi")) implementation(project(":dataframe-excel")) + implementation(project(":dataframe-jdbc")) implementation(libs.ksp.api) implementation(libs.kotlin.reflect) + implementation(libs.h2db) + testImplementation(libs.h2db) testImplementation("org.jetbrains.kotlin:kotlin-test") testImplementation("com.github.tschuchortdev:kotlin-compile-testing:1.5.0") testImplementation("com.github.tschuchortdev:kotlin-compile-testing-ksp:1.5.0") diff --git a/plugins/symbol-processor/src/main/kotlin/org/jetbrains/dataframe/ksp/DataSchemaGenerator.kt b/plugins/symbol-processor/src/main/kotlin/org/jetbrains/dataframe/ksp/DataSchemaGenerator.kt index b98dc4ec91..0f4a5680ea 100644 --- a/plugins/symbol-processor/src/main/kotlin/org/jetbrains/dataframe/ksp/DataSchemaGenerator.kt +++ b/plugins/symbol-processor/src/main/kotlin/org/jetbrains/dataframe/ksp/DataSchemaGenerator.kt @@ -7,9 +7,11 @@ import com.google.devtools.ksp.processing.KSPLogger import com.google.devtools.ksp.processing.Resolver import com.google.devtools.ksp.symbol.KSFile import org.jetbrains.dataframe.impl.codeGen.CodeGenerator +import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.annotations.CsvOptions import org.jetbrains.kotlinx.dataframe.annotations.DataSchemaVisibility import org.jetbrains.kotlinx.dataframe.annotations.ImportDataSchema +import org.jetbrains.kotlinx.dataframe.annotations.JdbcOptions import org.jetbrains.kotlinx.dataframe.annotations.JsonOptions import org.jetbrains.kotlinx.dataframe.api.JsonPath import org.jetbrains.kotlinx.dataframe.codeGen.MarkerVisibility @@ -20,16 +22,11 @@ import org.jetbrains.kotlinx.dataframe.impl.codeGen.from import org.jetbrains.kotlinx.dataframe.impl.codeGen.toStandaloneSnippet import org.jetbrains.kotlinx.dataframe.impl.codeGen.urlCodeGenReader import org.jetbrains.kotlinx.dataframe.impl.codeGen.urlDfReader -import org.jetbrains.kotlinx.dataframe.io.ArrowFeather -import org.jetbrains.kotlinx.dataframe.io.CSV -import org.jetbrains.kotlinx.dataframe.io.Excel -import org.jetbrains.kotlinx.dataframe.io.JSON -import org.jetbrains.kotlinx.dataframe.io.OpenApi -import org.jetbrains.kotlinx.dataframe.io.TSV -import org.jetbrains.kotlinx.dataframe.io.isURL +import org.jetbrains.kotlinx.dataframe.io.* import java.io.File import java.net.MalformedURLException import java.net.URL +import java.sql.DriverManager @OptIn(KspExperimental::class) class DataSchemaGenerator( @@ -52,6 +49,8 @@ class DataSchemaGenerator( val withDefaultPath: Boolean, val csvOptions: CsvOptions, val jsonOptions: JsonOptions, + val jdbcOptions: JdbcOptions, + val isJdbc: Boolean = false, ) class CodeGeneratorDataSource(val pathRepresentation: String, val data: URL) @@ -72,6 +71,23 @@ class DataSchemaGenerator( return null } } else { + // revisit architecture for an addition of the new data source https://github.com/Kotlin/dataframe/issues/450 + if(path.startsWith("jdbc")) { + return ImportDataSchemaStatement( + origin = file, + name = name, + // URL better to make nullable or make hierarchy here + dataSource = CodeGeneratorDataSource(this.path, URL("http://example.com/pages/")), + visibility = visibility.toMarkerVisibility(), + normalizationDelimiters = normalizationDelimiters.toList(), + withDefaultPath = withDefaultPath, + csvOptions = csvOptions, + jsonOptions = jsonOptions, + jdbcOptions = jdbcOptions, + isJdbc = true + ) + } + val resolutionDir: String = resolutionDir ?: run { reportMissingKspArgument(file) return null @@ -100,6 +116,7 @@ class DataSchemaGenerator( withDefaultPath = withDefaultPath, csvOptions = csvOptions, jsonOptions = jsonOptions, + jdbcOptions = jdbcOptions, ) } @@ -138,9 +155,54 @@ class DataSchemaGenerator( OpenApi(), ) - // first try without creating dataframe - when (val codeGenResult = - CodeGenerator.urlCodeGenReader(importStatement.dataSource.data, name, formats, false)) { + // revisit architecture for an addition of the new data source https://github.com/Kotlin/dataframe/issues/450 + if (importStatement.isJdbc) { + val url = importStatement.dataSource.pathRepresentation + + if(url.contains("h2")) Class.forName("org.h2.Driver") + + val connection = DriverManager.getConnection( + url, + importStatement.jdbcOptions.user, + importStatement.jdbcOptions.password + ) + + connection.use { + val schema = if(importStatement.jdbcOptions.sqlQuery.isBlank()) + DataFrame.getSchemaForSqlTable(connection, importStatement.name) + else DataFrame.getSchemaForSqlQuery(connection, importStatement.jdbcOptions.sqlQuery) + + val codeGenerator = CodeGenerator.create(useFqNames = false) + + val additionalImports: List = listOf() + + val codeGenResult = codeGenerator.generate( + schema = schema, + name = name, + fields = true, + extensionProperties = false, + isOpen = true, + visibility = importStatement.visibility, + knownMarkers = emptyList(), + readDfMethod = null, + fieldNameNormalizer = NameNormalizer.from(importStatement.normalizationDelimiters.toSet()) + ) + val code = codeGenResult.toStandaloneSnippet(packageName, additionalImports) + schemaFile.bufferedWriter().use { + it.write(code) + } + return + } + } + + // revisit architecture for an addition of the new data source https://github.com/Kotlin/dataframe/issues/450 + // works for JDBC and OpenAPI only + // first try without creating a dataframe + when (val codeGenResult = if (importStatement.isJdbc) { + CodeGenerator.databaseCodeGenReader(importStatement.dataSource.data, name) + } else { + CodeGenerator.urlCodeGenReader(importStatement.dataSource.data, name, formats, false) + }) { is CodeGenerationReadResult.Success -> { val readDfMethod = codeGenResult.getReadDfMethod( pathRepresentation = importStatement @@ -164,6 +226,7 @@ class DataSchemaGenerator( } } + // Usually works for others // on error, try with reading dataframe first val parsedDf = when (val readResult = CodeGenerator.urlDfReader(importStatement.dataSource.data, formats)) { is DfReadResult.Error -> { @@ -195,3 +258,4 @@ class DataSchemaGenerator( } } } + diff --git a/plugins/symbol-processor/src/test/kotlin/org/jetbrains/dataframe/ksp/DataFrameJdbcSymbolProcessorTest.kt b/plugins/symbol-processor/src/test/kotlin/org/jetbrains/dataframe/ksp/DataFrameJdbcSymbolProcessorTest.kt new file mode 100644 index 0000000000..acfe192c53 --- /dev/null +++ b/plugins/symbol-processor/src/test/kotlin/org/jetbrains/dataframe/ksp/DataFrameJdbcSymbolProcessorTest.kt @@ -0,0 +1,211 @@ +package org.jetbrains.dataframe.ksp + +import com.tschuchort.compiletesting.SourceFile +import io.kotest.assertions.asClue +import io.kotest.inspectors.forAtLeastOne +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain +import org.jetbrains.dataframe.ksp.runner.KotlinCompileTestingCompilationResult +import org.jetbrains.dataframe.ksp.runner.KspCompilationTestRunner +import org.jetbrains.dataframe.ksp.runner.TestCompilationParameters +import org.junit.AfterClass +import org.junit.Before +import org.junit.BeforeClass +import java.sql.Connection +import java.sql.DriverManager +import java.sql.SQLException +import kotlin.test.Test + +const val CONNECTION_URL = "jdbc:h2:mem:test;DB_CLOSE_DELAY=-1;MODE=MySQL;DATABASE_TO_UPPER=false" + +@Suppress("unused") +class DataFrameJdbcSymbolProcessorTest { + companion object { + private lateinit var connection: Connection + + val imports = """ + import org.jetbrains.kotlinx.dataframe.annotations.* + import org.jetbrains.kotlinx.dataframe.columns.* + import org.jetbrains.kotlinx.dataframe.* + """.trimIndent() + + const val generatedFile = "HelloJdbc${'$'}Extensions.kt" + + @JvmStatic + @BeforeClass + fun setupDB() { + connection = DriverManager.getConnection(CONNECTION_URL) + createTestDatabase(connection) + } + + @JvmStatic + @AfterClass + fun close(): Unit { + try { + connection.close() + } catch (e: SQLException) { + e.printStackTrace() + } + } + + private fun createTestDatabase(connection: Connection) { + // Crate table Customer + connection.createStatement().execute( + """ + CREATE TABLE Customer ( + id INT PRIMARY KEY, + name VARCHAR(50), + age INT + ) + """.trimIndent() + ) + + // Create table Sale + connection.createStatement().execute( + """ + CREATE TABLE Sale ( + id INT PRIMARY KEY, + customerId INT, + amount DECIMAL(10, 2) + ) + """.trimIndent() + ) + + // add data to the Customer table + connection.createStatement().execute("INSERT INTO Customer (id, name, age) VALUES (1, 'John', 40)") + connection.createStatement().execute("INSERT INTO Customer (id, name, age) VALUES (2, 'Alice', 25)") + connection.createStatement().execute("INSERT INTO Customer (id, name, age) VALUES (3, 'Bob', 47)") + + // add data to the Sale table + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (1, 1, 100.50)") + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (2, 2, 50.00)") + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (3, 1, 75.25)") + connection.createStatement().execute("INSERT INTO Sale (id, customerId, amount) VALUES (4, 3, 35.15)") + } + } + + @Before + fun setup() { + KspCompilationTestRunner.compilationDir.deleteRecursively() + } + + @Test + fun `failed compilation on wrong `() { + val result = KspCompilationTestRunner.compile( + TestCompilationParameters( + sources = listOf( + SourceFile.kotlin( + "MySources.kt", + """ + @file:ImportDataSchema(name = "Customer", path = "123") + + package test + + import org.jetbrains.kotlinx.dataframe.annotations.ImportDataSchema + import org.jetbrains.kotlinx.dataframe.api.filter + import org.jetbrains.kotlinx.dataframe.DataFrame + import org.jetbrains.kotlinx.dataframe.api.cast + import java.sql.Connection + import java.sql.DriverManager + import java.sql.SQLException + import org.jetbrains.kotlinx.dataframe.io.readSqlTable + import org.jetbrains.kotlinx.dataframe.io.DatabaseConfiguration + """.trimIndent() + ) + ) + ) + ) + result.successfulCompilation shouldBe false + } + + @Test + fun `schema is imported`() { + val result = KspCompilationTestRunner.compile( + TestCompilationParameters( + sources = listOf( + SourceFile.kotlin( + "MySources.kt", + """ + @file:ImportDataSchema(name = "Customer", path = "$CONNECTION_URL") + + package test + + import org.jetbrains.kotlinx.dataframe.annotations.ImportDataSchema + import org.jetbrains.kotlinx.dataframe.api.filter + import org.jetbrains.kotlinx.dataframe.DataFrame + import org.jetbrains.kotlinx.dataframe.api.cast + import java.sql.Connection + import java.sql.DriverManager + import java.sql.SQLException + import org.jetbrains.kotlinx.dataframe.io.readSqlTable + import org.jetbrains.kotlinx.dataframe.io.DatabaseConfiguration + """.trimIndent() + ) + ) + ) + ) + println(result.kspGeneratedFiles) + result.inspectLines("Customer.Generated.kt") { + it.forAtLeastOne { it shouldContain "val name: String" } + } + } + + /** + * Test code is copied from h2Test `read from table` test. + */ + @Test + fun `schema extracted via readFromDB method is resolved`() { + val result = KspCompilationTestRunner.compile( + TestCompilationParameters( + sources = listOf( + SourceFile.kotlin( + "MySources.kt", + """ + @file:ImportDataSchema(name = "Customer", path = "$CONNECTION_URL") + + package test + + import org.jetbrains.kotlinx.dataframe.annotations.ImportDataSchema + import org.jetbrains.kotlinx.dataframe.api.filter + import org.jetbrains.kotlinx.dataframe.DataFrame + import org.jetbrains.kotlinx.dataframe.api.cast + import java.sql.Connection + import java.sql.DriverManager + import java.sql.SQLException + import org.jetbrains.kotlinx.dataframe.io.readSqlTable + import org.jetbrains.kotlinx.dataframe.io.DatabaseConfiguration + + fun main() { + val tableName = "Customer" + DriverManager.getConnection("$CONNECTION_URL").use { connection -> + val df = DataFrame.readSqlTable(connection, tableName).cast() + df.filter { age > 30 } + + val df1 = DataFrame.readSqlTable(connection, tableName, 1).cast() + df1.filter { age > 30 } + + val dbConfig = DatabaseConfiguration(url = "$CONNECTION_URL") + val df2 = DataFrame.readSqlTable(dbConfig, tableName).cast() + df2.filter { age > 30 } + + val df3 = DataFrame.readSqlTable(dbConfig, tableName, 1).cast() + df3.filter { age > 30 } + + } + } + """.trimIndent() + ) + ) + ) + ) + result.successfulCompilation shouldBe true + } + + private fun KotlinCompileTestingCompilationResult.inspectLines(f: (List) -> Unit) { + inspectLines(generatedFile, f) + } + + private fun KotlinCompileTestingCompilationResult.inspectLines(filename: String, f: (List) -> Unit) { + kspGeneratedFiles.single { it.name == filename }.readLines().asClue(f) + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index e1c2250aed..07f677cd2a 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -10,6 +10,7 @@ include("tests") include("dataframe-arrow") include("dataframe-openapi") include("dataframe-excel") +include("dataframe-jdbc") include("core") include("examples:idea-examples:titanic") diff --git a/tests/build.gradle.kts b/tests/build.gradle.kts index 37432f8896..d1227be119 100644 --- a/tests/build.gradle.kts +++ b/tests/build.gradle.kts @@ -18,6 +18,7 @@ repositories { dependencies { implementation(project(":core")) implementation(project(":dataframe-excel")) + implementation(project(":dataframe-jdbc")) implementation(project(":dataframe-arrow")) testImplementation(libs.junit) testImplementation(libs.kotestAssertions) {