Skip to content

Commit e3a8d9f

Browse files
committed
add tests
1 parent d1238f3 commit e3a8d9f

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

bigframes/core/groupby/__init__.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -340,19 +340,29 @@ def _agg_list(self, func: typing.Sequence) -> df.DataFrame:
340340
for f in func
341341
]
342342

343-
aggregated_columns = pd.MultiIndex.from_tuples(
344-
[
345-
self._block.col_id_to_label[col_id]
343+
if self._block.column_labels.nlevels > 1:
344+
# Restructure MultiIndex for proper format: (idx1, idx2, func)
345+
# rather than ((idx1, idx2), func).
346+
aggregated_columns = pd.MultiIndex.from_tuples(
347+
[
348+
self._block.col_id_to_label[col_id]
349+
for col_id in self._aggregated_columns()
350+
],
351+
names=[*self._block.column_labels.names],
352+
).to_frame(index=False)
353+
354+
column_labels = [
355+
tuple(col_id) + (f,)
356+
for col_id in aggregated_columns.to_numpy()
357+
for f in func
358+
]
359+
else:
360+
column_labels = [
361+
(self._block.col_id_to_label[col_id], f)
346362
for col_id in self._aggregated_columns()
347-
],
348-
names=[*self._block.column_labels.names],
349-
).to_frame(index=False)
363+
for f in func
364+
]
350365

351-
column_labels = [
352-
tuple(col_id) + (f,)
353-
for col_id in aggregated_columns.to_numpy()
354-
for f in func
355-
]
356366
agg_block, _ = self._block.aggregate(
357367
by_column_ids=self._by_col_ids,
358368
aggregations=aggregations,

tests/system/small/test_groupby.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,23 @@ def test_dataframe_groupby_agg_list(scalars_df_index, scalars_pandas_df_index):
144144
pd.testing.assert_frame_equal(pd_result, bf_result_computed, check_dtype=False)
145145

146146

147+
def test_dataframe_groupby_agg_list_w_column_multi_index(
148+
scalars_df_index, scalars_pandas_df_index
149+
):
150+
columns = ["int64_too", "string_col", "bool_col"]
151+
multi_columns = pd.MultiIndex.from_tuples(zip(["a", "b", "a"], columns))
152+
bf_df = scalars_df_index[columns].copy()
153+
bf_df.columns = multi_columns
154+
pd_df = scalars_pandas_df_index[columns].copy()
155+
pd_df.columns = multi_columns
156+
157+
bf_result = bf_df.groupby(level=0).agg(["count", "min"])
158+
pd_result = pd_df.groupby(level=0).agg(["count", "min"])
159+
160+
bf_result_computed = bf_result.to_pandas()
161+
pd.testing.assert_frame_equal(pd_result, bf_result_computed, check_dtype=False)
162+
163+
147164
@pytest.mark.parametrize(
148165
("as_index"),
149166
[

0 commit comments

Comments
 (0)