Skip to content

Commit 14a5730

Browse files
committed
Ensure a stable order of columns after aggregation of GroupBy
1 parent f7cbe37 commit 14a5730

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/typeConversions.kt

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -356,12 +356,16 @@ public fun <T> DataFrame<T>.asColumnGroup(column: ColumnGroupAccessor<T>): Colum
356356

357357
// region as GroupedDataFrame
358358

359-
public fun <T> DataFrame<T>.asGroupBy(groupedColumnName: String): GroupBy<T, T> =
360-
GroupByImpl(this, getFrameColumn(groupedColumnName).castFrameColumn()) { none() }
359+
public fun <T> DataFrame<T>.asGroupBy(groupedColumnName: String): GroupBy<T, T> {
360+
val groups = getFrameColumn(groupedColumnName)
361+
return asGroupBy { groups.cast() }
362+
}
361363

362364
@AccessApiOverload
363-
public fun <T, G> DataFrame<T>.asGroupBy(groupedColumn: ColumnReference<DataFrame<G>>): GroupBy<T, G> =
364-
GroupByImpl(this, getFrameColumn(groupedColumn.name()).castFrameColumn()) { none() }
365+
public fun <T, G> DataFrame<T>.asGroupBy(groupedColumn: ColumnReference<DataFrame<G>>): GroupBy<T, G> {
366+
val groups = getFrameColumn(groupedColumn.name()).castFrameColumn<G>()
367+
return asGroupBy { groups }
368+
}
365369

366370
public fun <T> DataFrame<T>.asGroupBy(): GroupBy<T, T> {
367371
val groupCol = columns().single { it.isFrameColumn() }.asAnyFrameColumn().castFrameColumn<T>()
@@ -370,7 +374,7 @@ public fun <T> DataFrame<T>.asGroupBy(): GroupBy<T, T> {
370374

371375
public fun <T, G> DataFrame<T>.asGroupBy(selector: ColumnSelector<T, DataFrame<G>>): GroupBy<T, G> {
372376
val column = getColumn(selector).asFrameColumn()
373-
return GroupByImpl(this, column) { none() }
377+
return GroupByImpl(this.move { column }.toEnd(), column) { none() }
374378
}
375379

376380
// endregion

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTreeTests.kt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,28 @@ class DataFrameTreeTests : BaseTest() {
138138
res shouldBe typed
139139
}
140140

141+
@Test
142+
fun asGroupByOverloads() {
143+
val rowsColumn by columnOf(typed[0..3], typed[4..5], typed[6..6])
144+
val df = dataFrameOf(rowsColumn)
145+
val res = df.asGroupBy { rowsColumn }.max()
146+
df.asGroupBy("rowsColumn").max() shouldBe res
147+
df.asGroupBy(rowsColumn).max() shouldBe res
148+
}
149+
150+
@Test
151+
fun moveGroupedColumn() {
152+
val df = dataFrameOf(
153+
"group" to listOf(typed[0..3], typed[4..5], typed[6..6]),
154+
"col" to listOf(1, 2, 3),
155+
)
156+
157+
// We need to match the order of columns in the runtime and in compiler plugin
158+
// GroupBy with the same schema should give the same result after aggregation, no matter how it was created
159+
// We cannot track position of group in original df, so we align `asGroupBy` with `groupBy` and move `group` column to end
160+
df.asGroupBy("group").toDataFrame().columnNames() shouldBe listOf("col", "group")
161+
}
162+
141163
@Test
142164
fun createFrameColumn2() {
143165
val id by column(typed.indices())

0 commit comments

Comments
 (0)