Skip to content

Commit e3289b7

Browse files
fix: Fix issues with chunked arrow data (#1700)
1 parent edaac89 commit e3289b7

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

bigframes/core/local_data.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable:
8686
columns: list[pa.ChunkedArray] = []
8787
fields: list[schemata.SchemaItem] = []
8888
for name, arr in zip(table.column_names, table.columns):
89-
new_arr, bf_type = _adapt_arrow_array(arr)
89+
new_arr, bf_type = _adapt_chunked_array(arr)
9090
columns.append(new_arr)
9191
fields.append(schemata.SchemaItem(name, bf_type))
9292

@@ -279,10 +279,26 @@ def _adapt_pandas_series(
279279
raise e
280280

281281

282-
def _adapt_arrow_array(
283-
array: Union[pa.ChunkedArray, pa.Array]
284-
) -> tuple[Union[pa.ChunkedArray, pa.Array], bigframes.dtypes.Dtype]:
282+
def _adapt_chunked_array(
283+
chunked_array: pa.ChunkedArray,
284+
) -> tuple[pa.ChunkedArray, bigframes.dtypes.Dtype]:
285+
if len(chunked_array.chunks) == 0:
286+
return _adapt_arrow_array(chunked_array.combine_chunks())
287+
dtype = None
288+
arrays = []
289+
for chunk in chunked_array.chunks:
290+
array, arr_dtype = _adapt_arrow_array(chunk)
291+
arrays.append(array)
292+
dtype = dtype or arr_dtype
293+
assert dtype is not None
294+
return pa.chunked_array(arrays), dtype
295+
296+
297+
def _adapt_arrow_array(array: pa.Array) -> tuple[pa.Array, bigframes.dtypes.Dtype]:
285298
"""Normalize the array to managed storage types. Preverse shapes, only transforms values."""
299+
if array.offset != 0: # Offset arrays don't have all operations implemented
300+
return _adapt_arrow_array(pa.concat_arrays([array]))
301+
286302
if pa.types.is_struct(array.type):
287303
assert isinstance(array, pa.StructArray)
288304
assert isinstance(array.type, pa.StructType)

tests/unit/test_local_data.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,23 @@ def test_local_data_well_formed_round_trip():
4444
local_entry = local_data.ManagedArrowTable.from_pandas(pd_data)
4545
result = pd.DataFrame(local_entry.itertuples(), columns=pd_data.columns)
4646
pandas.testing.assert_frame_equal(pd_data_normalized, result, check_dtype=False)
47+
48+
49+
def test_local_data_well_formed_round_trip_chunked():
50+
pa_table = pa.Table.from_pandas(pd_data, preserve_index=False)
51+
as_rechunked_pyarrow = pa.Table.from_batches(pa_table.to_batches(max_chunksize=2))
52+
local_entry = local_data.ManagedArrowTable.from_pyarrow(as_rechunked_pyarrow)
53+
result = pd.DataFrame(local_entry.itertuples(), columns=pd_data.columns)
54+
pandas.testing.assert_frame_equal(pd_data_normalized, result, check_dtype=False)
55+
56+
57+
def test_local_data_well_formed_round_trip_sliced():
58+
pa_table = pa.Table.from_pandas(pd_data, preserve_index=False)
59+
as_rechunked_pyarrow = pa.Table.from_batches(pa_table.slice(2, 4).to_batches())
60+
local_entry = local_data.ManagedArrowTable.from_pyarrow(as_rechunked_pyarrow)
61+
result = pd.DataFrame(local_entry.itertuples(), columns=pd_data.columns)
62+
pandas.testing.assert_frame_equal(
63+
pd_data_normalized[2:4].reset_index(drop=True),
64+
result.reset_index(drop=True),
65+
check_dtype=False,
66+
)

0 commit comments

Comments
 (0)