Skip to content

Commit aa3307e

Browse files
authored
Merge pull request #1088 from Kotlin/aggregate-datarow
[Compiler plugin] Support DataFrame.aggregate
2 parents 518c483 + 1555636 commit aa3307e

File tree

5 files changed

+38
-0
lines changed

5 files changed

+38
-0
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/DataFrame.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import org.jetbrains.kotlinx.dataframe.aggregation.Aggregatable
44
import org.jetbrains.kotlinx.dataframe.aggregation.AggregateGroupedBody
55
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
66
import org.jetbrains.kotlinx.dataframe.annotations.HasSchema
7+
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
8+
import org.jetbrains.kotlinx.dataframe.annotations.Refine
79
import org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl
810
import org.jetbrains.kotlinx.dataframe.api.add
911
import org.jetbrains.kotlinx.dataframe.api.cast
@@ -71,6 +73,8 @@ public interface DataFrame<out T> :
7173

7274
// endregion
7375

76+
@Refine
77+
@Interpretable("AggregateRow")
7478
public fun <R> aggregate(body: AggregateGroupedBody<T, R>): DataRow<T>
7579

7680
// region get columns

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,18 @@ class Aggregate : AbstractSchemaModificationInterpreter() {
7272
}
7373
}
7474

75+
class AggregateRow : AbstractSchemaModificationInterpreter() {
76+
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
77+
val Arguments.body: FirAnonymousFunctionExpression by arg(lens = Interpreter.Id)
78+
override fun Arguments.interpret(): PluginDataFrameSchema {
79+
return aggregate(
80+
GroupBy(PluginDataFrameSchema.EMPTY, receiver),
81+
InterpretationErrorReporter.DEFAULT,
82+
body
83+
)
84+
}
85+
}
86+
7587
fun KotlinTypeFacade.aggregate(
7688
groupBy: GroupBy,
7789
reporter: InterpretationErrorReporter,

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AddDslNamedGroup
7070
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AddDslStringInvoke
7171
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AddId
7272
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Aggregate
73+
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AggregateRow
7374
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.All0
7475
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.All1
7576
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.All2
@@ -395,6 +396,7 @@ internal inline fun <reified T> String.load(): T {
395396
"ToTop" -> ToTop()
396397
"Update0" -> Update0()
397398
"Aggregate" -> Aggregate()
399+
"AggregateRow" -> AggregateRow()
398400
"DataFrameOf3" -> DataFrameOf3()
399401
"ValueCounts" -> ValueCounts()
400402
"RenameToCamelCase" -> RenameToCamelCase()
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
fun box(): String {
7+
val row = dataFrameOf("a" to List(10) { it }).aggregate {
8+
maxOf { a } into "max"
9+
minOf { a } into "min"
10+
}
11+
val i: Int = row.max
12+
val i1: Int = row.min
13+
return "OK"
14+
}

plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ public void testAddId() {
2929
runTest("testData/box/addId.kt");
3030
}
3131

32+
@Test
33+
@TestMetadata("aggregateDataFrame.kt")
34+
public void testAggregateDataFrame() {
35+
runTest("testData/box/aggregateDataFrame.kt");
36+
}
37+
3238
@Test
3339
public void testAllFilesPresentInBox() {
3440
KtTestUtil.assertAllTestsPresentByMetadataWithExcluded(this.getClass(), new File("testData/box"), Pattern.compile("^(.+)\\.kt$"), null, TargetBackend.JVM_IR, true);

0 commit comments

Comments
 (0)