From e970d90938c88176ef1826a1bec35b387fb61c40 Mon Sep 17 00:00:00 2001 From: Nikita Klimenko Date: Mon, 10 Feb 2025 18:50:31 +0200 Subject: [PATCH 1/4] [Compiler plugin] Support GroupBy.count --- .../jetbrains/kotlinx/dataframe/api/count.kt | 6 ++++++ .../dataframe/plugin/impl/api/count.kt | 19 +++++++++++++++++++ .../dataframe/plugin/loadInterpreter.kt | 2 ++ .../testData/box/groupBy_count.kt | 16 ++++++++++++++++ ...DataFrameBlackBoxCodegenTestGenerated.java | 6 ++++++ 5 files changed, 49 insertions(+) create mode 100644 plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/count.kt create mode 100644 plugins/kotlin-dataframe/testData/box/groupBy_count.kt diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/count.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/count.kt index e8b306338f..6ee9c8820a 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/count.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/count.kt @@ -6,6 +6,8 @@ import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.Predicate import org.jetbrains.kotlinx.dataframe.RowFilter +import org.jetbrains.kotlinx.dataframe.annotations.Interpretable +import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateValue // region DataColumn @@ -37,9 +39,13 @@ public fun DataFrame.count(predicate: RowFilter): Int = rows().count { // region GroupBy +@Refine +@Interpretable("GroupByCount0") public fun Grouped.count(resultName: String = "count"): DataFrame = aggregateValue(resultName) { count() default 0 } +@Refine +@Interpretable("GroupByCount0") public fun Grouped.count(resultName: String = "count", predicate: RowFilter): DataFrame = aggregateValue(resultName) { count(predicate) default 0 } diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/count.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/count.kt new file mode 100644 index 0000000000..ac11e670d0 --- /dev/null +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/count.kt @@ -0,0 +1,19 @@ +package org.jetbrains.kotlinx.dataframe.plugin.impl.api + +import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter +import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments +import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema +import org.jetbrains.kotlinx.dataframe.plugin.impl.Present +import org.jetbrains.kotlinx.dataframe.plugin.impl.add +import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy +import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore + +class GroupByCount0 : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.resultName: String by arg(defaultValue = Present("count")) + val Arguments.predicate by ignore() + + override fun Arguments.interpret(): PluginDataFrameSchema { + return receiver.keys.add(resultName, session.builtinTypes.intType.type, context = this) + } +} diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index cc74c4a82e..7e8affd8eb 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -89,6 +89,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Flatten0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FlattenDefault import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByCount0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByInto import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Merge0 @@ -295,6 +296,7 @@ internal inline fun String.load(): T { "MergeBy0" -> MergeBy0() "MergeBy1" -> MergeBy1() "ReorderColumnsByName" -> ReorderColumnsByName() + "GroupByCount0" -> GroupByCount0() else -> error("$this") } as T } diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_count.kt b/plugins/kotlin-dataframe/testData/box/groupBy_count.kt new file mode 100644 index 0000000000..a2e98adec5 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/groupBy_count.kt @@ -0,0 +1,16 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + val df = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.count() + val i: Int = df.count[0] + + val df1 = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.count { a > 1 } + val i1: Int = df1.count[0] + + val df2 = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.count("myCol") { a > 1 } + val i2: Int = df2.myCol[0] + return "OK" +} diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index 256a201833..147704afdc 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -226,6 +226,12 @@ public void testGroupBy_DataRow() { runTest("testData/box/groupBy_DataRow.kt"); } + @Test + @TestMetadata("groupBy_count.kt") + public void testGroupBy_count() { + runTest("testData/box/groupBy_count.kt"); + } + @Test @TestMetadata("groupBy_extractSchema.kt") public void testGroupBy_extractSchema() { From 4c76e1729491ae7451c92333fdb96bffc51458d1 Mon Sep 17 00:00:00 2001 From: Nikita Klimenko Date: Mon, 10 Feb 2025 18:55:25 +0200 Subject: [PATCH 2/4] [Compiler plugin] Support groupBy.[first | last | maxBy | minBy].into(columName) --- .../jetbrains/kotlinx/dataframe/api/first.kt | 3 ++ .../jetbrains/kotlinx/dataframe/api/into.kt | 2 ++ .../jetbrains/kotlinx/dataframe/api/last.kt | 3 ++ .../jetbrains/kotlinx/dataframe/api/max.kt | 2 ++ .../jetbrains/kotlinx/dataframe/api/min.kt | 2 ++ .../dataframe/plugin/impl/SimpleCol.kt | 2 +- .../plugin/impl/api/ReducedGroupBy.kt | 35 +++++++++++++++++++ .../dataframe/plugin/impl/api/groupBy.kt | 8 +++-- .../kotlinx/dataframe/plugin/interpret.kt | 2 +- .../dataframe/plugin/loadInterpreter.kt | 6 ++++ .../testData/box/reducedGroupBy.kt | 16 +++++++++ ...DataFrameBlackBoxCodegenTestGenerated.java | 6 ++++ 12 files changed, 82 insertions(+), 5 deletions(-) create mode 100644 plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/ReducedGroupBy.kt create mode 100644 plugins/kotlin-dataframe/testData/box/reducedGroupBy.kt diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/first.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/first.kt index 89880a0016..eee17154f4 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/first.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/first.kt @@ -6,6 +6,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.RowFilter import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload +import org.jetbrains.kotlinx.dataframe.annotations.Interpretable import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup import org.jetbrains.kotlinx.dataframe.columns.ColumnPath import org.jetbrains.kotlinx.dataframe.columns.ColumnReference @@ -55,8 +56,10 @@ public fun DataFrame.firstOrNull(predicate: RowFilter): DataRow? = // region GroupBy +@Interpretable("GroupByReducePredicate") public fun GroupBy.first(): ReducedGroupBy = reduce { firstOrNull() } +@Interpretable("GroupByReducePredicate") public fun GroupBy.first(predicate: RowFilter): ReducedGroupBy = reduce { firstOrNull(predicate) } // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/into.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/into.kt index 0c76319186..c6b297708e 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/into.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/into.kt @@ -79,6 +79,8 @@ public inline fun ReducedGroupBy.into( noinline expression: RowExpression, ): DataFrame = into(column.columnName, expression) +@Refine +@Interpretable("GroupByReduceInto") public fun ReducedGroupBy.into(columnName: String): DataFrame = into(columnName) { this } @AccessApiOverload diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/last.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/last.kt index 622ceadd52..f7461958a6 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/last.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/last.kt @@ -6,6 +6,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.RowFilter import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload +import org.jetbrains.kotlinx.dataframe.annotations.Interpretable import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup import org.jetbrains.kotlinx.dataframe.columns.ColumnPath import org.jetbrains.kotlinx.dataframe.columns.ColumnReference @@ -56,8 +57,10 @@ public fun DataFrame.last(): DataRow { // region GroupBy +@Interpretable("GroupByReducePredicate") public fun GroupBy.last(): ReducedGroupBy = reduce { lastOrNull() } +@Interpretable("GroupByReducePredicate") public fun GroupBy.last(predicate: RowFilter): ReducedGroupBy = reduce { lastOrNull(predicate) } // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt index e64394fe23..84432017e4 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt @@ -8,6 +8,7 @@ import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.RowExpression import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload +import org.jetbrains.kotlinx.dataframe.annotations.Interpretable import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.columns.values @@ -168,6 +169,7 @@ public fun > Grouped.maxOf( expression: RowExpression, ): DataFrame = Aggregators.max.aggregateOfDelegated(this, name) { maxOfOrNull(expression) } +@Interpretable("GroupByReduceExpression") public fun > GroupBy.maxBy(rowExpression: RowExpression): ReducedGroupBy = reduce { maxByOrNull(rowExpression) } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt index 8b421a305b..74a949ee31 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt @@ -8,6 +8,7 @@ import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.RowExpression import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload +import org.jetbrains.kotlinx.dataframe.annotations.Interpretable import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.columns.values @@ -168,6 +169,7 @@ public fun > Grouped.minOf( expression: RowExpression, ): DataFrame = Aggregators.min.aggregateOfDelegated(this, name) { minOfOrNull(expression) } +@Interpretable("GroupByReduceExpression") public fun > GroupBy.minBy(rowExpression: RowExpression): ReducedGroupBy = reduce { minByOrNull(rowExpression) } diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/SimpleCol.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/SimpleCol.kt index 1a181adbdc..783e7412f1 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/SimpleCol.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/SimpleCol.kt @@ -125,7 +125,7 @@ fun KotlinTypeFacade.simpleColumnOf(name: String, type: ConeKotlinType): SimpleC } } -private fun KotlinTypeFacade.makeNullable(column: SimpleCol): SimpleCol { +internal fun KotlinTypeFacade.makeNullable(column: SimpleCol): SimpleCol { return when (column) { is SimpleColumnGroup -> { SimpleColumnGroup(column.name, column.columns().map { makeNullable(it) }) diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/ReducedGroupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/ReducedGroupBy.kt new file mode 100644 index 0000000000..d72f59f0c4 --- /dev/null +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/ReducedGroupBy.kt @@ -0,0 +1,35 @@ +package org.jetbrains.kotlinx.dataframe.plugin.impl.api + +import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter +import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter +import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments +import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema +import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup +import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy +import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore +import org.jetbrains.kotlinx.dataframe.plugin.impl.makeNullable + +class GroupByReducePredicate : AbstractInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.predicate by ignore() + override fun Arguments.interpret(): GroupBy { + return receiver + } +} + +class GroupByReduceExpression : AbstractInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.rowExpression by ignore() + override fun Arguments.interpret(): GroupBy { + return receiver + } +} + +class GroupByReduceInto : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.columnName: String by arg() + override fun Arguments.interpret(): PluginDataFrameSchema { + val group = makeNullable(SimpleColumnGroup(columnName, receiver.groups.columns())) + return PluginDataFrameSchema(receiver.keys.columns() + group) + } +} diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index b5bc6dc9e8..eea04ef476 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -1,14 +1,12 @@ package org.jetbrains.kotlinx.dataframe.plugin.impl.api -import org.jetbrains.kotlinx.dataframe.plugin.InterpretationErrorReporter -import org.jetbrains.kotlinx.dataframe.plugin.interpret -import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter import org.jetbrains.kotlin.fir.expressions.FirAnonymousFunctionExpression import org.jetbrains.kotlin.fir.expressions.FirExpression import org.jetbrains.kotlin.fir.expressions.FirFunctionCall import org.jetbrains.kotlin.fir.expressions.FirReturnExpression import org.jetbrains.kotlin.fir.types.ConeKotlinType import org.jetbrains.kotlin.fir.types.resolvedType +import org.jetbrains.kotlinx.dataframe.plugin.InterpretationErrorReporter import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter @@ -23,8 +21,11 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.add import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy +import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf import org.jetbrains.kotlinx.dataframe.plugin.impl.type +import org.jetbrains.kotlinx.dataframe.plugin.interpret +import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema) @@ -173,6 +174,7 @@ class GroupByToDataFrame : AbstractSchemaModificationInterpreter() { class GroupByAdd : AbstractInterpreter() { val Arguments.receiver: GroupBy by groupBy() val Arguments.name: String by arg() + val Arguments.infer by ignore() val Arguments.type: TypeApproximation by type(name("expression")) override fun Arguments.interpret(): GroupBy { diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt index f67f5fc3a2..a79bd61636 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt @@ -192,7 +192,7 @@ fun KotlinTypeFacade.interpret( assert(expectedReturnType.toString() == GroupBy::class.qualifiedName!!) { "'$name' should be ${GroupBy::class.qualifiedName!!}, but plugin expect $expectedReturnType" } - + // ok for ReducedGroupBy too val resolvedType = it.expression.resolvedType.fullyExpandedType(session) val keys = pluginDataFrameSchema(resolvedType.typeArguments[0]) val groups = pluginDataFrameSchema(resolvedType.typeArguments[1]) diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index 7e8affd8eb..5c76f7f9f9 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -91,6 +91,9 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByCount0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByInto +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceExpression +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceInto +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReducePredicate import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Merge0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MergeId @@ -297,6 +300,9 @@ internal inline fun String.load(): T { "MergeBy1" -> MergeBy1() "ReorderColumnsByName" -> ReorderColumnsByName() "GroupByCount0" -> GroupByCount0() + "GroupByReducePredicate" -> GroupByReducePredicate() + "GroupByReduceExpression" -> GroupByReduceExpression() + "GroupByReduceInto" -> GroupByReduceInto() else -> error("$this") } as T } diff --git a/plugins/kotlin-dataframe/testData/box/reducedGroupBy.kt b/plugins/kotlin-dataframe/testData/box/reducedGroupBy.kt new file mode 100644 index 0000000000..f7048d5b87 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/reducedGroupBy.kt @@ -0,0 +1,16 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + val groupBy = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() } + groupBy.maxBy { id }.into("group").compareSchemas() + groupBy.maxBy { id }.into("group").compareSchemas() + groupBy.first { id == 1 }.into("group").compareSchemas() + groupBy.first().into("group").compareSchemas() + groupBy.last { id == 1 }.into("group").compareSchemas() + groupBy.last().into("group").compareSchemas() + groupBy.minBy { id == 1 }.into("group").compareSchemas() + return "OK" +} diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index 147704afdc..edef7a8c3f 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -472,6 +472,12 @@ public void testRead_localFile() { runTest("testData/box/read_localFile.kt"); } + @Test + @TestMetadata("reducedGroupBy.kt") + public void testReducedGroupBy() { + runTest("testData/box/reducedGroupBy.kt"); + } + @Test @TestMetadata("remove.kt") public void testRemove() { From 7124479f8f0a817d89fecad060e1a765b197a4a5 Mon Sep 17 00:00:00 2001 From: Nikita Klimenko Date: Mon, 10 Feb 2025 20:36:02 +0200 Subject: [PATCH 3/4] Move ReducedGroupBy.concat --- core/api/core.api | 2 +- .../kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt | 9 +++++++++ .../kotlin/org/jetbrains/kotlinx/dataframe/api/into.kt | 5 ----- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/core/api/core.api b/core/api/core.api index 332e1ed34f..fb13bb3605 100644 --- a/core/api/core.api +++ b/core/api/core.api @@ -3530,6 +3530,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/ConcatKt { public static final fun concat (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/DataFrame;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun concat (Lorg/jetbrains/kotlinx/dataframe/DataRow;[Lorg/jetbrains/kotlinx/dataframe/DataRow;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun concat (Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static final fun concat (Lorg/jetbrains/kotlinx/dataframe/api/ReducedGroupBy;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun concatRows (Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun concatT (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; } @@ -4831,7 +4832,6 @@ public final class org/jetbrains/kotlinx/dataframe/api/InsertKt { } public final class org/jetbrains/kotlinx/dataframe/api/IntoKt { - public static final fun concat (Lorg/jetbrains/kotlinx/dataframe/api/ReducedGroupBy;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun into (Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun into (Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;Lkotlin/reflect/KProperty;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun into (Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;Lorg/jetbrains/kotlinx/dataframe/columns/ColumnAccessor;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt index ced0b784a2..b475b50fdb 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt @@ -42,6 +42,15 @@ public fun GroupBy.concat(): DataFrame = groups.concat() // endregion +// region ReducedGroupBy + +public fun ReducedGroupBy.concat(): DataFrame = + groupBy.groups.values() + .map { reducer(it, it) } + .concat() + +// endregion + // region Iterable public fun Iterable>.concat(): DataFrame = concatImpl(asList()) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/into.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/into.kt index c6b297708e..dc5fd1002e 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/into.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/into.kt @@ -89,9 +89,4 @@ public fun ReducedGroupBy.into(column: ColumnAccessor): Dat @AccessApiOverload public fun ReducedGroupBy.into(column: KProperty): DataFrame = into(column) { this } -public fun ReducedGroupBy.concat(): DataFrame = - groupBy.groups.values() - .map { reducer(it, it) } - .concat() - // endregion From 981a47c8abb110102556665a3711c299d2875a6a Mon Sep 17 00:00:00 2001 From: Nikita Klimenko Date: Mon, 10 Feb 2025 20:55:37 +0200 Subject: [PATCH 4/4] [Compiler plugin] Support GroupBy.[minOf | maxOf] --- .../jetbrains/kotlinx/dataframe/api/max.kt | 3 +++ .../jetbrains/kotlinx/dataframe/api/min.kt | 3 +++ .../dataframe/plugin/impl/api/groupBy.kt | 15 +++++++++++++ .../dataframe/plugin/loadInterpreter.kt | 4 ++++ .../testData/box/groupBy_maxOfMinOf.kt | 22 +++++++++++++++++++ ...DataFrameBlackBoxCodegenTestGenerated.java | 6 +++++ 6 files changed, 53 insertions(+) create mode 100644 plugins/kotlin-dataframe/testData/box/groupBy_maxOfMinOf.kt diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt index 84432017e4..a265a20f9e 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt @@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dataframe.RowExpression import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload import org.jetbrains.kotlinx.dataframe.annotations.Interpretable +import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.columns.values @@ -164,6 +165,8 @@ public fun > Grouped.max( public fun > Grouped.max(vararg columns: KProperty, name: String? = null): DataFrame = max(name) { columns.toColumnSet() } +@Refine +@Interpretable("GroupByMaxOf") public fun > Grouped.maxOf( name: String? = null, expression: RowExpression, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt index 74a949ee31..d1cae852aa 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt @@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dataframe.RowExpression import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload import org.jetbrains.kotlinx.dataframe.annotations.Interpretable +import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.columns.values @@ -164,6 +165,8 @@ public fun > Grouped.min( public fun > Grouped.min(vararg columns: KProperty, name: String? = null): DataFrame = min(name) { columns.toColumnSet() } +@Refine +@Interpretable("GroupByMinOf") public fun > Grouped.minOf( name: String? = null, expression: RowExpression, diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index eea04ef476..72ae4bd346 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -22,6 +22,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximat import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore +import org.jetbrains.kotlinx.dataframe.plugin.impl.makeNullable import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf import org.jetbrains.kotlinx.dataframe.plugin.impl.type import org.jetbrains.kotlinx.dataframe.plugin.interpret @@ -181,3 +182,17 @@ class GroupByAdd : AbstractInterpreter() { return GroupBy(receiver.keys, receiver.groups.add(name, type.type, context = this)) } } + +abstract class GroupByAggregator(val defaultName: String) : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.name: String? by arg(defaultValue = Present(null)) + val Arguments.expression by type() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val aggregated = makeNullable(simpleColumnOf(name ?: defaultName, expression.type)) + return PluginDataFrameSchema(receiver.keys.columns() + aggregated) + } +} + +class GroupByMaxOf : GroupByAggregator(defaultName = "max") +class GroupByMinOf : GroupByAggregator(defaultName = "min") diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index 5c76f7f9f9..2f091af117 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -91,6 +91,8 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByCount0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByInto +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMaxOf +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMinOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceExpression import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceInto import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReducePredicate @@ -303,6 +305,8 @@ internal inline fun String.load(): T { "GroupByReducePredicate" -> GroupByReducePredicate() "GroupByReduceExpression" -> GroupByReduceExpression() "GroupByReduceInto" -> GroupByReduceInto() + "GroupByMaxOf" -> GroupByMaxOf() + "GroupByMinOf" -> GroupByMinOf() else -> error("$this") } as T } diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_maxOfMinOf.kt b/plugins/kotlin-dataframe/testData/box/groupBy_maxOfMinOf.kt new file mode 100644 index 0000000000..7653b7c669 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/groupBy_maxOfMinOf.kt @@ -0,0 +1,22 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + val df = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() }.maxOf { 123 } + val df1 = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() }.minOf { 123 } + + val max = df.max[0] + val min = df1.min[0] + + df.compareSchemas() + df1.compareSchemas() + + val df2 = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() }.maxOf("myMax") { 123 } + val df3 = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() }.minOf("myMin") { 123 } + + df2.myMax + df3.myMin + return "OK" +} diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index edef7a8c3f..77424539f7 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -238,6 +238,12 @@ public void testGroupBy_extractSchema() { runTest("testData/box/groupBy_extractSchema.kt"); } + @Test + @TestMetadata("groupBy_maxOfMinOf.kt") + public void testGroupBy_maxOfMinOf() { + runTest("testData/box/groupBy_maxOfMinOf.kt"); + } + @Test @TestMetadata("groupBy_refine.kt") public void testGroupBy_refine() {