diff --git a/README.md b/README.md index c528ccc759..8de27b8bb2 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,7 @@ This table shows the mapping between main library component versions and minimum | 0.11.0 | 8 | 1.8.20 | 0.11.0-358 | 3.0.0 | 11.0.0 | | 0.11.1 | 8 | 1.8.20 | 0.11.0-358 | 3.0.0 | 11.0.0 | | 0.12.0 | 8 | 1.9.0 | 0.11.0-358 | 3.0.0 | 11.0.0 | +| 0.12.1 | 8 | 1.9.0 | 0.11.0-358 | 3.0.0 | 11.0.0 | ## Usage example diff --git a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/rename.kt b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/rename.kt index 2e7e8ece2c..b6130abf36 100644 --- a/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/rename.kt +++ b/core/generated-sources/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/rename.kt @@ -22,14 +22,14 @@ internal fun RenameClause.renameImpl(newNames: Array): internal fun RenameClause.renameImpl(transform: (ColumnWithPath) -> String): DataFrame { // get all selected columns and their paths val selectedColumnsWithPath = df.getColumnsWithPaths(columns) - .associateBy { it.data } + .associateBy { it.path } // gather a tree of all columns where the nodes will be renamed val tree = df.getColumnsWithPaths { all().rec() }.collectTree() // perform rename in nodes - tree.allChildrenNotNull().forEach { node -> + tree.allChildrenNotNull().map { it to it.pathFromRoot() }.forEach { (node, originalPath) -> // Check if the current node/column is a selected column and, if so, get its ColumnWithPath - val column = selectedColumnsWithPath[node.data] ?: return@forEach + val column = selectedColumnsWithPath[originalPath] ?: return@forEach // Use the found selected ColumnWithPath to query for the new name val newColumnName = transform(column) node.name = newColumnName diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt index dafcc375ce..c27351c42c 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt @@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema import java.sql.ResultSet import org.jetbrains.kotlinx.dataframe.io.TableMetadata +import kotlin.reflect.KType /** * The `DbType` class represents a database type used for reading dataframe from the database. @@ -22,19 +23,10 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) { */ public abstract val driverClassName: String - /** - * Converts the data from the given [ResultSet] into the specified [TableColumnMetadata] type. - * - * @param rs The [ResultSet] containing the data to be converted. - * @param tableColumnMetadata The [TableColumnMetadata] representing the target type of the conversion. - * @return The converted data as an instance of [Any]. - */ - public abstract fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? - /** * Returns a [ColumnSchema] produced from [tableColumnMetadata]. */ - public abstract fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema + public abstract fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? /** * Checks if the given table name is a system table for the specified database type. @@ -52,4 +44,12 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) { * @return the TableMetadata object representing the table metadata. */ public abstract fun buildTableMetadata(tables: ResultSet): TableMetadata + + /** + * Converts SQL data type to a Kotlin data type. + * + * @param [tableColumnMetadata] The metadata of the table column. + * @return The corresponding Kotlin data type, or null if no mapping is found. + */ + public abstract fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt index 38a4970e59..cf967806f8 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt @@ -4,10 +4,8 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema import java.sql.ResultSet import java.util.Locale -import org.jetbrains.kotlinx.dataframe.DataRow -import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup import org.jetbrains.kotlinx.dataframe.io.TableMetadata -import kotlin.reflect.typeOf +import kotlin.reflect.KType /** * Represents the H2 database type. @@ -21,71 +19,8 @@ public object H2 : DbType("h2") { override val driverClassName: String get() = "org.h2.Driver" - override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? { - val name = tableColumnMetadata.name - return when (tableColumnMetadata.sqlTypeName) { - "CHARACTER", "CHAR" -> rs.getString(name) - "CHARACTER VARYING", "CHAR VARYING", "VARCHAR" -> rs.getString(name) - "CHARACTER LARGE OBJECT", "CHAR LARGE OBJECT", "CLOB" -> rs.getString(name) - "MEDIUMTEXT" -> rs.getString(name) - "VARCHAR_IGNORECASE" -> rs.getString(name) - "BINARY" -> rs.getBytes(name) - "BINARY VARYING", "VARBINARY" -> rs.getBytes(name) - "BINARY LARGE OBJECT", "BLOB" -> rs.getBytes(name) - "BOOLEAN" -> rs.getBoolean(name) - "TINYINT" -> rs.getByte(name) - "SMALLINT" -> rs.getShort(name) - "INTEGER", "INT" -> rs.getInt(name) - "BIGINT" -> rs.getLong(name) - "NUMERIC", "DECIMAL", "DEC" -> rs.getFloat(name) // not a BigDecimal - "REAL", "FLOAT" -> rs.getFloat(name) - "DOUBLE PRECISION" -> rs.getDouble(name) - "DECFLOAT" -> rs.getDouble(name) - "DATE" -> rs.getDate(name).toString() - "TIME" -> rs.getTime(name).toString() - "TIME WITH TIME ZONE" -> rs.getTime(name).toString() - "TIMESTAMP" -> rs.getTimestamp(name).toString() - "TIMESTAMP WITH TIME ZONE" -> rs.getTimestamp(name).toString() - "INTERVAL" -> rs.getObject(name).toString() - "JAVA_OBJECT" -> rs.getObject(name) - "ENUM" -> rs.getString(name) - "JSON" -> rs.getString(name) // TODO: https://github.com/Kotlin/dataframe/issues/462 - "UUID" -> rs.getString(name) - else -> throw IllegalArgumentException("Unsupported H2 type: ${tableColumnMetadata.sqlTypeName}") - } - } - - override fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema { - return when (tableColumnMetadata.sqlTypeName) { - "CHARACTER", "CHAR" -> ColumnSchema.Value(typeOf()) - "CHARACTER VARYING", "CHAR VARYING", "VARCHAR" -> ColumnSchema.Value(typeOf()) - "CHARACTER LARGE OBJECT", "CHAR LARGE OBJECT", "CLOB" -> ColumnSchema.Value(typeOf()) - "MEDIUMTEXT" -> ColumnSchema.Value(typeOf()) - "VARCHAR_IGNORECASE" -> ColumnSchema.Value(typeOf()) - "BINARY" -> ColumnSchema.Value(typeOf()) - "BINARY VARYING", "VARBINARY" -> ColumnSchema.Value(typeOf()) - "BINARY LARGE OBJECT", "BLOB" -> ColumnSchema.Value(typeOf()) - "BOOLEAN" -> ColumnSchema.Value(typeOf()) - "TINYINT" -> ColumnSchema.Value(typeOf()) - "SMALLINT" -> ColumnSchema.Value(typeOf()) - "INTEGER", "INT" -> ColumnSchema.Value(typeOf()) - "BIGINT" -> ColumnSchema.Value(typeOf()) - "NUMERIC", "DECIMAL", "DEC" -> ColumnSchema.Value(typeOf()) - "REAL", "FLOAT" -> ColumnSchema.Value(typeOf()) - "DOUBLE PRECISION" -> ColumnSchema.Value(typeOf()) - "DECFLOAT" -> ColumnSchema.Value(typeOf()) - "DATE" -> ColumnSchema.Value(typeOf()) - "TIME" -> ColumnSchema.Value(typeOf()) - "TIME WITH TIME ZONE" -> ColumnSchema.Value(typeOf()) - "TIMESTAMP" -> ColumnSchema.Value(typeOf()) - "TIMESTAMP WITH TIME ZONE" -> ColumnSchema.Value(typeOf()) - "INTERVAL" -> ColumnSchema.Value(typeOf()) - "JAVA_OBJECT" -> ColumnSchema.Value(typeOf()) - "ENUM" -> ColumnSchema.Value(typeOf()) - "JSON" -> ColumnSchema.Value(typeOf()) // TODO: https://github.com/Kotlin/dataframe/issues/462 - "UUID" -> ColumnSchema.Value(typeOf()) - else -> throw IllegalArgumentException("Unsupported H2 type: ${tableColumnMetadata.sqlTypeName} for column ${tableColumnMetadata.name}") - } + override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? { + return null } override fun isSystemTable(tableMetadata: TableMetadata): Boolean { @@ -99,4 +34,8 @@ public object H2 : DbType("h2") { tables.getString("TABLE_SCHEM"), tables.getString("TABLE_CAT")) } + + override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? { + return null + } } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MariaDb.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MariaDb.kt index fd06297453..03716edd82 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MariaDb.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MariaDb.kt @@ -4,7 +4,7 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema import java.sql.ResultSet import org.jetbrains.kotlinx.dataframe.io.TableMetadata -import kotlin.reflect.typeOf +import kotlin.reflect.KType /** * Represents the MariaDb database type. @@ -16,73 +16,8 @@ public object MariaDb : DbType("mariadb") { override val driverClassName: String get() = "org.mariadb.jdbc.Driver" - override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? { - val name = tableColumnMetadata.name - return when (tableColumnMetadata.sqlTypeName) { - "BIT" -> rs.getBytes(name) - "TINYINT" -> rs.getInt(name) - "SMALLINT" -> rs.getInt(name) - "MEDIUMINT"-> rs.getInt(name) - "MEDIUMINT UNSIGNED" -> rs.getLong(name) - "INTEGER", "INT" -> rs.getInt(name) - "INTEGER UNSIGNED", "INT UNSIGNED" -> rs.getLong(name) - "BIGINT" -> rs.getLong(name) - "FLOAT" -> rs.getFloat(name) - "DOUBLE" -> rs.getDouble(name) - "DECIMAL" -> rs.getBigDecimal(name) - "DATE" -> rs.getDate(name).toString() - "DATETIME" -> rs.getTimestamp(name).toString() - "TIMESTAMP" -> rs.getTimestamp(name).toString() - "TIME"-> rs.getTime(name).toString() - "YEAR" -> rs.getDate(name).toString() - "VARCHAR", "CHAR" -> rs.getString(name) - "BINARY" -> rs.getBytes(name) - "VARBINARY" -> rs.getBytes(name) - "TINYBLOB"-> rs.getBytes(name) - "BLOB"-> rs.getBytes(name) - "MEDIUMBLOB" -> rs.getBytes(name) - "LONGBLOB" -> rs.getBytes(name) - "TEXT" -> rs.getString(name) - "MEDIUMTEXT" -> rs.getString(name) - "LONGTEXT" -> rs.getString(name) - "ENUM" -> rs.getString(name) - "SET" -> rs.getString(name) - else -> throw IllegalArgumentException("Unsupported MariaDB type: ${tableColumnMetadata.sqlTypeName}") - } - } - - override fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema { - return when (tableColumnMetadata.sqlTypeName) { - "BIT" -> ColumnSchema.Value(typeOf()) - "TINYINT" -> ColumnSchema.Value(typeOf()) - "SMALLINT" -> ColumnSchema.Value(typeOf()) - "MEDIUMINT"-> ColumnSchema.Value(typeOf()) - "MEDIUMINT UNSIGNED" -> ColumnSchema.Value(typeOf()) - "INTEGER", "INT" -> ColumnSchema.Value(typeOf()) - "INTEGER UNSIGNED", "INT UNSIGNED" -> ColumnSchema.Value(typeOf()) - "BIGINT" -> ColumnSchema.Value(typeOf()) - "FLOAT" -> ColumnSchema.Value(typeOf()) - "DOUBLE" -> ColumnSchema.Value(typeOf()) - "DECIMAL" -> ColumnSchema.Value(typeOf()) - "DATE" -> ColumnSchema.Value(typeOf()) - "DATETIME" -> ColumnSchema.Value(typeOf()) - "TIMESTAMP" -> ColumnSchema.Value(typeOf()) - "TIME"-> ColumnSchema.Value(typeOf()) - "YEAR" -> ColumnSchema.Value(typeOf()) - "VARCHAR", "CHAR" -> ColumnSchema.Value(typeOf()) - "BINARY" -> ColumnSchema.Value(typeOf()) - "VARBINARY" -> ColumnSchema.Value(typeOf()) - "TINYBLOB"-> ColumnSchema.Value(typeOf()) - "BLOB"-> ColumnSchema.Value(typeOf()) - "MEDIUMBLOB" -> ColumnSchema.Value(typeOf()) - "LONGBLOB" -> ColumnSchema.Value(typeOf()) - "TEXT" -> ColumnSchema.Value(typeOf()) - "MEDIUMTEXT" -> ColumnSchema.Value(typeOf()) - "LONGTEXT" -> ColumnSchema.Value(typeOf()) - "ENUM" -> ColumnSchema.Value(typeOf()) - "SET" -> ColumnSchema.Value(typeOf()) - else -> throw IllegalArgumentException("Unsupported MariaDB type: ${tableColumnMetadata.sqlTypeName} for column ${tableColumnMetadata.name}") - } + override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? { + return null } override fun isSystemTable(tableMetadata: TableMetadata): Boolean { @@ -95,4 +30,8 @@ public object MariaDb : DbType("mariadb") { tables.getString("table_schem"), tables.getString("table_cat")) } + + override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? { + return null + } } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MySql.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MySql.kt index 95d9335a8e..40b1f4a3dc 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MySql.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MySql.kt @@ -4,10 +4,8 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema import java.sql.ResultSet import java.util.Locale -import org.jetbrains.kotlinx.dataframe.DataRow -import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup import org.jetbrains.kotlinx.dataframe.io.TableMetadata -import kotlin.reflect.typeOf +import kotlin.reflect.KType /** * Represents the MySql database type. @@ -19,79 +17,8 @@ public object MySql : DbType("mysql") { override val driverClassName: String get() = "com.mysql.jdbc.Driver" - override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? { - val name = tableColumnMetadata.name - return when (tableColumnMetadata.sqlTypeName) { - "BIT" -> rs.getBytes(name) - "TINYINT" -> rs.getInt(name) - "SMALLINT" -> rs.getInt(name) - "MEDIUMINT"-> rs.getInt(name) - "MEDIUMINT UNSIGNED" -> rs.getLong(name) - "INTEGER", "INT" -> rs.getInt(name) - "INTEGER UNSIGNED", "INT UNSIGNED" -> rs.getLong(name) - "BIGINT" -> rs.getLong(name) - "FLOAT" -> rs.getFloat(name) - "DOUBLE" -> rs.getDouble(name) - "DECIMAL" -> rs.getBigDecimal(name) - "DATE" -> rs.getDate(name).toString() - "DATETIME" -> rs.getTimestamp(name).toString() - "TIMESTAMP" -> rs.getTimestamp(name).toString() - "TIME"-> rs.getTime(name).toString() - "YEAR" -> rs.getDate(name).toString() - "VARCHAR", "CHAR" -> rs.getString(name) - "BINARY" -> rs.getBytes(name) - "VARBINARY" -> rs.getBytes(name) - "TINYBLOB"-> rs.getBytes(name) - "BLOB"-> rs.getBytes(name) - "MEDIUMBLOB" -> rs.getBytes(name) - "LONGBLOB" -> rs.getBytes(name) - "TEXT" -> rs.getString(name) - "MEDIUMTEXT" -> rs.getString(name) - "LONGTEXT" -> rs.getString(name) - "ENUM" -> rs.getString(name) - "SET" -> rs.getString(name) - // special mysql types - "JSON" -> rs.getString(name) // TODO: https://github.com/Kotlin/dataframe/issues/462 - "GEOMETRY" -> rs.getBytes(name) - else -> throw IllegalArgumentException("Unsupported MySQL type: ${tableColumnMetadata.sqlTypeName}") - } - } - - override fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema { - return when (tableColumnMetadata.sqlTypeName) { - "BIT" -> ColumnSchema.Value(typeOf()) - "TINYINT" -> ColumnSchema.Value(typeOf()) - "SMALLINT" -> ColumnSchema.Value(typeOf()) - "MEDIUMINT"-> ColumnSchema.Value(typeOf()) - "MEDIUMINT UNSIGNED" -> ColumnSchema.Value(typeOf()) - "INTEGER", "INT" -> ColumnSchema.Value(typeOf()) - "INTEGER UNSIGNED", "INT UNSIGNED" -> ColumnSchema.Value(typeOf()) - "BIGINT" -> ColumnSchema.Value(typeOf()) - "FLOAT" -> ColumnSchema.Value(typeOf()) - "DOUBLE" -> ColumnSchema.Value(typeOf()) - "DECIMAL" -> ColumnSchema.Value(typeOf()) - "DATE" -> ColumnSchema.Value(typeOf()) - "DATETIME" -> ColumnSchema.Value(typeOf()) - "TIMESTAMP" -> ColumnSchema.Value(typeOf()) - "TIME"-> ColumnSchema.Value(typeOf()) - "YEAR" -> ColumnSchema.Value(typeOf()) - "VARCHAR", "CHAR" -> ColumnSchema.Value(typeOf()) - "BINARY" -> ColumnSchema.Value(typeOf()) - "VARBINARY" -> ColumnSchema.Value(typeOf()) - "TINYBLOB"-> ColumnSchema.Value(typeOf()) - "BLOB"-> ColumnSchema.Value(typeOf()) - "MEDIUMBLOB" -> ColumnSchema.Value(typeOf()) - "LONGBLOB" -> ColumnSchema.Value(typeOf()) - "TEXT" -> ColumnSchema.Value(typeOf()) - "MEDIUMTEXT" -> ColumnSchema.Value(typeOf()) - "LONGTEXT" -> ColumnSchema.Value(typeOf()) - "ENUM" -> ColumnSchema.Value(typeOf()) - "SET" -> ColumnSchema.Value(typeOf()) - // special mysql types - "JSON" -> ColumnSchema.Value(typeOf>>()) // TODO: https://github.com/Kotlin/dataframe/issues/462 - "GEOMETRY" -> ColumnSchema.Value(typeOf()) - else -> throw IllegalArgumentException("Unsupported MySQL type: ${tableColumnMetadata.sqlTypeName} for column ${tableColumnMetadata.name}") - } + override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? { + return null } override fun isSystemTable(tableMetadata: TableMetadata): Boolean { @@ -116,4 +43,8 @@ public object MySql : DbType("mysql") { tables.getString("table_schem"), tables.getString("table_cat")) } + + override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? { + return null + } } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/PostgreSql.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/PostgreSql.kt index 3bc6e09f22..b52af3d5a3 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/PostgreSql.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/PostgreSql.kt @@ -4,9 +4,8 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema import java.sql.ResultSet import java.util.Locale -import org.jetbrains.kotlinx.dataframe.DataRow -import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup import org.jetbrains.kotlinx.dataframe.io.TableMetadata +import kotlin.reflect.KType import kotlin.reflect.typeOf /** @@ -19,77 +18,8 @@ public object PostgreSql : DbType("postgresql") { override val driverClassName: String get() = "org.postgresql.Driver" - override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? { - val name = tableColumnMetadata.name - return when (tableColumnMetadata.sqlTypeName) { - "serial" -> rs.getInt(name) - "int8", "bigint", "bigserial" -> rs.getLong(name) - "bool" -> rs.getBoolean(name) - "box" -> rs.getString(name) - "bytea" -> rs.getBytes(name) - "character", "bpchar" -> rs.getString(name) - "circle" -> rs.getString(name) - "date" -> rs.getDate(name).toString() - "float8", "double precision" -> rs.getDouble(name) - "int4", "integer" -> rs.getInt(name) - "interval" -> rs.getString(name) - "json", "jsonb" -> rs.getString(name) // TODO: https://github.com/Kotlin/dataframe/issues/462 - "line" -> rs.getString(name) - "lseg" -> rs.getString(name) - "macaddr" -> rs.getString(name) - "money" -> rs.getString(name) - "numeric" -> rs.getString(name) - "path" -> rs.getString(name) - "point" -> rs.getString(name) - "polygon" -> rs.getString(name) - "float4", "real" -> rs.getFloat(name) - "int2", "smallint" -> rs.getShort(name) - "smallserial" -> rs.getInt(name) - "text" -> rs.getString(name) - "time" -> rs.getString(name) - "timetz", "time with time zone" -> rs.getString(name) - "timestamp" -> rs.getString(name) - "timestamptz", "timestamp with time zone" -> rs.getString(name) - "uuid" -> rs.getString(name) - "xml" -> rs.getString(name) - else -> throw IllegalArgumentException("Unsupported PostgreSQL type: ${tableColumnMetadata.sqlTypeName}") - } - } - - override fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema { - return when (tableColumnMetadata.sqlTypeName) { - "serial" -> ColumnSchema.Value(typeOf()) - "int8", "bigint", "bigserial" -> ColumnSchema.Value(typeOf()) - "bool" -> ColumnSchema.Value(typeOf()) - "box" -> ColumnSchema.Value(typeOf()) - "bytea" -> ColumnSchema.Value(typeOf()) - "character", "bpchar" -> ColumnSchema.Value(typeOf()) - "circle" -> ColumnSchema.Value(typeOf()) - "date" -> ColumnSchema.Value(typeOf()) - "float8", "double precision" -> ColumnSchema.Value(typeOf()) - "int4", "integer" -> ColumnSchema.Value(typeOf()) - "interval" -> ColumnSchema.Value(typeOf()) - "json", "jsonb" -> ColumnSchema.Value(typeOf>>()) // TODO: https://github.com/Kotlin/dataframe/issues/462 - "line" -> ColumnSchema.Value(typeOf()) - "lseg" -> ColumnSchema.Value(typeOf()) - "macaddr" -> ColumnSchema.Value(typeOf()) - "money" -> ColumnSchema.Value(typeOf()) - "numeric" -> ColumnSchema.Value(typeOf()) - "path" -> ColumnSchema.Value(typeOf()) - "point" -> ColumnSchema.Value(typeOf()) - "polygon" -> ColumnSchema.Value(typeOf()) - "float4", "real" -> ColumnSchema.Value(typeOf()) - "int2", "smallint" -> ColumnSchema.Value(typeOf()) - "smallserial" -> ColumnSchema.Value(typeOf()) - "text" -> ColumnSchema.Value(typeOf()) - "time" -> ColumnSchema.Value(typeOf()) - "timetz", "time with time zone" -> ColumnSchema.Value(typeOf()) - "timestamp" -> ColumnSchema.Value(typeOf()) - "timestamptz", "timestamp with time zone" -> ColumnSchema.Value(typeOf()) - "uuid" -> ColumnSchema.Value(typeOf()) - "xml" -> ColumnSchema.Value(typeOf()) - else -> throw IllegalArgumentException("Unsupported PostgreSQL type: ${tableColumnMetadata.sqlTypeName} for column ${tableColumnMetadata.name}") - } + override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? { + return null } override fun isSystemTable(tableMetadata: TableMetadata): Boolean { @@ -103,4 +33,10 @@ public object PostgreSql : DbType("postgresql") { tables.getString("table_schem"), tables.getString("table_cat")) } + + override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? { + if(tableColumnMetadata.sqlTypeName == "money") // because of https://github.com/pgjdbc/pgjdbc/issues/425 + return typeOf() + return null + } } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/Sqlite.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/Sqlite.kt index 251b8bcb22..38757416b9 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/Sqlite.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/Sqlite.kt @@ -4,7 +4,7 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema import java.sql.ResultSet import org.jetbrains.kotlinx.dataframe.io.TableMetadata -import kotlin.reflect.typeOf +import kotlin.reflect.KType /** * Represents the Sqlite database type. @@ -16,27 +16,8 @@ public object Sqlite : DbType("sqlite") { override val driverClassName: String get() = "org.sqlite.JDBC" - override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? { - val name = tableColumnMetadata.name - return when (tableColumnMetadata.sqlTypeName) { - "INTEGER", "INTEGER AUTO_INCREMENT" -> rs.getInt(name) - "TEXT" -> rs.getString(name) - "REAL" -> rs.getDouble(name) - "NUMERIC" -> rs.getDouble(name) - "BLOB" -> rs.getBytes(name) - else -> throw IllegalArgumentException("Unsupported SQLite type: ${tableColumnMetadata.sqlTypeName}") - } - } - - override fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema { - return when (tableColumnMetadata.sqlTypeName) { - "INTEGER", "INTEGER AUTO_INCREMENT" -> ColumnSchema.Value(typeOf()) - "TEXT" -> ColumnSchema.Value(typeOf()) - "REAL" -> ColumnSchema.Value(typeOf()) - "NUMERIC" -> ColumnSchema.Value(typeOf()) - "BLOB" -> ColumnSchema.Value(typeOf()) - else -> throw IllegalArgumentException("Unsupported SQLite type: ${tableColumnMetadata.sqlTypeName} for column ${tableColumnMetadata.name}") - } + override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? { + return null } override fun isSystemTable(tableMetadata: TableMetadata): Boolean { @@ -49,4 +30,8 @@ public object Sqlite : DbType("sqlite") { tables.getString("TABLE_SCHEM"), tables.getString("TABLE_CAT")) } + + override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? { + return null + } } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index fca51a6c88..3a439638fa 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -1,18 +1,35 @@ package org.jetbrains.kotlinx.dataframe.io import io.github.oshai.kotlinlogging.KotlinLogging +import java.math.BigDecimal import java.sql.Connection import java.sql.DatabaseMetaData import java.sql.DriverManager import java.sql.ResultSet import java.sql.ResultSetMetaData +import java.sql.Time +import java.sql.Timestamp +import java.sql.Types +import java.sql.RowId +import java.sql.Ref +import java.sql.Clob +import java.sql.Blob +import java.sql.NClob +import java.sql.SQLXML +import java.util.Date import org.jetbrains.kotlinx.dataframe.AnyFrame +import org.jetbrains.kotlinx.dataframe.DataColumn import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.api.toDataFrame import org.jetbrains.kotlinx.dataframe.impl.schema.DataFrameSchemaImpl import org.jetbrains.kotlinx.dataframe.io.db.DbType import org.jetbrains.kotlinx.dataframe.io.db.extractDBTypeFromUrl +import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema +import kotlin.reflect.KType +import kotlin.reflect.full.createType +import kotlin.reflect.full.isSupertypeOf +import kotlin.reflect.full.starProjectedType private val logger = KotlinLogging.logger {} @@ -26,6 +43,21 @@ private val logger = KotlinLogging.logger {} */ private const val DEFAULT_LIMIT = Int.MIN_VALUE +/** + * Constant variable indicating the start of an SQL read query. + * The value of this variable is "SELECT". + */ +private const val START_OF_READ_SQL_QUERY = "SELECT" + +/** + * Constant representing the separator used to separate multiple SQL queries. + * + * This separator is used when multiple SQL queries need to be executed together. + * Each query should be separated by this separator to indicate the end of one query + * and the start of the next query. + */ +private const val MULTIPLE_SQL_QUERY_SEPARATOR = ";" + /** * Represents a column in a database table to keep all required meta-information. * @@ -33,8 +65,15 @@ private const val DEFAULT_LIMIT = Int.MIN_VALUE * @property [sqlTypeName] the SQL data type of the column. * @property [jdbcType] the JDBC data type of the column produced from [java.sql.Types]. * @property [size] the size of the column. + * @property [isNullable] true if column could contain nulls. */ -public data class TableColumnMetadata(val name: String, val sqlTypeName: String, val jdbcType: Int, val size: Int) +public data class TableColumnMetadata( + val name: String, + val sqlTypeName: String, + val jdbcType: Int, + val size: Int, + val isNullable: Boolean = false +) /** * Represents a table metadata to store information about a database table, @@ -58,19 +97,6 @@ public data class TableMetadata(val name: String, val schemaName: String?, val c */ public data class DatabaseConfiguration(val url: String, val user: String = "", val password: String = "") -/** - * Reads data from an SQL table and converts it into a DataFrame. - * - * @param [dbConfig] the configuration for the database, including URL, user, and password. - * @param [tableName] the name of the table to read data from. - * @return the DataFrame containing the data from the SQL table. - */ -public fun DataFrame.Companion.readSqlTable(dbConfig: DatabaseConfiguration, tableName: String): AnyFrame { - DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return readSqlTable(connection, tableName, DEFAULT_LIMIT) - } -} - /** * Reads data from an SQL table and converts it into a DataFrame. * @@ -79,25 +105,16 @@ public fun DataFrame.Companion.readSqlTable(dbConfig: DatabaseConfiguration, tab * @param [limit] the maximum number of rows to retrieve from the table. * @return the DataFrame containing the data from the SQL table. */ -public fun DataFrame.Companion.readSqlTable(dbConfig: DatabaseConfiguration, tableName: String, limit: Int): AnyFrame { +public fun DataFrame.Companion.readSqlTable( + dbConfig: DatabaseConfiguration, + tableName: String, + limit: Int = DEFAULT_LIMIT +): AnyFrame { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> return readSqlTable(connection, tableName, limit) } } -/** - * Reads data from an SQL table and converts it into a DataFrame. - * - * @param [connection] the database connection to read tables from. - * @param [tableName] the name of the table to read data from. - * @return the DataFrame containing the data from the SQL table. - * - * @see DriverManager.getConnection - */ -public fun DataFrame.Companion.readSqlTable(connection: Connection, tableName: String): AnyFrame { - return readSqlTable(connection, tableName, DEFAULT_LIMIT) -} - /** * Reads data from an SQL table and converts it into a DataFrame. * @@ -108,7 +125,11 @@ public fun DataFrame.Companion.readSqlTable(connection: Connection, tableName: S * * @see DriverManager.getConnection */ -public fun DataFrame.Companion.readSqlTable(connection: Connection, tableName: String, limit: Int): AnyFrame { +public fun DataFrame.Companion.readSqlTable( + connection: Connection, + tableName: String, + limit: Int = DEFAULT_LIMIT +): AnyFrame { var preparedQuery = "SELECT * FROM $tableName" if (limit > 0) preparedQuery += " LIMIT $limit" @@ -117,13 +138,12 @@ public fun DataFrame.Companion.readSqlTable(connection: Connection, tableName: S connection.createStatement().use { st -> logger.debug { "Connection with url:${url} is established successfully." } - val tableColumns = getTableColumnsMetadata(connection, tableName) st.executeQuery( preparedQuery ).use { rs -> - val data = fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit) - return data.toDataFrame() + val tableColumns = getTableColumnsMetadata(rs) + return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit) } } } @@ -131,25 +151,19 @@ public fun DataFrame.Companion.readSqlTable(connection: Connection, tableName: S /** * Converts the result of an SQL query to the DataFrame. * - * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. - * @param [sqlQuery] the SQL query to execute. - * @return the DataFrame containing the result of the SQL query. - */ -public fun DataFrame.Companion.readSqlQuery(dbConfig: DatabaseConfiguration, sqlQuery: String): AnyFrame { - DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return readSqlQuery(connection, sqlQuery, DEFAULT_LIMIT) - } -} - -/** - * Converts the result of an SQL query to the DataFrame. + * NOTE: SQL query should start from SELECT and contain one query for reading data without any manipulation. + * It should not contain `;` symbol. * * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. * @param [sqlQuery] the SQL query to execute. * @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution. * @return the DataFrame containing the result of the SQL query. */ -public fun DataFrame.Companion.readSqlQuery(dbConfig: DatabaseConfiguration, sqlQuery: String, limit: Int): AnyFrame { +public fun DataFrame.Companion.readSqlQuery( + dbConfig: DatabaseConfiguration, + sqlQuery: String, + limit: Int = DEFAULT_LIMIT +): AnyFrame { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> return readSqlQuery(connection, sqlQuery, limit) } @@ -158,18 +172,8 @@ public fun DataFrame.Companion.readSqlQuery(dbConfig: DatabaseConfiguration, sql /** * Converts the result of an SQL query to the DataFrame. * - * @param [connection] the database connection to execute the SQL query. - * @param [sqlQuery] the SQL query to execute. - * @return the DataFrame containing the result of the SQL query. - * - * @see DriverManager.getConnection - */ -public fun DataFrame.Companion.readSqlQuery(connection: Connection, sqlQuery: String): AnyFrame { - return readSqlQuery(connection, sqlQuery, DEFAULT_LIMIT) -} - -/** - * Converts the result of an SQL query to the DataFrame. + * NOTE: SQL query should start from SELECT and contain one query for reading data without any manipulation. + * It should not contain `;` symbol. * * @param [connection] the database connection to execute the SQL query. * @param [sqlQuery] the SQL query to execute. @@ -178,7 +182,13 @@ public fun DataFrame.Companion.readSqlQuery(connection: Connection, sqlQuery: St * * @see DriverManager.getConnection */ -public fun DataFrame.Companion.readSqlQuery(connection: Connection, sqlQuery: String, limit: Int): AnyFrame { +public fun DataFrame.Companion.readSqlQuery( + connection: Connection, + sqlQuery: String, + limit: Int = DEFAULT_LIMIT +): AnyFrame { + require(isValid(sqlQuery)) { "SQL query should start from SELECT and contain one query for reading data without any manipulation. " } + val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) @@ -190,77 +200,55 @@ public fun DataFrame.Companion.readSqlQuery(connection: Connection, sqlQuery: St connection.createStatement().use { st -> st.executeQuery(internalSqlQuery).use { rs -> val tableColumns = getTableColumnsMetadata(rs) - val data = fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, DEFAULT_LIMIT) - - logger.debug { "SQL query executed successfully. Converting data to DataFrame." } - - return data.toDataFrame() + return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, DEFAULT_LIMIT) } } } -/** - * Reads the data from a [ResultSet] and converts it into a DataFrame. - * - * @param [resultSet] the ResultSet containing the data to read. - * @param [dbType] the type of database that the ResultSet belongs to. - * @return the DataFrame generated from the ResultSet data. - */ -public fun DataFrame.Companion.readResultSet(resultSet: ResultSet, dbType: DbType): AnyFrame { - return readResultSet(resultSet, dbType, DEFAULT_LIMIT) -} +/** SQL-query is accepted only if it starts from SELECT */ +private fun isValid(sqlQuery: String): Boolean { + val normalizedSqlQuery = sqlQuery.trim().uppercase() -/** - * Reads the data from a ResultSet and converts it into a DataFrame. - * - * @param [resultSet] the ResultSet containing the data to read. - * @param [dbType] the type of database that the ResultSet belongs to. - * @param [limit] the maximum number of rows to read from the ResultSet. - * @return the DataFrame generated from the ResultSet data. - */ -public fun DataFrame.Companion.readResultSet(resultSet: ResultSet, dbType: DbType, limit: Int): AnyFrame { - val tableColumns = getTableColumnsMetadata(resultSet) - val data = fetchAndConvertDataFromResultSet(tableColumns, resultSet, dbType, limit) - return data.toDataFrame() + return normalizedSqlQuery.startsWith(START_OF_READ_SQL_QUERY) && + !normalizedSqlQuery.contains(MULTIPLE_SQL_QUERY_SEPARATOR) } /** - * Reads the data from a ResultSet and converts it into a DataFrame. + * Reads the data from a [ResultSet] and converts it into a DataFrame. * - * @param [resultSet] the ResultSet containing the data to read. - * @param [connection] the connection to the database (it's required to extract the database type). - * @return the DataFrame generated from the ResultSet data. + * @param [resultSet] the [ResultSet] containing the data to read. + * @param [dbType] the type of database that the [ResultSet] belongs to. + * @param [limit] the maximum number of rows to read from the [ResultSet]. + * @return the DataFrame generated from the [ResultSet] data. */ -public fun DataFrame.Companion.readResultSet(resultSet: ResultSet, connection: Connection): AnyFrame { - return readResultSet(resultSet, connection, DEFAULT_LIMIT) +public fun DataFrame.Companion.readResultSet( + resultSet: ResultSet, + dbType: DbType, + limit: Int = DEFAULT_LIMIT +): AnyFrame { + val tableColumns = getTableColumnsMetadata(resultSet) + return fetchAndConvertDataFromResultSet(tableColumns, resultSet, dbType, limit) } /** - * Reads the data from a ResultSet and converts it into a DataFrame. + * Reads the data from a [ResultSet] and converts it into a DataFrame. * - * @param [resultSet] the ResultSet containing the data to read. + * @param [resultSet] the [ResultSet] containing the data to read. * @param [connection] the connection to the database (it's required to extract the database type). - * @param [limit] the maximum number of rows to read from the ResultSet. - * @return the DataFrame generated from the ResultSet data. + * @param [limit] the maximum number of rows to read from the [ResultSet]. + * @return the DataFrame generated from the [ResultSet] data. */ -public fun DataFrame.Companion.readResultSet(resultSet: ResultSet, connection: Connection, limit: Int): AnyFrame { +public fun DataFrame.Companion.readResultSet( + resultSet: ResultSet, + connection: Connection, + limit: Int = DEFAULT_LIMIT +): AnyFrame { val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) return readResultSet(resultSet, dbType, limit) } -/** - * Reads all non-system tables from a database and returns them as a list of data frames - * using the provided database configuration. - * - * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. - * @return a list of [AnyFrame] objects representing the non-system tables from the database. - */ -public fun DataFrame.Companion.readAllSqlTables(dbConfig: DatabaseConfiguration): List { - return readAllSqlTables(dbConfig, DEFAULT_LIMIT) -} - /** * Reads all tables from the given database using the provided database configuration and limit. * @@ -268,24 +256,16 @@ public fun DataFrame.Companion.readAllSqlTables(dbConfig: DatabaseConfiguration) * @param [limit] the maximum number of rows to read from each table. * @return a list of [AnyFrame] objects representing the non-system tables from the database. */ -public fun DataFrame.Companion.readAllSqlTables(dbConfig: DatabaseConfiguration, limit: Int): List { +public fun DataFrame.Companion.readAllSqlTables( + dbConfig: DatabaseConfiguration, + catalogue: String? = null, + limit: Int = DEFAULT_LIMIT +): List { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return readAllSqlTables(connection, limit) + return readAllSqlTables(connection, catalogue, limit) } } -/** - * Reads all non-system tables from a database and returns them as a list of data frames. - * - * @param [connection] the database connection to read tables from. - * @return a list of [AnyFrame] objects representing the non-system tables from the database. - * - * @see DriverManager.getConnection - */ -public fun DataFrame.Companion.readAllSqlTables(connection: Connection): List { - return readAllSqlTables(connection, DEFAULT_LIMIT) -} - /** * Reads all non-system tables from a database and returns them as a list of data frames. * @@ -295,13 +275,17 @@ public fun DataFrame.Companion.readAllSqlTables(connection: Connection): List { +public fun DataFrame.Companion.readAllSqlTables( + connection: Connection, + catalogue: String? = null, + limit: Int = DEFAULT_LIMIT +): List { val metaData = connection.metaData val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) // exclude a system and other tables without data, but it looks like it supported badly for many databases - val tables = metaData.getTables(null, null, null, arrayOf("TABLE")) + val tables = metaData.getTables(catalogue, null, null, arrayOf("TABLE")) val dataFrames = mutableListOf() @@ -309,10 +293,18 @@ public fun DataFrame.Companion.readAllSqlTables(connection: Connection, limit: I val table = dbType.buildTableMetadata(tables) if (!dbType.isSystemTable(table)) { // we filter her second time because of specific logic with SQLite and possible issues with future databases - logger.debug { "Reading table: ${table.name}" } - val dataFrame = readSqlTable(connection, table.name, limit) + // val tableName = if (table.catalogue != null) table.catalogue + "." + table.name else table.name + val tableName = if (catalogue != null) catalogue + "." + table.name else table.name + + // TODO: both cases is schema specified or not in URL + // in h2 database name is recognized as a schema name https://www.h2database.com/html/features.html#database_url + // https://stackoverflow.com/questions/20896935/spring-hibernate-h2-database-schema-not-found + // could be Dialect/Database specific + logger.debug { "Reading table: $tableName" } + + val dataFrame = readSqlTable(connection, tableName, limit) dataFrames += dataFrame - logger.debug { "Finished reading table: ${table.name}" } + logger.debug { "Finished reading table: $tableName" } } } @@ -324,9 +316,12 @@ public fun DataFrame.Companion.readAllSqlTables(connection: Connection, limit: I * * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. * @param [tableName] the name of the SQL table for which to retrieve the schema. - * @return the DataFrameSchema object representing the schema of the SQL table + * @return the [DataFrameSchema] object representing the schema of the SQL table */ -public fun DataFrame.Companion.getSchemaForSqlTable(dbConfig: DatabaseConfiguration, tableName: String): DataFrameSchema { +public fun DataFrame.Companion.getSchemaForSqlTable( + dbConfig: DatabaseConfiguration, + tableName: String +): DataFrameSchema { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> return getSchemaForSqlTable(connection, tableName) } @@ -348,12 +343,15 @@ public fun DataFrame.Companion.getSchemaForSqlTable( val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) - connection.createStatement().use { - logger.debug { "Connection with url:${connection.metaData.url} is established successfully." } - - val tableColumns = getTableColumnsMetadata(connection, tableName) + val preparedQuery = "SELECT * FROM $tableName LIMIT 1" - return buildSchemaByTableColumns(tableColumns, dbType) + connection.createStatement().use { st -> + st.executeQuery( + preparedQuery + ).use { rs -> + val tableColumns = getTableColumnsMetadata(rs) + return buildSchemaByTableColumns(tableColumns, dbType) + } } } @@ -364,7 +362,10 @@ public fun DataFrame.Companion.getSchemaForSqlTable( * @param [sqlQuery] the SQL query to execute and retrieve the schema from. * @return the schema of the SQL query as a [DataFrameSchema] object. */ -public fun DataFrame.Companion.getSchemaForSqlQuery(dbConfig: DatabaseConfiguration, sqlQuery: String): DataFrameSchema { +public fun DataFrame.Companion.getSchemaForSqlQuery( + dbConfig: DatabaseConfiguration, + sqlQuery: String +): DataFrameSchema { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> return getSchemaForSqlQuery(connection, sqlQuery) } @@ -392,13 +393,13 @@ public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQ } /** - * Retrieves the schema from ResultSet. + * Retrieves the schema from [ResultSet]. * * NOTE: This function will not close connection and result set and not retrieve data from the result set. * - * @param [resultSet] the ResultSet obtained from executing a database query. - * @param [dbType] the type of database that the ResultSet belongs to. - * @return the schema of the ResultSet as a [DataFrameSchema] object. + * @param [resultSet] the [ResultSet] obtained from executing a database query. + * @param [dbType] the type of database that the [ResultSet] belongs to. + * @return the schema of the [ResultSet] as a [DataFrameSchema] object. */ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbType: DbType): DataFrameSchema { val tableColumns = getTableColumnsMetadata(resultSet) @@ -406,14 +407,14 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbTyp } /** - * Retrieves the schema from ResultSet. + * Retrieves the schema from [ResultSet]. * * NOTE: [connection] is required to extract the database type. * This function will not close connection and result set and not retrieve data from the result set. * - * @param [resultSet] the ResultSet obtained from executing a database query. + * @param [resultSet] the [ResultSet] obtained from executing a database query. * @param [connection] the connection to the database (it's required to extract the database type). - * @return the schema of the ResultSet as a [DataFrameSchema] object. + * @return the schema of the [ResultSet] as a [DataFrameSchema] object. */ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, connection: Connection): DataFrameSchema { val url = connection.metaData.url @@ -427,7 +428,7 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, conne * Retrieves the schema of all non-system tables in the database using the provided database configuration. * * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. - * @return a list of DataFrameSchema objects representing the schema of each non-system table. + * @return a list of [DataFrameSchema] objects representing the schema of each non-system table. */ public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): List { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> @@ -439,7 +440,7 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfig * Retrieves the schema of all non-system tables in the database using the provided database connection. * * @param [connection] the database connection. - * @return a list of DataFrameSchema objects representing the schema of each non-system table. + * @return a list of [DataFrameSchema] objects representing the schema of each non-system table. */ public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): List { val metaData = connection.metaData @@ -470,126 +471,217 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): * @param [tableColumns] a mutable map containing the table columns, where the key represents the column name * and the value represents the metadata of the column * @param [dbType] the type of database. - * @return a DataFrameSchema object representing the schema built from the table columns. + * @return a [DataFrameSchema] object representing the schema built from the table columns. */ -private fun buildSchemaByTableColumns(tableColumns: MutableMap, dbType: DbType): DataFrameSchema { - val schemaColumns = tableColumns.map { - Pair(it.key, dbType.toColumnSchema(it.value)) - }.toMap() +private fun buildSchemaByTableColumns(tableColumns: MutableList, dbType: DbType): DataFrameSchema { + val schemaColumns = tableColumns.associate { + Pair(it.name, generateColumnSchemaValue(dbType, it)) + } return DataFrameSchemaImpl( columns = schemaColumns ) } +private fun generateColumnSchemaValue( + dbType: DbType, + tableColumnMetadata: TableColumnMetadata +): ColumnSchema = dbType.convertSqlTypeToColumnSchemaValue(tableColumnMetadata) ?: ColumnSchema.Value( + makeCommonSqlToKTypeMapping(tableColumnMetadata) +) + + /** * Retrieves the metadata of the columns in the result set. * - * @param [rs] the result set - * @return a mutable map of column names to [TableColumnMetadata] objects, - * where each TableColumnMetadata object contains information such as the column type, - * JDBC type, size, and name. + * @param rs the result set + * @return a mutable list of [TableColumnMetadata] objects, + * where each TableColumnMetadata object contains information such as the column type, + * JDBC type, size, and name. */ -private fun getTableColumnsMetadata(rs: ResultSet): MutableMap { +private fun getTableColumnsMetadata(rs: ResultSet): MutableList { val metaData: ResultSetMetaData = rs.metaData val numberOfColumns: Int = metaData.columnCount - - val tableColumns = mutableMapOf() + val tableColumns = mutableListOf() + val columnNameCounter = mutableMapOf() + val databaseMetaData: DatabaseMetaData = rs.statement.connection.metaData + val catalog: String? = rs.statement.connection.catalog.takeUnless { it.isNullOrBlank() } + val schema: String? = rs.statement.connection.schema.takeUnless { it.isNullOrBlank() } for (i in 1 until numberOfColumns + 1) { - val name = metaData.getColumnName(i) + val columnResultSet: ResultSet = + databaseMetaData.getColumns(catalog, schema, metaData.getTableName(i), metaData.getColumnName(i)) + val isNullable = if (columnResultSet.next()) { + columnResultSet.getString("IS_NULLABLE") == "YES" + } else { + true // we assume that it's nullable by default + } + + val name = manageColumnNameDuplication(columnNameCounter, metaData.getColumnName(i)) val size = metaData.getColumnDisplaySize(i) val type = metaData.getColumnTypeName(i) val jdbcType = metaData.getColumnType(i) - // TODO: - // add strategy for multiple columns handling (throw exception, ignore, - // create columns with additional indexes in name) - // column names should be unique - check(!tableColumns.containsKey(name)) { "Multiple columns with name $name from table ${metaData.getTableName(i)}. Rename columns to make it unique." } - - tableColumns += Pair(name, TableColumnMetadata(name, type, jdbcType, size)) - + tableColumns += TableColumnMetadata(name, type, jdbcType, size, isNullable) } return tableColumns } /** - * Retrieves the metadata of columns for a given table. + * Manages the duplication of column names by appending a unique identifier to the original name if necessary. * - * @param [connection] the database connection - * @param [tableName] the name of the table - * @return a mutable map of column names to [TableColumnMetadata] objects, - * where each TableColumnMetadata object contains information such as the column type, - * JDBC type, size, and name. + * @param columnNameCounter a mutable map that keeps track of the count for each column name. + * @param originalName the original name of the column to be managed. + * @return the modified column name that is free from duplication. */ -private fun getTableColumnsMetadata(connection: Connection, tableName: String): MutableMap { - val dbMetaData: DatabaseMetaData = connection.metaData - val columns: ResultSet = dbMetaData.getColumns(null, null, tableName, null) - val tableColumns = mutableMapOf() - - while (columns.next()) { - val name = columns.getString("COLUMN_NAME") - val type = columns.getString("TYPE_NAME") - val jdbcType = columns.getInt("DATA_TYPE") - val size = columns.getInt("COLUMN_SIZE") - tableColumns += Pair(name, TableColumnMetadata(name, type, jdbcType, size)) +private fun manageColumnNameDuplication(columnNameCounter: MutableMap, originalName: String): String { + var name = originalName + val count = columnNameCounter[originalName] + + if (count != null) { + var incrementedCount = count + 1 + while (columnNameCounter.containsKey("${originalName}_$incrementedCount")) { + incrementedCount++ + } + columnNameCounter[originalName] = incrementedCount + name = "${originalName}_$incrementedCount" + } else { + columnNameCounter[originalName] = 0 } - return tableColumns + + return name } /** * Fetches and converts data from a ResultSet into a mutable map. * - * @param [tableColumns] a map containing the column metadata for the table. + * @param [tableColumns] a list containing the column metadata for the table. * @param [rs] the ResultSet object containing the data to be fetched and converted. * @param [dbType] the type of the database. * @param [limit] the maximum number of rows to fetch and convert. * @return A mutable map containing the fetched and converted data. */ private fun fetchAndConvertDataFromResultSet( - tableColumns: MutableMap, + tableColumns: MutableList, rs: ResultSet, dbType: DbType, limit: Int -): MutableMap> { - // map - val data = mutableMapOf>() +): AnyFrame { + val data = List(tableColumns.size) { mutableListOf() } - // init data - tableColumns.forEach { (columnName, _) -> - data[columnName] = mutableListOf() + val kotlinTypesForSqlColumns = mutableMapOf() + List(tableColumns.size) { index -> + kotlinTypesForSqlColumns[index] = generateKType(dbType, tableColumns[index]) } var counter = 0 if (limit > 0) { - while (rs.next() && counter < limit) { - handleRow(tableColumns, data, dbType, rs) + while (counter < limit && rs.next()) { + extractNewRowFromResultSetAndAddToData(tableColumns, data, rs, kotlinTypesForSqlColumns) counter++ // if (counter % 1000 == 0) logger.debug { "Loaded $counter rows." } // TODO: https://github.com/Kotlin/dataframe/issues/455 } } else { while (rs.next()) { - handleRow(tableColumns, data, dbType, rs) + extractNewRowFromResultSetAndAddToData(tableColumns, data, rs, kotlinTypesForSqlColumns) counter++ // if (counter % 1000 == 0) logger.debug { "Loaded $counter rows." } // TODO: https://github.com/Kotlin/dataframe/issues/455 } } - return data + val dataFrame = data.mapIndexed { index, values -> + DataColumn.createValueColumn( + name = tableColumns[index].name, + values = values, + type = kotlinTypesForSqlColumns[index]!! + ) + }.toDataFrame() + + logger.debug { "DataFrame with ${dataFrame.rowsCount()} rows and ${dataFrame.columnsCount()} columns created as a result of SQL query." } + + return dataFrame } -private fun handleRow( - tableColumns: MutableMap, - data: MutableMap>, - dbType: DbType, - rs: ResultSet +private fun extractNewRowFromResultSetAndAddToData( + tableColumns: MutableList, + data: List>, + rs: ResultSet, + kotlinTypesForSqlColumns: MutableMap ) { - tableColumns.forEach { (columnName, jdbcColumn) -> - data[columnName] = (data[columnName]!! + dbType.convertDataFromResultSet(rs, jdbcColumn)).toMutableList() + repeat(tableColumns.size) { i -> + data[i].add( + try { + rs.getObject(i + 1) + } catch (_: Throwable) { + val kType = kotlinTypesForSqlColumns[i]!! + if (kType.isSupertypeOf(String::class.starProjectedType)) rs.getString(i + 1) else rs.getString(i + 1) // TODO: expand for all the types like in generateKType function + } + ) } } +/** + * Generates a KType based on the given database type and table column metadata. + * + * @param dbType The database type. + * @param tableColumnMetadata The table column metadata. + * + * @return The generated KType. + */ +private fun generateKType(dbType: DbType, tableColumnMetadata: TableColumnMetadata): KType { + return dbType.convertSqlTypeToKType(tableColumnMetadata) ?: makeCommonSqlToKTypeMapping(tableColumnMetadata) +} +/** + * Creates a mapping between common SQL types and their corresponding KTypes. + * + * @param tableColumnMetadata The metadata of the table column. + * @return The KType associated with the SQL type, or a default type if no mapping is found. + */ +private fun makeCommonSqlToKTypeMapping(tableColumnMetadata: TableColumnMetadata): KType { + val jdbcTypeToKTypeMapping = mapOf( + Types.BIT to Boolean::class, + Types.TINYINT to Byte::class, + Types.SMALLINT to Short::class, + Types.INTEGER to Int::class, + Types.BIGINT to Long::class, + Types.FLOAT to Float::class, + Types.REAL to Float::class, + Types.DOUBLE to Double::class, + Types.NUMERIC to BigDecimal::class, + Types.DECIMAL to BigDecimal::class, + Types.CHAR to Char::class, + Types.VARCHAR to String::class, + Types.LONGVARCHAR to String::class, + Types.DATE to Date::class, + Types.TIME to Time::class, + Types.TIMESTAMP to Timestamp::class, + Types.BINARY to ByteArray::class, + Types.VARBINARY to ByteArray::class, + Types.LONGVARBINARY to ByteArray::class, + Types.NULL to String::class, + Types.OTHER to Any::class, + Types.JAVA_OBJECT to Any::class, + Types.DISTINCT to Any::class, + Types.STRUCT to Any::class, + Types.ARRAY to Array::class, + Types.BLOB to Blob::class, + Types.CLOB to Clob::class, + Types.REF to Ref::class, + Types.DATALINK to Any::class, + Types.BOOLEAN to Boolean::class, + Types.ROWID to RowId::class, + Types.NCHAR to Char::class, + Types.NVARCHAR to String::class, + Types.LONGNVARCHAR to String::class, + Types.NCLOB to NClob::class, + Types.SQLXML to SQLXML::class, + Types.REF_CURSOR to Ref::class, + Types.TIME_WITH_TIMEZONE to Time::class, + Types.TIMESTAMP_WITH_TIMEZONE to Timestamp::class + ) + val kClass = jdbcTypeToKTypeMapping[tableColumnMetadata.jdbcType] ?: String::class + return kClass.createType(nullable = tableColumnMetadata.isNullable) +} diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt index f20141c886..ae9b8b040c 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt @@ -12,6 +12,7 @@ import org.jetbrains.kotlinx.dataframe.io.db.H2 import org.junit.AfterClass import org.junit.BeforeClass import org.junit.Test +import java.math.BigDecimal import java.sql.Connection import java.sql.DriverManager import java.sql.ResultSet @@ -22,53 +23,53 @@ private const val URL = "jdbc:h2:mem:test;DB_CLOSE_DELAY=-1;MODE=MySQL;DATABASE_ @DataSchema interface Customer { - val id: Int - val name: String - val age: Int + val id: Int? + val name: String? + val age: Int? } @DataSchema interface Sale { - val id: Int - val customerId: Int + val id: Int? + val customerId: Int? val amount: Double } @DataSchema interface CustomerSales { - val customerName: String - val totalSalesAmount: Double + val customerName: String? + val totalSalesAmount: Double? } @DataSchema interface TestTableData { - val characterCol: String - val characterVaryingCol: String - val characterLargeObjectCol: String - val mediumTextCol: String - val varcharIgnoreCaseCol: String - val binaryCol: ByteArray - val binaryVaryingCol: ByteArray - val binaryLargeObjectCol: ByteArray - val booleanCol: Boolean - val tinyIntCol: Byte - val smallIntCol: Short - val integerCol: Int - val bigIntCol: Long - val numericCol: Double - val realCol: Float - val doublePrecisionCol: Double - val decFloatCol: Double - val dateCol: String - val timeCol: String - val timeWithTimeZoneCol: String - val timestampCol: String - val timestampWithTimeZoneCol: String - val intervalCol: String + val characterCol: String? + val characterVaryingCol: String? + val characterLargeObjectCol: String? + val mediumTextCol: String? + val varcharIgnoreCaseCol: String? + val binaryCol: ByteArray? + val binaryVaryingCol: ByteArray? + val binaryLargeObjectCol: ByteArray? + val booleanCol: Boolean? + val tinyIntCol: Byte? + val smallIntCol: Short? + val integerCol: Int? + val bigIntCol: Long? + val numericCol: Double? + val realCol: Float? + val doublePrecisionCol: Double? + val decFloatCol: Double? + val dateCol: String? + val timeCol: String? + val timeWithTimeZoneCol: String? + val timestampCol: String? + val timestampWithTimeZoneCol: String? + val intervalCol: String? val javaObjectCol: Any? - val enumCol: String - val jsonCol: String - val uuidCol: String + val enumCol: String? + val jsonCol: String? + val uuidCol: String? } class JdbcTest { @@ -82,7 +83,7 @@ class JdbcTest { DriverManager.getConnection(URL) - // Crate table Customer + // Create table Customer @Language("SQL") val createCustomerTableQuery = """ CREATE TABLE Customer ( @@ -100,7 +101,7 @@ class JdbcTest { CREATE TABLE Sale ( id INT PRIMARY KEY, customerId INT, - amount DECIMAL(10, 2) + amount DECIMAL(10, 2) NOT NULL ) """ @@ -132,6 +133,28 @@ class JdbcTest { } } + @Test + fun `read from empty table`() { + @Language("SQL") + val createTableQuery = """ + CREATE TABLE EmptyTestTable ( + characterCol CHAR(10), + characterVaryingCol VARCHAR(20) + ) + """ + + connection.createStatement().execute(createTableQuery.trimIndent()) + + val tableName = "EmptyTestTable" + + val df = DataFrame.readSqlTable(connection, tableName) + df.rowsCount() shouldBe 0 + + val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName) + dataSchema.columns.size shouldBe 2 + dataSchema.columns["characterCol"]!!.type shouldBe typeOf() + } + @Test fun `read from huge table`() { @Language("SQL") @@ -214,7 +237,7 @@ class JdbcTest { val df = DataFrame.readSqlTable(connection, "TestTable").cast() df.rowsCount() shouldBe 3 - df.filter { it[TestTableData::integerCol] > 1000}.rowsCount() shouldBe 2 + df.filter { it[TestTableData::integerCol]!! > 1000 }.rowsCount() shouldBe 2 } @Test @@ -223,35 +246,56 @@ class JdbcTest { val df = DataFrame.readSqlTable(connection, tableName).cast() df.rowsCount() shouldBe 4 - df.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 2 + df.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 df[0][1] shouldBe "John" val df1 = DataFrame.readSqlTable(connection, tableName, 1).cast() df1.rowsCount() shouldBe 1 - df1.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 1 + df1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 df1[0][1] shouldBe "John" val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName) dataSchema.columns.size shouldBe 3 - dataSchema.columns["name"]!!.type shouldBe typeOf() + dataSchema.columns["name"]!!.type shouldBe typeOf() val dbConfig = DatabaseConfiguration(url = URL) val df2 = DataFrame.readSqlTable(dbConfig, tableName).cast() df2.rowsCount() shouldBe 4 - df2.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 2 + df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 df2[0][1] shouldBe "John" val df3 = DataFrame.readSqlTable(dbConfig, tableName, 1).cast() df3.rowsCount() shouldBe 1 - df3.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 1 + df3.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 df3[0][1] shouldBe "John" val dataSchema1 = DataFrame.getSchemaForSqlTable(dbConfig, tableName) dataSchema1.columns.size shouldBe 3 - dataSchema.columns["name"]!!.type shouldBe typeOf() + dataSchema1.columns["name"]!!.type shouldBe typeOf() + } + + // to cover a reported case from https://github.com/Kotlin/dataframe/issues/494 + @Test + fun `repeated read from table with limit`() { + val tableName = "Customer" + + for (i in 1..10) { + val df1 = DataFrame.readSqlTable(connection, tableName, 2).cast() + + df1.rowsCount() shouldBe 2 + df1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 + df1[0][1] shouldBe "John" + + val dbConfig = DatabaseConfiguration(url = URL) + val df2 = DataFrame.readSqlTable(dbConfig, tableName, 2).cast() + + df2.rowsCount() shouldBe 2 + df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 + df2[0][1] shouldBe "John" + } } @Test @@ -264,29 +308,29 @@ class JdbcTest { val df = DataFrame.readResultSet(rs, H2).cast() df.rowsCount() shouldBe 4 - df.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 2 + df.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 df[0][1] shouldBe "John" rs.beforeFirst() - val df1 = DataFrame.readResultSet(rs, H2, 1).cast() + val df1 = DataFrame.readResultSet(rs, H2, 1).cast() df1.rowsCount() shouldBe 1 - df1.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 1 + df1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 df1[0][1] shouldBe "John" rs.beforeFirst() val dataSchema = DataFrame.getSchemaForResultSet(rs, H2) dataSchema.columns.size shouldBe 3 - dataSchema.columns["name"]!!.type shouldBe typeOf() + dataSchema.columns["name"]!!.type shouldBe typeOf() rs.beforeFirst() val df2 = DataFrame.readResultSet(rs, connection).cast() df2.rowsCount() shouldBe 4 - df2.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 2 + df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 df2[0][1] shouldBe "John" rs.beforeFirst() @@ -294,14 +338,43 @@ class JdbcTest { val df3 = DataFrame.readResultSet(rs, connection, 1).cast() df3.rowsCount() shouldBe 1 - df3.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 1 + df3.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 df3[0][1] shouldBe "John" rs.beforeFirst() val dataSchema1 = DataFrame.getSchemaForResultSet(rs, connection) dataSchema1.columns.size shouldBe 3 - dataSchema.columns["name"]!!.type shouldBe typeOf() + dataSchema1.columns["name"]!!.type shouldBe typeOf() + } + } + } + + // to cover a reported case from https://github.com/Kotlin/dataframe/issues/494 + @Test + fun `repeated read from ResultSet with limit`() { + connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st -> + @Language("SQL") + val selectStatement = "SELECT * FROM Customer" + + st.executeQuery(selectStatement).use { rs -> + for (i in 1..10) { + rs.beforeFirst() + + val df1 = DataFrame.readResultSet(rs, H2, 2).cast() + + df1.rowsCount() shouldBe 2 + df1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 + df1[0][1] shouldBe "John" + + rs.beforeFirst() + + val df2 = DataFrame.readResultSet(rs, connection, 2).cast() + + df2.rowsCount() shouldBe 2 + df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 + df2[0][1] shouldBe "John" + } } } } @@ -313,6 +386,88 @@ class JdbcTest { } } + // to cover a reported case from https://github.com/Kotlin/dataframe/issues/498 + @Test + fun `read from incorrect SQL query`() { + @Language("SQL") + val createSQL = """ + CREATE TABLE Orders ( + order_id INT PRIMARY KEY, + customer_id INT, + order_date DATE, + total_amount DECIMAL(10, 2)) + """ + + + @Language("SQL") + val dropSQL = """ + DROP TABLE Customer + """ + + @Language("SQL") + val alterSQL = """ + ALTER TABLE Customer + ADD COLUMN email VARCHAR(100) + """ + + @Language("SQL") + val deleteSQL = """ + DELETE FROM Customer + WHERE id = 1 + """ + + @Language("SQL") + val repeatedSQL = """ + SELECT * FROM Customer + WHERE id = 1; + SELECT * FROM Customer + WHERE id = 1; + """ + + shouldThrow { + DataFrame.readSqlQuery(connection, createSQL) + } + + shouldThrow { + DataFrame.readSqlQuery(connection, dropSQL) + } + + shouldThrow { + DataFrame.readSqlQuery(connection, alterSQL) + } + + shouldThrow { + DataFrame.readSqlQuery(connection, deleteSQL) + } + + shouldThrow { + DataFrame.readSqlQuery(connection, repeatedSQL) + } + } + + @Test + fun `read from table with name from reserved SQL keywords`() { + // Create table Sale + @Language("SQL") + val createAlterTableQuery = """ + CREATE TABLE "ALTER" ( + id INT PRIMARY KEY, + description TEXT + ) + """ + + connection.createStatement().execute( + createAlterTableQuery + ) + + @Language("SQL") + val selectFromWeirdTableSQL = """ + SELECT * from "ALTER" + """ + + DataFrame.readSqlQuery(connection, selectFromWeirdTableSQL).rowsCount() shouldBe 0 + } + @Test fun `read from non-existing jdbc url`() { shouldThrow { @@ -334,39 +489,39 @@ class JdbcTest { val df = DataFrame.readSqlQuery(connection, sqlQuery).cast() df.rowsCount() shouldBe 2 - df.filter { it[CustomerSales::totalSalesAmount] > 100 }.rowsCount() shouldBe 1 + df.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 df[0][0] shouldBe "John" val df1 = DataFrame.readSqlQuery(connection, sqlQuery, 1).cast() df1.rowsCount() shouldBe 1 - df1.filter { it[CustomerSales::totalSalesAmount] > 100 }.rowsCount() shouldBe 1 + df1.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 df1[0][0] shouldBe "John" val dataSchema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery) dataSchema.columns.size shouldBe 2 - dataSchema.columns["name"]!!.type shouldBe typeOf() + dataSchema.columns["name"]!!.type shouldBe typeOf() val dbConfig = DatabaseConfiguration(url = URL) val df2 = DataFrame.readSqlQuery(dbConfig, sqlQuery).cast() df2.rowsCount() shouldBe 2 - df2.filter { it[CustomerSales::totalSalesAmount] > 100 }.rowsCount() shouldBe 1 + df2.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 df2[0][0] shouldBe "John" val df3 = DataFrame.readSqlQuery(dbConfig, sqlQuery, 1).cast() df3.rowsCount() shouldBe 1 - df3.filter { it[CustomerSales::totalSalesAmount] > 100 }.rowsCount() shouldBe 1 + df3.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 df3[0][0] shouldBe "John" val dataSchema1 = DataFrame.getSchemaForSqlQuery(dbConfig, sqlQuery) dataSchema1.columns.size shouldBe 2 - dataSchema.columns["name"]!!.type shouldBe typeOf() + dataSchema1.columns["name"]!!.type shouldBe typeOf() } @Test - fun `read from sql query with repeated columns` () { + fun `read from sql query with two repeated columns`() { @Language("SQL") val sqlQuery = """ SELECT c1.name, c2.name @@ -374,9 +529,26 @@ class JdbcTest { INNER JOIN Customer c2 ON c1.id = c2.id """.trimIndent() - shouldThrow { - DataFrame.readSqlQuery(connection, sqlQuery) - } + val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery) + schema.columns.size shouldBe 2 + schema.columns.toList()[0].first shouldBe "name" + schema.columns.toList()[1].first shouldBe "name_1" + } + + @Test + fun `read from sql query with three repeated columns`() { + @Language("SQL") + val sqlQuery = """ + SELECT c1.name as name, c2.name as name_1, c1.name as name_1 + FROM Customer c1 + INNER JOIN Customer c2 ON c1.id = c2.id + """.trimIndent() + + val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery) + schema.columns.size shouldBe 3 + schema.columns.toList()[0].first shouldBe "name" + schema.columns.toList()[1].first shouldBe "name_1" + schema.columns.toList()[2].first shouldBe "name_2" } @Test @@ -386,38 +558,38 @@ class JdbcTest { val customerDf = dataframes[0].cast() customerDf.rowsCount() shouldBe 4 - customerDf.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 2 + customerDf.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 customerDf[0][1] shouldBe "John" val saleDf = dataframes[1].cast() saleDf.rowsCount() shouldBe 4 saleDf.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 - saleDf[0][2] shouldBe 100.5f + (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 - val dataframes1 = DataFrame.readAllSqlTables(connection, 1) + val dataframes1 = DataFrame.readAllSqlTables(connection, limit = 1) val customerDf1 = dataframes1[0].cast() customerDf1.rowsCount() shouldBe 1 - customerDf1.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 1 + customerDf1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 customerDf1[0][1] shouldBe "John" val saleDf1 = dataframes1[1].cast() saleDf1.rowsCount() shouldBe 1 saleDf1.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 1 - saleDf1[0][2] shouldBe 100.5f + (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 val dataSchemas = DataFrame.getSchemaForAllSqlTables(connection) val customerDataSchema = dataSchemas[0] customerDataSchema.columns.size shouldBe 3 - customerDataSchema.columns["name"]!!.type shouldBe typeOf() + customerDataSchema.columns["name"]!!.type shouldBe typeOf() val saleDataSchema = dataSchemas[1] saleDataSchema.columns.size shouldBe 3 - saleDataSchema.columns["amount"]!!.type shouldBe typeOf() + saleDataSchema.columns["amount"]!!.type shouldBe typeOf() val dbConfig = DatabaseConfiguration(url = URL) val dataframes2 = DataFrame.readAllSqlTables(dbConfig) @@ -425,38 +597,38 @@ class JdbcTest { val customerDf2 = dataframes2[0].cast() customerDf2.rowsCount() shouldBe 4 - customerDf2.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 2 + customerDf2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 customerDf2[0][1] shouldBe "John" val saleDf2 = dataframes2[1].cast() saleDf2.rowsCount() shouldBe 4 saleDf2.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 - saleDf2[0][2] shouldBe 100.5f + (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 - val dataframes3 = DataFrame.readAllSqlTables(dbConfig, 1) + val dataframes3 = DataFrame.readAllSqlTables(dbConfig, limit = 1) val customerDf3 = dataframes3[0].cast() customerDf3.rowsCount() shouldBe 1 - customerDf3.filter { it[Customer::age] > 30 }.rowsCount() shouldBe 1 + customerDf3.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 customerDf3[0][1] shouldBe "John" val saleDf3 = dataframes3[1].cast() saleDf3.rowsCount() shouldBe 1 saleDf3.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 1 - saleDf3[0][2] shouldBe 100.5f + (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 val dataSchemas1 = DataFrame.getSchemaForAllSqlTables(dbConfig) val customerDataSchema1 = dataSchemas1[0] customerDataSchema1.columns.size shouldBe 3 - customerDataSchema1.columns["name"]!!.type shouldBe typeOf() + customerDataSchema1.columns["name"]!!.type shouldBe typeOf() val saleDataSchema1 = dataSchemas1[1] saleDataSchema1.columns.size shouldBe 3 - saleDataSchema1.columns["amount"]!!.type shouldBe typeOf() + saleDataSchema1.columns["amount"]!!.type shouldBe typeOf() } } diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/imdbTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/imdbTest.kt index 0a7bda75bf..f5690d3e27 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/imdbTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/imdbTest.kt @@ -1,32 +1,35 @@ package org.jetbrains.kotlinx.dataframe.io +import io.kotest.matchers.shouldBe import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.annotations.DataSchema import org.jetbrains.kotlinx.dataframe.api.cast -import org.jetbrains.kotlinx.dataframe.api.print import org.junit.Test import java.sql.DriverManager import java.util.Properties +import org.jetbrains.kotlinx.dataframe.api.filter import org.junit.Ignore +import kotlin.reflect.typeOf -private const val URL = "jdbc:mariadb://localhost:3306/imdb" +private const val URL = "jdbc:mariadb://localhost:3307/imdb" +private const val URL2 = "jdbc:mariadb://localhost:3307" private const val USER_NAME = "root" private const val PASSWORD = "pass" @DataSchema interface ActorKDF { val id: Int - val firstName: String - val lastName: String - val gender: String + val firstName: String? + val lastName: String? + val gender: String? } @DataSchema interface RankedMoviesWithGenres { - val name: String - val year: Int - val rank: Float - val genres: String + val name: String? + val year: Int? + val rank: Float? + val genres: String? } @Ignore @@ -40,12 +43,39 @@ class ImdbTestTest { // generate kdf schemas by database metadata (as interfaces or extensions) // for gradle or as classes under the hood in KNB + val tableName = "actors" + DriverManager.getConnection(URL, props).use { connection -> - val df = DataFrame.readSqlTable(connection, "actors", 100).cast() - df.print() + val df = DataFrame.readSqlTable(connection, tableName, 100).cast() + val result = df.filter { it[ActorKDF::id] in 11..19 } + result[0][1] shouldBe "Víctor" + + val schema = DataFrame.getSchemaForSqlTable(connection, tableName) + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["first_name"]!!.type shouldBe typeOf() } } + @Test + fun `read table with schema name in table name`() { + val props = Properties() + props.setProperty("user", USER_NAME) + props.setProperty("password", PASSWORD) + + // generate kdf schemas by database metadata (as interfaces or extensions) + // for gradle or as classes under the hood in KNB + val imdbTableName = "imdb.actors" + + DriverManager.getConnection(URL2, props).use { connection -> + val df = DataFrame.readSqlTable(connection, imdbTableName, 100).cast() + val result = df.filter { it[ActorKDF::id] in 11..19 } + result[0][1] shouldBe "Víctor" + + val schema = DataFrame.getSchemaForSqlTable(connection, imdbTableName) + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["first_name"]!!.type shouldBe typeOf() + } + } @Test fun `read sql query`() { @@ -54,6 +84,7 @@ class ImdbTestTest { "from movies join movies_directors on movie_id = movies.id\n" + " join directors on directors.id=director_id left join movies_genres on movies.id = movies_genres.movie_id \n" + "where directors.first_name = \"Quentin\" and directors.last_name = \"Tarantino\"\n" + + "and movies.name is not null and movies.name is not null\n" + "group by name, year, rank\n" + "order by year" val props = Properties() @@ -65,12 +96,13 @@ class ImdbTestTest { DriverManager.getConnection(URL, props).use { connection -> val df = DataFrame.readSqlQuery(connection, sqlQuery).cast() - //df.filter { year > 2000 }.print() - df.print() + val result = + df.filter { it[RankedMoviesWithGenres::year] != null && it[RankedMoviesWithGenres::year]!! > 2000 } + result[0][1] shouldBe 2003 val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery) - schema.print() + schema.columns["name"]!!.type shouldBe typeOf() + schema.columns["year"]!!.type shouldBe typeOf() } - } } diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt index abf1d3ebd8..c21ef9e549 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mariadbTest.kt @@ -6,9 +6,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.annotations.DataSchema import org.jetbrains.kotlinx.dataframe.api.cast import org.jetbrains.kotlinx.dataframe.api.filter -import org.jetbrains.kotlinx.dataframe.api.print import org.junit.AfterClass -import org.junit.Assert.assertEquals import org.junit.BeforeClass import org.junit.Test import java.math.BigDecimal @@ -16,6 +14,7 @@ import java.sql.Connection import java.sql.DriverManager import java.sql.SQLException import org.junit.Ignore +import kotlin.reflect.typeOf private const val URL = "jdbc:mariadb://localhost:3307" private const val USER_NAME = "root" @@ -54,16 +53,58 @@ interface Table1MariaDb { val mediumtextCol: String val longtextCol: String val enumCol: String - val setCol: String + val setCol: Char + val jsonCol: String } @DataSchema interface Table2MariaDb { + val id: Int + val bitCol: Boolean? + val tinyintCol: Int? + val smallintCol: Int? + val mediumintCol: Int? + val mediumintUnsignedCol: Long? + val integerCol: Int? + val intCol: Int? + val integerUnsignedCol: Long? + val bigintCol: Long? + val floatCol: Float? + val doubleCol: Double? + val decimalCol: Double? + val dateCol: String? + val datetimeCol: String? + val timestampCol: String? + val timeCol: String? + val yearCol: String? + val varcharCol: String? + val charCol: String? + val binaryCol: ByteArray? + val varbinaryCol: ByteArray? + val tinyblobCol: ByteArray? + val blobCol: ByteArray? + val mediumblobCol: ByteArray? + val longblobCol: ByteArray? + val textCol: String? + val mediumtextCol: String? + val longtextCol: String? + val enumCol: String? + val setCol: Char? + val jsonCol: String? +} + +@DataSchema +interface Table3MariaDb { val id: Int val enumCol: String - val setCol: String + val setCol: Char? } +private const val JSON_STRING = + "{\"details\": {\"foodType\": \"Pizza\", \"menu\": \"https://www.loumalnatis.com/our-menu\"}, \n" + + " \t\"favorites\": [{\"description\": \"Pepperoni deep dish\", \"price\": 18.75}, \n" + + "{\"description\": \"The Lou\", \"price\": 24.75}]}" + @Ignore class MariadbTest { companion object { @@ -95,36 +136,38 @@ class MariadbTest { val createTableQuery = """ CREATE TABLE IF NOT EXISTS table1 ( id INT AUTO_INCREMENT PRIMARY KEY, - bitCol BIT, - tinyintCol TINYINT, - smallintCol SMALLINT, - mediumintCol MEDIUMINT, - mediumintUnsignedCol MEDIUMINT UNSIGNED, - integerCol INTEGER, - intCol INT, - integerUnsignedCol INTEGER UNSIGNED, - bigintCol BIGINT, - floatCol FLOAT, - doubleCol DOUBLE, - decimalCol DECIMAL, - dateCol DATE, - datetimeCol DATETIME, - timestampCol TIMESTAMP, - timeCol TIME, - yearCol YEAR, - varcharCol VARCHAR(255), - charCol CHAR(10), - binaryCol BINARY(64), - varbinaryCol VARBINARY(128), - tinyblobCol TINYBLOB, - blobCol BLOB, - mediumblobCol MEDIUMBLOB, - longblobCol LONGBLOB, - textCol TEXT, - mediumtextCol MEDIUMTEXT, - longtextCol LONGTEXT, - enumCol ENUM('Value1', 'Value2', 'Value3'), - setCol SET('Option1', 'Option2', 'Option3') + bitCol BIT NOT NULL, + tinyintCol TINYINT NOT NULL, + smallintCol SMALLINT NOT NULL, + mediumintCol MEDIUMINT NOT NULL, + mediumintUnsignedCol MEDIUMINT UNSIGNED NOT NULL, + integerCol INTEGER NOT NULL, + intCol INT NOT NULL, + integerUnsignedCol INTEGER UNSIGNED NOT NULL, + bigintCol BIGINT NOT NULL, + floatCol FLOAT NOT NULL, + doubleCol DOUBLE NOT NULL, + decimalCol DECIMAL NOT NULL, + dateCol DATE NOT NULL, + datetimeCol DATETIME NOT NULL, + timestampCol TIMESTAMP NOT NULL, + timeCol TIME NOT NULL, + yearCol YEAR NOT NULL, + varcharCol VARCHAR(255) NOT NULL, + charCol CHAR(10) NOT NULL, + binaryCol BINARY(64) NOT NULL, + varbinaryCol VARBINARY(128) NOT NULL, + tinyblobCol TINYBLOB NOT NULL, + blobCol BLOB NOT NULL, + mediumblobCol MEDIUMBLOB NOT NULL , + longblobCol LONGBLOB NOT NULL, + textCol TEXT NOT NULL, + mediumtextCol MEDIUMTEXT NOT NULL, + longtextCol LONGTEXT NOT NULL, + enumCol ENUM('Value1', 'Value2', 'Value3') NOT NULL, + setCol SET('Option1', 'Option2', 'Option3') NOT NULL, + jsonCol JSON NOT NULL + CHECK (JSON_VALID(jsonCol)) ) """ connection.createStatement().execute( @@ -177,8 +220,8 @@ class MariadbTest { bitCol, tinyintCol, smallintCol, mediumintCol, mediumintUnsignedCol, integerCol, intCol, integerUnsignedCol, bigintCol, floatCol, doubleCol, decimalCol, dateCol, datetimeCol, timestampCol, timeCol, yearCol, varcharCol, charCol, binaryCol, varbinaryCol, tinyblobCol, blobCol, - mediumblobCol, longblobCol, textCol, mediumtextCol, longtextCol, enumCol, setCol - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + mediumblobCol, longblobCol, textCol, mediumtextCol, longtextCol, enumCol, setCol, jsonCol + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """.trimIndent() @@ -225,6 +268,7 @@ class MariadbTest { st.setString(28, "longtextValue$i") st.setString(29, "Value$i") st.setString(30, "Option$i") + st.setString(31, JSON_STRING) st.executeUpdate() } @@ -258,8 +302,8 @@ class MariadbTest { st.setBytes(23, "blobValue".toByteArray()) st.setBytes(24, "mediumblobValue".toByteArray()) st.setBytes(25, "longblobValue".toByteArray()) - st.setString(26, "textValue$i") - st.setString(27, "mediumtextValue$i") + st.setString(26, null) + st.setString(27, null) st.setString(28, "longtextValue$i") st.setString(29, "Value$i") st.setString(30, "Option$i") @@ -285,12 +329,20 @@ class MariadbTest { @Test fun `basic test for reading sql tables`() { val df1 = DataFrame.readSqlTable(connection, "table1").cast() - df1.print() - assertEquals(3, df1.rowsCount()) + val result = df1.filter { it[Table1MariaDb::id] == 1 } + result[0][26] shouldBe "textValue1" + + val schema = DataFrame.getSchemaForSqlTable(connection, "table1") + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["textCol"]!!.type shouldBe typeOf() - val df2 = DataFrame.readSqlTable(connection, "table2").cast() - df2.print() - assertEquals(3, df2.rowsCount()) + val df2 = DataFrame.readSqlTable(connection, "table2").cast() + val result2 = df2.filter { it[Table2MariaDb::id] == 1 } + result2[0][26] shouldBe null + + val schema2 = DataFrame.getSchemaForSqlTable(connection, "table2") + schema2.columns["id"]!!.type shouldBe typeOf() + schema2.columns["textCol"]!!.type shouldBe typeOf() } @Test @@ -299,30 +351,40 @@ class MariadbTest { val sqlQuery = """ SELECT t1.id, - t2.enumCol, + t1.enumCol, t2.setCol FROM table1 t1 - JOIN table2 t2 ON t1.id = t2.id; + JOIN table2 t2 ON t1.id = t2.id """.trimIndent() - val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery).cast() - df.rowsCount() shouldBe 3 + val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery).cast() + val result = df.filter { it[Table3MariaDb::id] == 1 } + result[0][2] shouldBe "Option1" + + val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery = sqlQuery) + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["enumCol"]!!.type shouldBe typeOf() + schema.columns["setCol"]!!.type shouldBe typeOf() } @Test fun `read from all tables`() { - val dataframes = DataFrame.readAllSqlTables(connection) + val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 1000) val table1Df = dataframes[0].cast() table1Df.rowsCount() shouldBe 3 table1Df.filter { it[Table1MariaDb::integerCol] > 100 }.rowsCount() shouldBe 2 table1Df[0][11] shouldBe 10.0 + table1Df[0][26] shouldBe "textValue1" + table1Df[0][31] shouldBe JSON_STRING // TODO: https://github.com/Kotlin/dataframe/issues/462 - val table2Df = dataframes[1].cast() + val table2Df = dataframes[1].cast() table2Df.rowsCount() shouldBe 3 - table2Df.filter { it[Table1MariaDb::integerCol] > 400 }.rowsCount() shouldBe 1 + table2Df.filter { it[Table2MariaDb::integerCol] != null && it[Table2MariaDb::integerCol]!! > 400 } + .rowsCount() shouldBe 1 table2Df[0][11] shouldBe 20.0 + table2Df[0][26] shouldBe null } } diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt index bb717fdc81..059d7fe7ef 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mysqlTest.kt @@ -14,6 +14,7 @@ import java.sql.Connection import java.sql.DriverManager import java.sql.SQLException import org.junit.Ignore +import kotlin.reflect.typeOf private const val URL = "jdbc:mysql://localhost:3306" private const val USER_NAME = "root" @@ -52,14 +53,50 @@ interface Table1MySql { val mediumtextCol: String val longtextCol: String val enumCol: String - val setCol: String + val setCol: Char } @DataSchema interface Table2MySql { + val id: Int + val bitCol: Boolean? + val tinyintCol: Int? + val smallintCol: Int? + val mediumintCol: Int? + val mediumintUnsignedCol: Long? + val integerCol: Int? + val intCol: Int? + val integerUnsignedCol: Long? + val bigintCol: Long? + val floatCol: Float? + val doubleCol: Double? + val decimalCol: Double? + val dateCol: String? + val datetimeCol: String? + val timestampCol: String? + val timeCol: String? + val yearCol: String? + val varcharCol: String? + val charCol: String? + val binaryCol: ByteArray? + val varbinaryCol: ByteArray? + val tinyblobCol: ByteArray? + val blobCol: ByteArray? + val mediumblobCol: ByteArray? + val longblobCol: ByteArray? + val textCol: String? + val mediumtextCol: String? + val longtextCol: String? + val enumCol: String? + val setCol: Char? + val jsonCol: String? +} + +@DataSchema +interface Table3MySql { val id: Int val enumCol: String - val setCol: String + val setCol: Char? } @Ignore @@ -93,38 +130,39 @@ class MySqlTest { val createTableQuery = """ CREATE TABLE IF NOT EXISTS table1 ( id INT AUTO_INCREMENT PRIMARY KEY, - bitCol BIT, - tinyintCol TINYINT, - smallintCol SMALLINT, - mediumintCol MEDIUMINT, - mediumintUnsignedCol MEDIUMINT UNSIGNED, - integerCol INTEGER, - intCol INT, - integerUnsignedCol INTEGER UNSIGNED, - bigintCol BIGINT, - floatCol FLOAT, - doubleCol DOUBLE, - decimalCol DECIMAL, - dateCol DATE, - datetimeCol DATETIME, - timestampCol TIMESTAMP, - timeCol TIME, - yearCol YEAR, - varcharCol VARCHAR(255), - charCol CHAR(10), - binaryCol BINARY(64), - varbinaryCol VARBINARY(128), - tinyblobCol TINYBLOB, - blobCol BLOB, - mediumblobCol MEDIUMBLOB, - longblobCol LONGBLOB, - textCol TEXT, - mediumtextCol MEDIUMTEXT, - longtextCol LONGTEXT, - enumCol ENUM('Value1', 'Value2', 'Value3'), - setCol SET('Option1', 'Option2', 'Option3'), + bitCol BIT NOT NULL, + tinyintCol TINYINT NOT NULL, + smallintCol SMALLINT NOT NULL, + mediumintCol MEDIUMINT NOT NULL, + mediumintUnsignedCol MEDIUMINT UNSIGNED NOT NULL, + integerCol INTEGER NOT NULL, + intCol INT NOT NULL, + integerUnsignedCol INTEGER UNSIGNED NOT NULL, + bigintCol BIGINT NOT NULL, + floatCol FLOAT NOT NULL, + doubleCol DOUBLE NOT NULL, + decimalCol DECIMAL NOT NULL, + dateCol DATE NOT NULL, + datetimeCol DATETIME NOT NULL, + timestampCol TIMESTAMP NOT NULL, + timeCol TIME NOT NULL, + yearCol YEAR NOT NULL, + varcharCol VARCHAR(255) NOT NULL, + charCol CHAR(10) NOT NULL, + binaryCol BINARY(64) NOT NULL, + varbinaryCol VARBINARY(128) NOT NULL, + tinyblobCol TINYBLOB NOT NULL, + blobCol BLOB NOT NULL, + mediumblobCol MEDIUMBLOB NOT NULL , + longblobCol LONGBLOB NOT NULL, + textCol TEXT NOT NULL, + mediumtextCol MEDIUMTEXT NOT NULL, + longtextCol LONGTEXT NOT NULL, + enumCol ENUM('Value1', 'Value2', 'Value3') NOT NULL, + setCol SET('Option1', 'Option2', 'Option3') NOT NULL, location GEOMETRY, data JSON + CHECK (JSON_VALID(data)) ) """ @@ -168,6 +206,7 @@ class MySqlTest { setCol SET('Option1', 'Option2', 'Option3'), location GEOMETRY, data JSON + CHECK (JSON_VALID(data)) ) """ @@ -261,8 +300,8 @@ class MySqlTest { st.setBytes(23, "blobValue".toByteArray()) st.setBytes(24, "mediumblobValue".toByteArray()) st.setBytes(25, "longblobValue".toByteArray()) - st.setString(26, "textValue$i") - st.setString(27, "mediumtextValue$i") + st.setString(26, null) + st.setString(27, null) st.setString(28, "longtextValue$i") st.setString(29, "Value$i") st.setString(30, "Option$i") @@ -288,13 +327,21 @@ class MySqlTest { @Test fun `basic test for reading sql tables`() { - val df1 = DataFrame.readSqlTable(connection, "table1").cast() - df1.rowsCount() shouldBe 3 + val df1 = DataFrame.readSqlTable(connection, "table1").cast() + val result = df1.filter { it[Table1MySql::id] == 1 } + result[0][26] shouldBe "textValue1" - val df2 = DataFrame.readSqlTable(connection, "table2").cast() - df2.rowsCount() shouldBe 3 + val schema = DataFrame.getSchemaForSqlTable(connection, "table1") + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["textCol"]!!.type shouldBe typeOf() - //TODO: add test for JSON column + val df2 = DataFrame.readSqlTable(connection, "table2").cast() + val result2 = df2.filter { it[Table2MySql::id] == 1 } + result2[0][26] shouldBe null + + val schema2 = DataFrame.getSchemaForSqlTable(connection, "table2") + schema2.columns["id"]!!.type shouldBe typeOf() + schema2.columns["textCol"]!!.type shouldBe typeOf() } @Test @@ -303,14 +350,20 @@ class MySqlTest { val sqlQuery = """ SELECT t1.id, - t2.enumCol, + t1.enumCol, t2.setCol FROM table1 t1 - JOIN table2 t2 ON t1.id = t2.id; + JOIN table2 t2 ON t1.id = t2.id """.trimIndent() - val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery).cast() - df.rowsCount() shouldBe 3 + val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery).cast() + val result = df.filter { it[Table3MySql::id] == 1 } + result[0][2] shouldBe "Option1" + + val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery = sqlQuery) + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["enumCol"]!!.type shouldBe typeOf() + schema.columns["setCol"]!!.type shouldBe typeOf() } @Test @@ -320,13 +373,16 @@ class MySqlTest { val table1Df = dataframes[0].cast() table1Df.rowsCount() shouldBe 3 - table1Df.filter { it[Table1MariaDb::integerCol] > 100 }.rowsCount() shouldBe 2 + table1Df.filter { it[Table1MySql::integerCol] > 100 }.rowsCount() shouldBe 2 table1Df[0][11] shouldBe 10.0 + table1Df[0][26] shouldBe "textValue1" - val table2Df = dataframes[1].cast() + val table2Df = dataframes[1].cast() table2Df.rowsCount() shouldBe 3 - table2Df.filter { it[Table1MariaDb::integerCol] > 400 }.rowsCount() shouldBe 1 + table2Df.filter { it[Table2MySql::integerCol] != null && it[Table2MySql::integerCol]!! > 400 } + .rowsCount() shouldBe 1 table2Df[0][11] shouldBe 20.0 + table2Df[0][26] shouldBe null } } diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt index 198e54283b..4c916c777a 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/postgresTest.kt @@ -16,6 +16,7 @@ import java.sql.DriverManager import java.sql.SQLException import java.util.UUID import org.junit.Ignore +import kotlin.reflect.typeOf private const val URL = "jdbc:postgresql://localhost:5432/test" private const val USER_NAME = "postgres" @@ -35,7 +36,7 @@ interface Table1 { val circlecol: String val datecol: java.sql.Date val doublecol: Double - val integercol: Int + val integercol: Int? val intervalcol: String val jsoncol: String val jsonbcol: String @@ -44,19 +45,19 @@ interface Table1 { @DataSchema interface Table2 { val id: Int - val linecol: String + val linecol: org.postgresql.geometric.PGline val lsegcol: String val macaddrcol: String val moneycol: String val numericcol: String - val pathcol: String + val pathcol: org.postgresql.geometric.PGpath val pointcol: String val polygoncol: String val realcol: Float val smallintcol: Short val smallserialcol: Int val serialcol: Int - val textcol: String + val textcol: String? val timecol: String val timewithzonecol: String val timestampcol: String @@ -70,7 +71,7 @@ interface ViewTable { val id: Int val bigintcol: Long val linecol: String - val numericcol: String + val textCol: String? } @Ignore @@ -90,21 +91,21 @@ class PostgresTest { val createTableStatement = """ CREATE TABLE IF NOT EXISTS table1 ( id serial PRIMARY KEY, - bigintCol bigint, - bigserialCol bigserial, - booleanCol boolean, - boxCol box, - byteaCol bytea, - characterCol character, - characterNCol character(10), - charCol char, - circleCol circle, - dateCol date, - doubleCol double precision, + bigintCol bigint not null, + bigserialCol bigserial not null, + booleanCol boolean not null, + boxCol box not null, + byteaCol bytea not null, + characterCol character not null, + characterNCol character(10) not null, + charCol char not null, + circleCol circle not null, + dateCol date not null, + doubleCol double precision not null, integerCol integer, - intervalCol interval, - jsonCol json, - jsonbCol jsonb + intervalCol interval not null, + jsonCol json not null, + jsonbCol jsonb not null ) """ connection.createStatement().execute( @@ -115,25 +116,25 @@ class PostgresTest { val createTableQuery = """ CREATE TABLE IF NOT EXISTS table2 ( id serial PRIMARY KEY, - lineCol line, - lsegCol lseg, - macaddrCol macaddr, - moneyCol money, - numericCol numeric, - pathCol path, - pointCol point, - polygonCol polygon, - realCol real, - smallintCol smallint, - smallserialCol smallserial, - serialCol serial, + lineCol line not null, + lsegCol lseg not null, + macaddrCol macaddr not null, + moneyCol money not null, + numericCol numeric not null, + pathCol path not null, + pointCol point not null, + polygonCol polygon not null, + realCol real not null, + smallintCol smallint not null, + smallserialCol smallserial not null, + serialCol serial not null, textCol text, - timeCol time, - timeWithZoneCol time with time zone, - timestampCol timestamp, - timestampWithZoneCol timestamp with time zone, - uuidCol uuid, - xmlCol xml + timeCol time not null, + timeWithZoneCol time with time zone not null, + timestampCol timestamp not null, + timestampWithZoneCol timestamp with time zone not null, + uuidCol uuid not null, + xmlCol xml not null ) """ connection.createStatement().execute( @@ -208,7 +209,7 @@ class PostgresTest { st.setShort(10, (i * 100).toShort()) st.setInt(11, 1000 + i) st.setInt(12, 1000000 + i) - st.setString(13, "Text data $i") + st.setString(13, null) st.setTime(14, java.sql.Time.valueOf("12:34:56")) st.setTimestamp(15, java.sql.Timestamp(System.currentTimeMillis())) @@ -241,11 +242,27 @@ class PostgresTest { @Test fun `read from tables`() { - val df1 = DataFrame.readSqlTable(connection, "table1").cast() - df1.rowsCount() shouldBe 3 - - val df2 = DataFrame.readSqlTable(connection, "table2").cast() - df2.rowsCount() shouldBe 3 + val tableName1 = "table1" + val df1 = DataFrame.readSqlTable(connection, tableName1).cast() + val result = df1.filter { it[Table1::id] == 1 } + result[0][12] shouldBe 12345 + + val schema = DataFrame.getSchemaForSqlTable(connection, tableName1) + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["integercol"]!!.type shouldBe typeOf() + schema.columns["circlecol"]!!.type shouldBe typeOf() + + val tableName2 = "table2" + val df2 = DataFrame.readSqlTable(connection, tableName2).cast() + val result2 = df2.filter { it[Table2::id] == 1 } + result2[0][11] shouldBe 1001 + result2[0][13] shouldBe null + + val schema2 = DataFrame.getSchemaForSqlTable(connection, tableName2) + schema2.columns["id"]!!.type shouldBe typeOf() + schema2.columns["pathcol"]!!.type shouldBe typeOf() // TODO: https://github.com/Kotlin/dataframe/issues/537 + schema2.columns["textcol"]!!.type shouldBe typeOf() + schema2.columns["linecol"]!!.type shouldBe typeOf() // TODO: https://github.com/Kotlin/dataframe/issues/537 } @Test @@ -253,16 +270,22 @@ class PostgresTest { @Language("SQL") val sqlQuery = """ SELECT - t1.id AS t1_id, + t1.id, t1.bigintCol, t2.lineCol, - t2.numericCol + t2.textCol FROM table1 t1 - JOIN table2 t2 ON t1.id = t2.id; + JOIN table2 t2 ON t1.id = t2.id """.trimIndent() - val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery).cast() - df.rowsCount() shouldBe 3 + val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery).cast() + val result = df.filter { it[ViewTable::id] == 1 } + result[0][3] shouldBe null + + val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery = sqlQuery) + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["bigintcol"]!!.type shouldBe typeOf() + schema.columns["textcol"]!!.type shouldBe typeOf() } @Test @@ -272,15 +295,14 @@ class PostgresTest { val table1Df = dataframes[0].cast() table1Df.rowsCount() shouldBe 3 - table1Df.filter { it[Table1::integercol] > 12345 }.rowsCount() shouldBe 2 + table1Df.filter { it[Table1::integercol] != null && it[Table1::integercol]!! > 12345 }.rowsCount() shouldBe 2 table1Df[0][1] shouldBe 1000L val table2Df = dataframes[1].cast() table2Df.rowsCount() shouldBe 3 - table2Df.filter { it[Table2::pathcol] == "((1,2),(3,1))" }.rowsCount() shouldBe 1 + table2Df.filter { it[Table2::pathcol] == org.postgresql.geometric.PGpath("((1,2),(3,1))") } + .rowsCount() shouldBe 1 table2Df[0][11] shouldBe 1001 - - //TODO: add test for JSON column } } diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt index afb397ca8f..0cf2bf0757 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/sqliteTest.kt @@ -13,38 +13,39 @@ import java.sql.Connection import java.sql.DriverManager import java.sql.SQLException import org.junit.Ignore +import kotlin.reflect.typeOf private const val DATABASE_URL = "jdbc:sqlite:" @DataSchema interface CustomerSQLite { - val id: Int - val name: String - val age: Int + val id: Int? + val name: String? + val age: Int? val salary: Double - val profilePicture: ByteArray + val profilePicture: ByteArray? } @DataSchema interface OrderSQLite { - val id: Int - val customerName: String - val orderDate: String + val id: Int? + val customerName: String? + val orderDate: String? val totalAmount: Double - val orderDetails: ByteArray + val orderDetails: ByteArray? } @DataSchema interface CustomerOrderSQLite { - val customerId: Int - val customerName: String - val customerAge: Int + val customerId: Int? + val customerName: String? + val customerAge: Int? val customerSalary: Double - val customerProfilePicture: ByteArray - val orderId: Int - val orderDate: String + val customerProfilePicture: ByteArray? + val orderId: Int? + val orderDate: String? val totalAmount: Double - val orderDetails: ByteArray + val orderDetails: ByteArray? } @Ignore @@ -60,10 +61,10 @@ class SqliteTest { @Language("SQL") val createCustomersTableQuery = """ CREATE TABLE Customers ( - id INTEGER AUTO_INCREMENT PRIMARY KEY, + id INTEGER PRIMARY KEY, name TEXT, age INTEGER, - salary REAL, + salary REAL NOT NULL, profilePicture BLOB ) """ @@ -75,10 +76,10 @@ class SqliteTest { @Language("SQL") val createOrderTableQuery = """ CREATE TABLE Orders ( - id INTEGER AUTO_INCREMENT PRIMARY KEY, + id INTEGER PRIMARY KEY, customerName TEXT, orderDate TEXT, - totalAmount NUMERIC, + totalAmount NUMERIC NOT NULL, orderDetails BLOB ) """ @@ -101,7 +102,7 @@ class SqliteTest { connection.prepareStatement("INSERT INTO Customers (name, age, salary, profilePicture) VALUES (?, ?, ?, ?)") .use { - it.setString(1, "Max Joint") + it.setString(1, null) it.setInt(2, 40) it.setDouble(3, 1500.50) it.setBytes(4, profilePicture) @@ -110,7 +111,7 @@ class SqliteTest { connection.prepareStatement("INSERT INTO Orders (customerName, orderDate, totalAmount, orderDetails) VALUES (?, ?, ?, ?)") .use { - it.setString(1, "John Doe") + it.setString(1, null) it.setString(2, "2023-07-21") it.setDouble(3, 150.75) it.setBytes(4, orderDetails) @@ -119,7 +120,7 @@ class SqliteTest { connection.prepareStatement("INSERT INTO Orders (customerName, orderDate, totalAmount, orderDetails) VALUES (?, ?, ?, ?)") .use { - it.setString(1, "Max Joint") + it.setString(1, "John Doe") it.setString(2, "2023-08-21") it.setDouble(3, 250.75) it.setBytes(4, orderDetails) @@ -140,11 +141,25 @@ class SqliteTest { @Test fun `read from tables`() { - val df = DataFrame.readSqlTable(connection, "Customers").cast() - df.rowsCount() shouldBe 2 - - val df2 = DataFrame.readSqlTable(connection, "Orders").cast() - df2.rowsCount() shouldBe 2 + val customerTableName = "Customers" + val df = DataFrame.readSqlTable(connection, customerTableName).cast() + val result = df.filter { it[CustomerSQLite::name] == "John Doe" } + result[0][2] shouldBe 30 + + val schema = DataFrame.getSchemaForSqlTable(connection, customerTableName) + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["name"]!!.type shouldBe typeOf() + schema.columns["salary"]!!.type shouldBe typeOf() + + val orderTableName = "Orders" + val df2 = DataFrame.readSqlTable(connection, orderTableName).cast() + val result2 = df2.filter { it[OrderSQLite::totalAmount] > 10 } + result2[0][2] shouldBe "2023-07-21" + + val schema2 = DataFrame.getSchemaForSqlTable(connection, orderTableName) + schema2.columns["id"]!!.type shouldBe typeOf() + schema2.columns["customerName"]!!.type shouldBe typeOf() + schema2.columns["totalAmount"]!!.type shouldBe typeOf() } @Test @@ -166,10 +181,14 @@ class SqliteTest { """ val df = DataFrame.readSqlQuery(connection, sqlQuery).cast() - df.rowsCount() shouldBe 2 - - val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery) - schema.columns.entries.size shouldBe 9 + val result = df.filter { it[CustomerOrderSQLite::customerSalary] > 1 } + result[0][3] shouldBe 2500.5 + + val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery = sqlQuery) + schema.columns["customerId"]!!.type shouldBe typeOf() + schema.columns["customerName"]!!.type shouldBe typeOf() + schema.columns["customerAge"]!!.type shouldBe typeOf() + schema.columns["totalAmount"]!!.type shouldBe typeOf() } @Test @@ -179,13 +198,13 @@ class SqliteTest { val customerDf = dataframes[0].cast() customerDf.rowsCount() shouldBe 2 - customerDf.filter { it[CustomerSQLite::age] > 30 }.rowsCount() shouldBe 1 + customerDf.filter { it[CustomerSQLite::age] != null && it[CustomerSQLite::age]!! > 30 }.rowsCount() shouldBe 1 customerDf[0][1] shouldBe "John Doe" val orderDf = dataframes[1].cast() orderDf.rowsCount() shouldBe 2 orderDf.filter { it[OrderSQLite::totalAmount] > 200 }.rowsCount() shouldBe 1 - orderDf[0][1] shouldBe "John Doe" + orderDf[0][1] shouldBe null } } diff --git a/docs/StardustDocs/topics/readSqlDatabases.md b/docs/StardustDocs/topics/readSqlDatabases.md index 37467905f4..a79f4a455c 100644 --- a/docs/StardustDocs/topics/readSqlDatabases.md +++ b/docs/StardustDocs/topics/readSqlDatabases.md @@ -4,15 +4,15 @@ These functions allow you to interact with an SQL database using a Kotlin DataFr There are two main blocks of available functionality: * reading data from the database - * reading specific tables - * executing SQL queries - * reading from ResultSet - * reading entire tables (all non-system tables) + * function ```readSqlTable``` reads specific database table + * function ```readSqlQuery``` executes SQL query + * function ```readResultSet``` reads from created earlier ResultSet + * function ```readAllSqlTables``` reads all tables (all non-system tables) * schema retrieval - * for specific tables - * for result of executing SQL queries - * for rows reading through the given ResultSet - * for all non-system tables + * ```getSchemaForSqlTable``` for specific tables + * ```getSchemaForSqlQuery``` for result of executing SQL queries + * ```getSchemaForResultSet``` for rows reading through the given ResultSet + * ```getSchemaForAllSqlTables``` for all non-system tables ## Getting started with reading from SQL database @@ -23,6 +23,32 @@ In the first, you need to add a dependency implementation("org.jetbrains.kotlinx:dataframe-jdbc:$dataframe_version") ``` +after that, you need to add a dependency for a JDBC driver for the used database, for example + +For MariaDB: + +```kotlin +implementation("org.mariadb.jdbc:mariadb-java-client:3.1.4") +``` + +For PostgreSQL: + +```kotlin +implementation("org.postgresql:postgresql:42.6.0") +``` + +For MySQL: + +```kotlin +implementation("mysql:mysql-connector-java:8.0.33") +``` + +For SQLite: + +```kotlin +implementation("org.xerial:sqlite-jdbc:3.42.0.1") +``` + In the second, be sure that you can establish a connection to the database. For this, usually, you need to have three things: a URL to a database, a username and a password. @@ -250,7 +276,7 @@ connection.close() These functions read all data from all tables in the connected database. Variants with a limit parameter restrict how many rows will be read from each table. -**readAllSqlTables(connection: Connection): List** +**readAllSqlTables(connection: Connection): List\** Retrieves data from all the non-system tables in the SQL database and returns them as a list of AnyFrame objects. @@ -265,7 +291,7 @@ val dbConfig = DatabaseConfiguration("URL_TO_CONNECT_DATABASE", "USERNAME", "PAS val dataframes = DataFrame.readAllSqlTables(dbConfig) ``` -**readAllSqlTables(connection: Connection, limit: Int): List** +**readAllSqlTables(connection: Connection, limit: Int): List\** A variant of the previous function, but with an added `limit: Int` parameter that allows setting the maximum number of records to be read from each table. @@ -280,7 +306,7 @@ val dbConfig = DatabaseConfiguration("URL_TO_CONNECT_DATABASE", "USERNAME", "PAS val dataframes = DataFrame.readAllSqlTables(dbConfig, 100) ``` -**readAllSqlTables(connection: Connection): List** +**readAllSqlTables(connection: Connection): List\** Another variant, where instead of `dbConfig: DatabaseConfiguration` we use a JDBC connection: `Connection` object. @@ -295,7 +321,7 @@ val dataframes = DataFrame.readAllSqlTables(connection) connection.close() ``` -**readAllSqlTables(connection: Connection, limit: Int): List** +**readAllSqlTables(connection: Connection, limit: Int): List\** A variant of the previous function, but with an added `limit: Int` parameter that allows setting the maximum number of records to be read from each table. @@ -428,7 +454,7 @@ connection.close() These functions return a list of all [`DataFrameSchema`](schema.md) from all the non-system tables in the SQL database. They can be called with either a database configuration or a connection. -**getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): List** +**getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): List\** This function retrieves the schema of all tables from an SQL database and returns them as a list of [`DataFrameSchema`](schema.md). @@ -444,7 +470,7 @@ val dbConfig = DatabaseConfiguration("URL_TO_CONNECT_DATABASE", "USERNAME", "PAS val schemas = DataFrame.getSchemaForAllSqlTables(dbConfig) ``` -**getSchemaForAllSqlTables(connection: Connection): List** This function retrieves the schema of all tables using a JDBC connection: `Connection` object and returns them as a list of [`DataFrameSchema`](schema.md). diff --git a/plugins/dataframe-gradle-plugin/src/integrationTest/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorPluginIntegrationTest.kt b/plugins/dataframe-gradle-plugin/src/integrationTest/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorPluginIntegrationTest.kt index 381f682225..5faa5ffffe 100644 --- a/plugins/dataframe-gradle-plugin/src/integrationTest/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorPluginIntegrationTest.kt +++ b/plugins/dataframe-gradle-plugin/src/integrationTest/kotlin/org/jetbrains/dataframe/gradle/SchemaGeneratorPluginIntegrationTest.kt @@ -411,39 +411,39 @@ class SchemaGeneratorPluginIntegrationTest : AbstractDataFramePluginIntegrationT // DataFrameJdbcSymbolProcessorTest.`schema extracted via readFromDB method is resolved` main.writeText( """ - @file:ImportDataSchema(name = "Customer", path = "$connectionUrl") - - package test - - import org.jetbrains.kotlinx.dataframe.annotations.ImportDataSchema - import org.jetbrains.kotlinx.dataframe.api.filter - import org.jetbrains.kotlinx.dataframe.DataFrame - import org.jetbrains.kotlinx.dataframe.api.cast - import java.sql.Connection - import java.sql.DriverManager - import java.sql.SQLException - import org.jetbrains.kotlinx.dataframe.io.readSqlTable - import org.jetbrains.kotlinx.dataframe.io.DatabaseConfiguration - - fun main() { - Class.forName("org.h2.Driver") - val tableName = "Customer" - DriverManager.getConnection("$connectionUrl").use { connection -> - val df = DataFrame.readSqlTable(connection, tableName).cast() - df.filter { age > 30 } - - val df1 = DataFrame.readSqlTable(connection, tableName, 1).cast() - df1.filter { age > 30 } - - val dbConfig = DatabaseConfiguration(url = "$connectionUrl") - val df2 = DataFrame.readSqlTable(dbConfig, tableName).cast() - df2.filter { age > 30 } - - val df3 = DataFrame.readSqlTable(dbConfig, tableName, 1).cast() - df3.filter { age > 30 } - - } + @file:ImportDataSchema(name = "Customer", path = "$connectionUrl") + + package test + + import org.jetbrains.kotlinx.dataframe.annotations.ImportDataSchema + import org.jetbrains.kotlinx.dataframe.api.filter + import org.jetbrains.kotlinx.dataframe.DataFrame + import org.jetbrains.kotlinx.dataframe.api.cast + import java.sql.Connection + import java.sql.DriverManager + import java.sql.SQLException + import org.jetbrains.kotlinx.dataframe.io.readSqlTable + import org.jetbrains.kotlinx.dataframe.io.DatabaseConfiguration + + fun main() { + Class.forName("org.h2.Driver") + val tableName = "Customer" + DriverManager.getConnection("$connectionUrl").use { connection -> + val df = DataFrame.readSqlTable(connection, tableName).cast() + df.filter { age != null && age > 30 } + + val df1 = DataFrame.readSqlTable(connection, tableName, 1).cast() + df1.filter { age != null && age > 30 } + + val dbConfig = DatabaseConfiguration(url = "$connectionUrl") + val df2 = DataFrame.readSqlTable(dbConfig, tableName).cast() + df2.filter { age != null && age > 30 } + + val df3 = DataFrame.readSqlTable(dbConfig, tableName, 1).cast() + df3.filter { age != null && age > 30 } + } + } """.trimIndent() ) diff --git a/plugins/symbol-processor/src/test/kotlin/org/jetbrains/dataframe/ksp/DataFrameJdbcSymbolProcessorTest.kt b/plugins/symbol-processor/src/test/kotlin/org/jetbrains/dataframe/ksp/DataFrameJdbcSymbolProcessorTest.kt index 9be7ec101c..1a9ad70f78 100644 --- a/plugins/symbol-processor/src/test/kotlin/org/jetbrains/dataframe/ksp/DataFrameJdbcSymbolProcessorTest.kt +++ b/plugins/symbol-processor/src/test/kotlin/org/jetbrains/dataframe/ksp/DataFrameJdbcSymbolProcessorTest.kt @@ -167,43 +167,43 @@ class DataFrameJdbcSymbolProcessorTest { SourceFile.kotlin( "MySources.kt", """ - @file:ImportDataSchema( - "Customer", - "$CONNECTION_URL", - jdbcOptions = JdbcOptions("", "", tableName = "Customer") - ) + @file:ImportDataSchema( + "Customer", + "$CONNECTION_URL", + jdbcOptions = JdbcOptions("", "", tableName = "Customer") + ) + + package test + + import org.jetbrains.kotlinx.dataframe.annotations.ImportDataSchema + import org.jetbrains.kotlinx.dataframe.annotations.JdbcOptions + import org.jetbrains.kotlinx.dataframe.api.filter + import org.jetbrains.kotlinx.dataframe.DataFrame + import org.jetbrains.kotlinx.dataframe.api.cast + import java.sql.Connection + import java.sql.DriverManager + import java.sql.SQLException + import org.jetbrains.kotlinx.dataframe.io.readSqlTable + import org.jetbrains.kotlinx.dataframe.io.DatabaseConfiguration + + fun main() { + val tableName = "Customer" + DriverManager.getConnection("$CONNECTION_URL").use { connection -> + val df = DataFrame.readSqlTable(connection, tableName).cast() + df.filter { it[Customer::age] != null && it[Customer::age]!! > 30 } + + val df1 = DataFrame.readSqlTable(connection, tableName, 1).cast() + df1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 } - package test + val dbConfig = DatabaseConfiguration(url = "$CONNECTION_URL") + val df2 = DataFrame.readSqlTable(dbConfig, tableName).cast() + df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 } - import org.jetbrains.kotlinx.dataframe.annotations.ImportDataSchema - import org.jetbrains.kotlinx.dataframe.annotations.JdbcOptions - import org.jetbrains.kotlinx.dataframe.api.filter - import org.jetbrains.kotlinx.dataframe.DataFrame - import org.jetbrains.kotlinx.dataframe.api.cast - import java.sql.Connection - import java.sql.DriverManager - import java.sql.SQLException - import org.jetbrains.kotlinx.dataframe.io.readSqlTable - import org.jetbrains.kotlinx.dataframe.io.DatabaseConfiguration - - fun main() { - val tableName = "Customer" - DriverManager.getConnection("$CONNECTION_URL").use { connection -> - val df = DataFrame.readSqlTable(connection, tableName).cast() - df.filter { age > 30 } - - val df1 = DataFrame.readSqlTable(connection, tableName, 1).cast() - df1.filter { age > 30 } - - val dbConfig = DatabaseConfiguration(url = "$CONNECTION_URL") - val df2 = DataFrame.readSqlTable(dbConfig, tableName).cast() - df2.filter { age > 30 } - - val df3 = DataFrame.readSqlTable(dbConfig, tableName, 1).cast() - df3.filter { age > 30 } - - } - } + val df3 = DataFrame.readSqlTable(dbConfig, tableName, 1).cast() + df3.filter { it[Customer::age] != null && it[Customer::age]!! > 30 } + + } + } """.trimIndent() ) )