Skip to content

Commit 345aaae

Browse files
authored
Fix concat of series objects with column projection (#981)
1 parent 43c83ee commit 345aaae

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

dask_expr/_concat.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
Blockwise,
1616
Expr,
1717
Projection,
18+
ToFrame,
1819
are_co_aligned,
1920
determine_column_projection,
2021
)
22+
from dask_expr._util import _convert_to_list
2123

2224

2325
class Concat(Expr):
@@ -235,6 +237,7 @@ def get_columns_or_name(e: Expr):
235237
return e.columns if e.ndim == 2 else [e.name]
236238

237239
columns = determine_column_projection(self, parent, dependents)
240+
columns = _convert_to_list(columns)
238241
columns_frame = [
239242
[col for col in get_columns_or_name(frame) if col in columns]
240243
for frame in self._frames
@@ -252,18 +255,22 @@ def get_columns_or_name(e: Expr):
252255
for frame, cols in zip(self._frames, columns_frame)
253256
if len(cols) > 0
254257
]
255-
return type(parent)(
256-
type(self)(
257-
self.join,
258-
self.ignore_order,
259-
self._kwargs,
260-
self.axis,
261-
self.ignore_unknown_divisions,
262-
self.interleave_partitions,
263-
*frames,
264-
),
265-
*parent.operands[1:],
258+
result = type(self)(
259+
self.join,
260+
self.ignore_order,
261+
self._kwargs,
262+
self.axis,
263+
self.ignore_unknown_divisions,
264+
self.interleave_partitions,
265+
*frames,
266266
)
267+
if result.columns == _convert_to_list(parent.operand("columns")):
268+
if result.ndim == parent.ndim:
269+
return result
270+
elif result.ndim < parent.ndim:
271+
return ToFrame(result)
272+
273+
return type(parent)(result, *parent.operands[1:])
267274

268275

269276
class StackPartition(Concat):

dask_expr/tests/test_concat.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,3 +339,13 @@ def test_concat_series(pdf):
339339
expected = concat([df2.y, df2.x], axis=1)[["x", "y"]]
340340
assert q.optimize(fuse=False)._name == expected.optimize(fuse=False)._name
341341
assert_eq(q, pd.concat([pdf.y, pdf.x, pdf.z], axis=1)[["x", "y"]])
342+
343+
344+
def test_concat_series_and_projection(df, pdf):
345+
result = concat([df.x, df.y], axis=1)["x"]
346+
expected = pd.concat([pdf.x, pdf.y], axis=1)["x"]
347+
assert_eq(result, expected)
348+
349+
result = concat([df.x, df.y], axis=1)[["x"]]
350+
expected = pd.concat([pdf.x, pdf.y], axis=1)[["x"]]
351+
assert_eq(result, expected)

0 commit comments

Comments
 (0)