diff --git a/core/api/core.api b/core/api/core.api index 98a4dce3ae..36eb4607ae 100644 --- a/core/api/core.api +++ b/core/api/core.api @@ -1387,6 +1387,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/ConstructorsKt { public static final fun columnGroupTyped (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;Lorg/jetbrains/kotlinx/dataframe/columns/ColumnPath;)Lorg/jetbrains/kotlinx/dataframe/columns/ColumnAccessor; public static final fun columnOf (Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/DataColumn; public static final fun columnOf (Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/columns/FrameColumn; + public static final fun columnOf ([Lkotlin/Pair;)Lorg/jetbrains/kotlinx/dataframe/columns/ColumnGroup; public static final fun columnOf ([Lorg/jetbrains/kotlinx/dataframe/DataFrame;)Lorg/jetbrains/kotlinx/dataframe/columns/FrameColumn; public static final fun columnOf ([Lorg/jetbrains/kotlinx/dataframe/columns/BaseColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn; public static final fun dataFrameOf (Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; @@ -1397,6 +1398,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/ConstructorsKt { public static final fun dataFrameOf ([Lkotlin/Pair;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun dataFrameOf ([Lorg/jetbrains/kotlinx/dataframe/columns/BaseColumn;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun dataFrameOf ([Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)Lorg/jetbrains/kotlinx/dataframe/api/DataFrameBuilder; + public static final fun dataFrameOfColumns ([Lkotlin/Pair;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun emptyDataFrame ()Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun frameColumn ()Lorg/jetbrains/kotlinx/dataframe/api/ColumnDelegate; public static final fun frameColumn (Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/columns/ColumnAccessor; diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/constructors.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/constructors.kt index 6b1248c44b..e56038470d 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/constructors.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/constructors.kt @@ -13,6 +13,7 @@ import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload import org.jetbrains.kotlinx.dataframe.annotations.Interpretable import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnAccessor +import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup import org.jetbrains.kotlinx.dataframe.columns.ColumnPath import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.FrameColumn @@ -269,6 +270,15 @@ public inline fun column(values: Iterable): DataColumn = allColsMakesColGroup = true, ).forceResolve() +@Refine +@Interpretable("ColumnOfPairs") +public fun columnOf(vararg columns: Pair): ColumnGroup<*> = + dataFrameOf( + columns.map { (name, col) -> + col.rename(name) + }, + ).asColumnGroup() + // endregion // region create DataFrame @@ -290,6 +300,12 @@ public fun dataFrameOf(columns: Iterable): DataFrame<*> { return DataFrameImpl(cols, nrow) } +@Refine +@JvmName("dataFrameOfColumns") +@Interpretable("DataFrameOfPairs") +public fun dataFrameOf(vararg columns: Pair): DataFrame<*> = + dataFrameOf(columns.map { (name, col) -> col.rename(name) }) + public fun dataFrameOf(vararg header: ColumnReference<*>): DataFrameBuilder = DataFrameBuilder(header.map { it.name() }) public fun dataFrameOf(vararg columns: AnyBaseCol): DataFrame<*> = dataFrameOf(columns.asIterable()) diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/Create.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/Create.kt index 4c533349a9..f99d2dfeeb 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/Create.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/samples/api/Create.kt @@ -261,6 +261,21 @@ class Create : TestBase() { // SampleEnd } + @Test + @TransformDataFrameExpressions + fun createNestedDataFrameInplace() { + // SampleStart + // DataFrame with 2 columns and 3 rows + val df = dataFrameOf( + "name" to columnOf( + "firstName" to columnOf("Alice", "Bob", "Charlie"), + "lastName" to columnOf("Cooper", "Dylan", "Daniels"), + ), + "age" to columnOf(15, 20, 100), + ) + // SampleEnd + } + @Test @TransformDataFrameExpressions fun createDataFrameWithFill() { diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTests.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTests.kt index 13e410bcb9..d566ab8469 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTests.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTests.kt @@ -1934,6 +1934,21 @@ class DataFrameTests : BaseTest() { df.columns().forEach { col -> col.forEachIndexed { row, value -> value shouldBe row + 1 } } } + @Test + fun `create nested dataframe inplace`() { + val df = dataFrameOf( + "a" to columnOf("1"), + "b" to columnOf( + "c" to columnOf("2"), + ), + "d" to columnOf(dataFrameOf("a")(123)), + "gr" to listOf("1").toDataFrame().asColumnGroup(), + ) + + df.columnNames() shouldBe listOf("a", "b", "d", "gr") + df.getColumnGroup("gr")["value"].values() shouldBe listOf("1") + } + @Test fun `get typed column by name`() { val col = df.getColumn("name").cast() diff --git a/docs/StardustDocs/topics/createDataFrame.md b/docs/StardustDocs/topics/createDataFrame.md index 78092bb2b1..8942d5b36e 100644 --- a/docs/StardustDocs/topics/createDataFrame.md +++ b/docs/StardustDocs/topics/createDataFrame.md @@ -44,6 +44,23 @@ val df = dataFrameOf( +Create DataFrame with nested columns inplace: + + + +```kotlin +// DataFrame with 2 columns and 3 rows +val df = dataFrameOf( + "name" to columnOf( + "firstName" to columnOf("Alice", "Bob", "Charlie"), + "lastName" to columnOf("Cooper", "Dylan", "Daniels"), + ), + "age" to columnOf(15, 20, 100), +) +``` + + + ```kotlin diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt index 6d7eab7167..09ac5a164f 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/FunctionCallTransformer.kt @@ -101,7 +101,13 @@ class FunctionCallTransformer( fun transformOrNull(call: FirFunctionCall, originalSymbol: FirNamedFunctionSymbol): FirFunctionCall? } - private val transformers = listOf(GroupByCallTransformer(), DataFrameCallTransformer(), DataRowCallTransformer()) + // also update [ReturnTypeBasedReceiverInjector.SCHEMA_TYPES] + private val transformers = listOf( + GroupByCallTransformer(), + DataFrameCallTransformer(), + DataRowCallTransformer(), + ColumnGroupCallTransformer(), + ) override fun intercept(callInfo: CallInfo, symbol: FirNamedFunctionSymbol): CallReturnType? { val callSiteAnnotations = (callInfo.callSite as? FirAnnotationContainer)?.annotations ?: emptyList() @@ -194,6 +200,8 @@ class FunctionCallTransformer( inner class DataRowCallTransformer : CallTransformer by DataSchemaLikeCallTransformer(Names.DATA_ROW_CLASS_ID) + inner class ColumnGroupCallTransformer : CallTransformer by DataSchemaLikeCallTransformer(Names.COLUM_GROUP_CLASS_ID) + inner class GroupByCallTransformer : CallTransformer { override fun interceptOrNull( callInfo: CallInfo, diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/ReturnTypeBasedReceiverInjector.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/ReturnTypeBasedReceiverInjector.kt index b522951c31..1624d200a0 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/ReturnTypeBasedReceiverInjector.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/ReturnTypeBasedReceiverInjector.kt @@ -16,10 +16,19 @@ import org.jetbrains.kotlin.fir.types.toRegularClassSymbol import org.jetbrains.kotlinx.dataframe.plugin.utils.Names class ReturnTypeBasedReceiverInjector(session: FirSession) : FirExpressionResolutionExtension(session) { + companion object { + private val SCHEMA_TYPES = setOf( + Names.DF_CLASS_ID, + Names.GROUP_BY_CLASS_ID, + Names.DATA_ROW_CLASS_ID, + Names.COLUM_GROUP_CLASS_ID, + ) + } + @OptIn(SymbolInternals::class) override fun addNewImplicitReceivers(functionCall: FirFunctionCall): List { val callReturnType = functionCall.resolvedType - return if (callReturnType.classId in setOf(Names.DF_CLASS_ID, Names.GROUP_BY_CLASS_ID, Names.DATA_ROW_CLASS_ID)) { + return if (callReturnType.classId in SCHEMA_TYPES) { val typeArguments = callReturnType.typeArguments typeArguments .mapNotNull { diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/dataFrameOf.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/dataFrameOf.kt index 1eeeca8c83..30bd5befd9 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/dataFrameOf.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/dataFrameOf.kt @@ -4,6 +4,9 @@ package org.jetbrains.kotlinx.dataframe.plugin.impl.api import org.jetbrains.kotlin.fir.expressions.FirExpression import org.jetbrains.kotlin.fir.expressions.FirLiteralExpression import org.jetbrains.kotlin.fir.expressions.FirVarargArgumentsExpression +import org.jetbrains.kotlin.fir.plugin.createConeType +import org.jetbrains.kotlin.fir.types.ConeKotlinType +import org.jetbrains.kotlin.fir.types.classId import org.jetbrains.kotlin.fir.types.commonSuperTypeOrNull import org.jetbrains.kotlin.fir.types.resolvedType import org.jetbrains.kotlin.fir.types.type @@ -15,6 +18,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf import org.jetbrains.kotlinx.dataframe.impl.api.withValuesImpl +import org.jetbrains.kotlinx.dataframe.plugin.utils.Names class DataFrameOf0 : AbstractInterpreter() { val Arguments.header: List by arg() @@ -53,3 +57,30 @@ class DataFrameOf3 : AbstractSchemaModificationInterpreter() { return PluginDataFrameSchema(res) } } + +abstract class SchemaConstructor : AbstractSchemaModificationInterpreter() { + val Arguments.columns: List>> by arg() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val res = columns.map { + val it = it.value + val name = (it.first as? FirLiteralExpression)?.value as? String + val resolvedType = (it.second as? FirExpression)?.resolvedType + val type: ConeKotlinType? = when (resolvedType?.classId) { + Names.COLUM_GROUP_CLASS_ID -> Names.DATA_ROW_CLASS_ID.createConeType(session, arrayOf(resolvedType.typeArguments[0])) + Names.FRAME_COLUMN_CLASS_ID -> Names.DF_CLASS_ID.createConeType(session, arrayOf(resolvedType.typeArguments[0])) + Names.DATA_COLUMN_CLASS_ID -> resolvedType.typeArguments[0] as? ConeKotlinType + Names.BASE_COLUMN_CLASS_ID -> resolvedType.typeArguments[0] as? ConeKotlinType + else -> null + } + if (name == null || type == null) return PluginDataFrameSchema(emptyList()) + simpleColumnOf(name, type) + } + return PluginDataFrameSchema(res) + } +} + +class DataFrameOfPairs : SchemaConstructor() + +class ColumnOfPairs : SchemaConstructor() + diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index 2d914d3e5e..4ac873c8a1 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -96,11 +96,13 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsAtAnyDepth2 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf2 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnOfPairs import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnRange import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ConcatWithKeys import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameBuilderInvoke0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf3 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOfPairs import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameUnfold import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameXs import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Drop0 @@ -409,6 +411,8 @@ internal inline fun String.load(): T { "toDataFrameDefault" -> ToDataFrameDefault() "ToDataFrameDslStringInvoke" -> ToDataFrameDslStringInvoke() "DataFrameOf0" -> DataFrameOf0() + "DataFrameOfPairs" -> DataFrameOfPairs() + "ColumnOfPairs" -> ColumnOfPairs() "DataFrameBuilderInvoke0" -> DataFrameBuilderInvoke0() "ToDataFrameColumn" -> ToDataFrameColumn() "FillNulls0" -> FillNulls0() diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/utils/Names.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/utils/Names.kt index e97c105718..249d7e2e93 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/utils/Names.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/utils/Names.kt @@ -31,11 +31,18 @@ object Names { val COLUM_GROUP_CLASS_ID: ClassId get() = ClassId(FqName("org.jetbrains.kotlinx.dataframe.columns"), Name.identifier("ColumnGroup")) + val FRAME_COLUMN_CLASS_ID: ClassId + get() = ClassId(FqName("org.jetbrains.kotlinx.dataframe.columns"), Name.identifier("FrameColumn")) val DATA_COLUMN_CLASS_ID: ClassId get() = ClassId( FqName.fromSegments(listOf("org", "jetbrains", "kotlinx", "dataframe")), Name.identifier("DataColumn") ) + val BASE_COLUMN_CLASS_ID: ClassId + get() = ClassId( + FqName.fromSegments(listOf("org", "jetbrains", "kotlinx", "dataframe", "columns")), + Name.identifier("BaseColumn") + ) val COLUMNS_CONTAINER_CLASS_ID: ClassId get() = ClassId( FqName.fromSegments(listOf("org", "jetbrains", "kotlinx", "dataframe")), diff --git a/plugins/kotlin-dataframe/testData/box/columnOf_nested.kt b/plugins/kotlin-dataframe/testData/box/columnOf_nested.kt new file mode 100644 index 0000000000..28b102a83b --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/columnOf_nested.kt @@ -0,0 +1,15 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + val group = columnOf( + "c" to columnOf("2"), + "d" to columnOf(123), + ) + val str: DataColumn = group.c + val i: DataColumn = group.d + + return "OK" +} diff --git a/plugins/kotlin-dataframe/testData/box/dataFrameOf_nested.kt b/plugins/kotlin-dataframe/testData/box/dataFrameOf_nested.kt new file mode 100644 index 0000000000..e0235a1c85 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/dataFrameOf_nested.kt @@ -0,0 +1,20 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + val df = dataFrameOf( + "a" to columnOf("1"), + "b" to columnOf( + "c" to columnOf("2"), + ), + "d" to columnOf(dataFrameOf("a")(123)), + "gr" to listOf("1").toDataFrame().asColumnGroup(), + ) + val str: DataColumn = df.a + val str1: DataColumn = df.b.c + val i: DataColumn = df.d[0].a + val str2: DataColumn = df.gr.value + return "OK" +} diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index 77a97df70b..795f95b5bb 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -76,6 +76,12 @@ public void testColumnName_invalidSymbol() { runTest("testData/box/columnName_invalidSymbol.kt"); } + @Test + @TestMetadata("columnOf_nested.kt") + public void testColumnOf_nested() { + runTest("testData/box/columnOf_nested.kt"); + } + @Test @TestMetadata("columnWithStarProjection.kt") public void testColumnWithStarProjection() { @@ -118,6 +124,12 @@ public void testDataFrameOf() { runTest("testData/box/dataFrameOf.kt"); } + @Test + @TestMetadata("dataFrameOf_nested.kt") + public void testDataFrameOf_nested() { + runTest("testData/box/dataFrameOf_nested.kt"); + } + @Test @TestMetadata("dataFrameOf_to.kt") public void testDataFrameOf_to() {