Skip to content

Commit 69ed7fe

Browse files
authored
Add PyCapsule support for Arrow import and export (#825)
1 parent 766e2ed commit 69ed7fe

File tree

9 files changed

+330
-41
lines changed

9 files changed

+330
-41
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
.. Licensed to the Apache Software Foundation (ASF) under one
2+
.. or more contributor license agreements. See the NOTICE file
3+
.. distributed with this work for additional information
4+
.. regarding copyright ownership. The ASF licenses this file
5+
.. to you under the Apache License, Version 2.0 (the
6+
.. "License"); you may not use this file except in compliance
7+
.. with the License. You may obtain a copy of the License at
8+
9+
.. http://www.apache.org/licenses/LICENSE-2.0
10+
11+
.. Unless required by applicable law or agreed to in writing,
12+
.. software distributed under the License is distributed on an
13+
.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
.. KIND, either express or implied. See the License for the
15+
.. specific language governing permissions and limitations
16+
.. under the License.
17+
18+
Arrow
19+
=====
20+
21+
DataFusion implements the
22+
`Apache Arrow PyCapsule interface <https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html>`_
23+
for importing and exporting DataFrames with zero copy. With this feature, any Python
24+
project that implements this interface can share data back and forth with DataFusion
25+
with zero copy.
26+
27+
We can demonstrate using `pyarrow <https://arrow.apache.org/docs/python/index.html>`_.
28+
29+
Importing to DataFusion
30+
-----------------------
31+
32+
Here we will create an Arrow table and import it to DataFusion.
33+
34+
To import an Arrow table, use :py:func:`datafusion.context.SessionContext.from_arrow`.
35+
This will accept any Python object that implements
36+
`__arrow_c_stream__ <https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html#arrowstream-export>`_
37+
or `__arrow_c_array__ <https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html#arrowarray-export>`_
38+
and returns a ``StructArray``. Common pyarrow sources you can use are:
39+
40+
- `Array <https://arrow.apache.org/docs/python/generated/pyarrow.Array.html>`_ (but it must return a Struct Array)
41+
- `Record Batch <https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatch.html>`_
42+
- `Record Batch Reader <https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html>`_
43+
- `Table <https://arrow.apache.org/docs/python/generated/pyarrow.Table.html>`_
44+
45+
.. ipython:: python
46+
47+
from datafusion import SessionContext
48+
import pyarrow as pa
49+
50+
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
51+
table = pa.Table.from_pydict(data)
52+
53+
ctx = SessionContext()
54+
df = ctx.from_arrow(table)
55+
df
56+
57+
Exporting from DataFusion
58+
-------------------------
59+
60+
DataFusion DataFrames implement ``__arrow_c_stream__`` PyCapsule interface, so any
61+
Python library that accepts these can import a DataFusion DataFrame directly.
62+
63+
.. warning::
64+
It is important to note that this will cause the DataFrame execution to happen, which may be
65+
a time consuming task. That is, you will cause a
66+
:py:func:`datafusion.dataframe.DataFrame.collect` operation call to occur.
67+
68+
69+
.. ipython:: python
70+
71+
df = df.select((col("a") * lit(1.5)).alias("c"), lit("df").alias("d"))
72+
pa.table(df)
73+

docs/source/user-guide/io/index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ IO
2121
.. toctree::
2222
:maxdepth: 2
2323

24+
arrow
25+
avro
2426
csv
25-
parquet
2627
json
27-
avro
28-
28+
parquet

examples/import.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,5 @@
5454

5555
# Convert Arrow Table to datafusion DataFrame
5656
arrow_table = pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
57-
df = ctx.from_arrow_table(arrow_table)
57+
df = ctx.from_arrow(arrow_table)
5858
assert type(df) == datafusion.DataFrame

python/datafusion/context.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -586,19 +586,31 @@ def from_pydict(
586586
"""
587587
return DataFrame(self.ctx.from_pydict(data, name))
588588

589-
def from_arrow_table(
590-
self, data: pyarrow.Table, name: str | None = None
591-
) -> DataFrame:
592-
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from an Arrow table.
589+
def from_arrow(self, data: Any, name: str | None = None) -> DataFrame:
590+
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from an Arrow source.
591+
592+
The Arrow data source can be any object that implements either
593+
``__arrow_c_stream__`` or ``__arrow_c_array__``. For the latter, it must return
594+
a struct array. Common examples of sources from pyarrow include
593595
594596
Args:
595-
data: Arrow table.
597+
data: Arrow data source.
596598
name: Name of the DataFrame.
597599
598600
Returns:
599601
DataFrame representation of the Arrow table.
600602
"""
601-
return DataFrame(self.ctx.from_arrow_table(data, name))
603+
return DataFrame(self.ctx.from_arrow(data, name))
604+
605+
@deprecated("Use ``from_arrow`` instead.")
606+
def from_arrow_table(
607+
self, data: pyarrow.Table, name: str | None = None
608+
) -> DataFrame:
609+
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from an Arrow table.
610+
611+
This is an alias for :py:func:`from_arrow`.
612+
"""
613+
return self.from_arrow(data, name)
602614

603615
def from_pandas(self, data: pandas.DataFrame, name: str | None = None) -> DataFrame:
604616
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from a Pandas DataFrame.

python/datafusion/dataframe.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,19 @@ def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFram
524524
"""
525525
columns = [c for c in columns]
526526
return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls))
527+
528+
def __arrow_c_stream__(self, requested_schema: pa.Schema) -> Any:
529+
"""Export an Arrow PyCapsule Stream.
530+
531+
This will execute and collect the DataFrame. We will attempt to respect the
532+
requested schema, but only trivial transformations will be applied such as only
533+
returning the fields listed in the requested schema if their data types match
534+
those in the DataFrame.
535+
536+
Args:
537+
requested_schema: Attempt to provide the DataFrame using this schema.
538+
539+
Returns:
540+
Arrow PyCapsule object.
541+
"""
542+
return self.df.__arrow_c_stream__(requested_schema)

python/datafusion/tests/test_context.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def test_from_arrow_table(ctx):
156156
table = pa.Table.from_pydict(data)
157157

158158
# convert to DataFrame
159-
df = ctx.from_arrow_table(table)
159+
df = ctx.from_arrow(table)
160160
tables = list(ctx.catalog().database().names())
161161

162162
assert df
@@ -166,13 +166,42 @@ def test_from_arrow_table(ctx):
166166
assert df.collect()[0].num_rows == 3
167167

168168

169+
def record_batch_generator(num_batches: int):
170+
schema = pa.schema([("a", pa.int64()), ("b", pa.int64())])
171+
for i in range(num_batches):
172+
yield pa.RecordBatch.from_arrays(
173+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])], schema=schema
174+
)
175+
176+
177+
@pytest.mark.parametrize(
178+
"source",
179+
[
180+
# __arrow_c_array__ sources
181+
pa.array([{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}]),
182+
# __arrow_c_stream__ sources
183+
pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}),
184+
pa.RecordBatchReader.from_batches(
185+
pa.schema([("a", pa.int64()), ("b", pa.int64())]), record_batch_generator(1)
186+
),
187+
pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}),
188+
],
189+
)
190+
def test_from_arrow_sources(ctx, source) -> None:
191+
df = ctx.from_arrow(source)
192+
assert df
193+
assert isinstance(df, DataFrame)
194+
assert df.schema().names == ["a", "b"]
195+
assert df.count() == 3
196+
197+
169198
def test_from_arrow_table_with_name(ctx):
170199
# create a PyArrow table
171200
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
172201
table = pa.Table.from_pydict(data)
173202

174203
# convert to DataFrame with optional name
175-
df = ctx.from_arrow_table(table, name="tbl")
204+
df = ctx.from_arrow(table, name="tbl")
176205
tables = list(ctx.catalog().database().names())
177206

178207
assert df
@@ -185,7 +214,7 @@ def test_from_arrow_table_empty(ctx):
185214
table = pa.Table.from_pydict(data, schema=schema)
186215

187216
# convert to DataFrame
188-
df = ctx.from_arrow_table(table)
217+
df = ctx.from_arrow(table)
189218
tables = list(ctx.catalog().database().names())
190219

191220
assert df
@@ -200,7 +229,7 @@ def test_from_arrow_table_empty_no_schema(ctx):
200229
table = pa.Table.from_pydict(data)
201230

202231
# convert to DataFrame
203-
df = ctx.from_arrow_table(table)
232+
df = ctx.from_arrow(table)
204233
tables = list(ctx.catalog().database().names())
205234

206235
assert df

python/datafusion/tests/test_dataframe.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def df():
4747
names=["a", "b", "c"],
4848
)
4949

50-
return ctx.create_dataframe([[batch]])
50+
return ctx.from_arrow(batch)
5151

5252

5353
@pytest.fixture
@@ -835,13 +835,42 @@ def test_write_compressed_parquet_missing_compression_level(df, tmp_path, compre
835835
df.write_parquet(str(path), compression=compression)
836836

837837

838-
# ctx = SessionContext()
839-
840-
# # create a RecordBatch and a new DataFrame from it
841-
# batch = pa.RecordBatch.from_arrays(
842-
# [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])],
843-
# names=["a", "b", "c"],
844-
# )
845-
846-
# df = ctx.create_dataframe([[batch]])
847-
# test_execute_stream(df)
838+
def test_dataframe_export(df) -> None:
839+
# Guarantees that we have the canonical implementation
840+
# reading our dataframe export
841+
table = pa.table(df)
842+
assert table.num_columns == 3
843+
assert table.num_rows == 3
844+
845+
desired_schema = pa.schema([("a", pa.int64())])
846+
847+
# Verify we can request a schema
848+
table = pa.table(df, schema=desired_schema)
849+
assert table.num_columns == 1
850+
assert table.num_rows == 3
851+
852+
# Expect a table of nulls if the schema don't overlap
853+
desired_schema = pa.schema([("g", pa.string())])
854+
table = pa.table(df, schema=desired_schema)
855+
assert table.num_columns == 1
856+
assert table.num_rows == 3
857+
for i in range(0, 3):
858+
assert table[0][i].as_py() is None
859+
860+
# Expect an error when we cannot convert schema
861+
desired_schema = pa.schema([("a", pa.float32())])
862+
failed_convert = False
863+
try:
864+
table = pa.table(df, schema=desired_schema)
865+
except Exception:
866+
failed_convert = True
867+
assert failed_convert
868+
869+
# Expect an error when we have a not set non-nullable
870+
desired_schema = pa.schema([("g", pa.string(), False)])
871+
failed_convert = False
872+
try:
873+
table = pa.table(df, schema=desired_schema)
874+
except Exception:
875+
failed_convert = True
876+
assert failed_convert

src/context.rs

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@ use std::path::PathBuf;
2020
use std::str::FromStr;
2121
use std::sync::Arc;
2222

23+
use arrow::array::RecordBatchReader;
24+
use arrow::ffi_stream::ArrowArrayStreamReader;
25+
use arrow::pyarrow::FromPyArrow;
2326
use datafusion::execution::session_state::SessionStateBuilder;
2427
use object_store::ObjectStore;
2528
use url::Url;
2629
use uuid::Uuid;
2730

28-
use pyo3::exceptions::{PyKeyError, PyValueError};
31+
use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError};
2932
use pyo3::prelude::*;
3033

3134
use crate::catalog::{PyCatalog, PyTable};
@@ -444,7 +447,7 @@ impl PySessionContext {
444447
let table = table_class.call_method1("from_pylist", args)?;
445448

446449
// Convert Arrow Table to datafusion DataFrame
447-
let df = self.from_arrow_table(table, name, py)?;
450+
let df = self.from_arrow(table, name, py)?;
448451
Ok(df)
449452
}
450453

@@ -463,29 +466,42 @@ impl PySessionContext {
463466
let table = table_class.call_method1("from_pydict", args)?;
464467

465468
// Convert Arrow Table to datafusion DataFrame
466-
let df = self.from_arrow_table(table, name, py)?;
469+
let df = self.from_arrow(table, name, py)?;
467470
Ok(df)
468471
}
469472

470473
/// Construct datafusion dataframe from Arrow Table
471-
pub fn from_arrow_table(
474+
pub fn from_arrow(
472475
&mut self,
473476
data: Bound<'_, PyAny>,
474477
name: Option<&str>,
475478
py: Python,
476479
) -> PyResult<PyDataFrame> {
477-
// Instantiate pyarrow Table object & convert to batches
478-
let table = data.call_method0("to_batches")?;
480+
let (schema, batches) =
481+
if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) {
482+
// Works for any object that implements __arrow_c_stream__ in pycapsule.
483+
484+
let schema = stream_reader.schema().as_ref().to_owned();
485+
let batches = stream_reader
486+
.collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()
487+
.map_err(DataFusionError::from)?;
488+
489+
(schema, batches)
490+
} else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
491+
// While this says RecordBatch, it will work for any object that implements
492+
// __arrow_c_array__ and returns a StructArray.
493+
494+
(array.schema().as_ref().to_owned(), vec![array])
495+
} else {
496+
return Err(PyTypeError::new_err(
497+
"Expected either a Arrow Array or Arrow Stream in from_arrow().",
498+
));
499+
};
479500

480-
let schema = data.getattr("schema")?;
481-
let schema = schema.extract::<PyArrowType<Schema>>()?;
482-
483-
// Cast PyAny to RecordBatch type
484501
// Because create_dataframe() expects a vector of vectors of record batches
485502
// here we need to wrap the vector of record batches in an additional vector
486-
let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>()?;
487-
let list_of_batches = PyArrowType::from(vec![batches.0]);
488-
self.create_dataframe(list_of_batches, name, Some(schema), py)
503+
let list_of_batches = PyArrowType::from(vec![batches]);
504+
self.create_dataframe(list_of_batches, name, Some(schema.into()), py)
489505
}
490506

491507
/// Construct datafusion dataframe from pandas
@@ -504,7 +520,7 @@ impl PySessionContext {
504520
let table = table_class.call_method1("from_pandas", args)?;
505521

506522
// Convert Arrow Table to datafusion DataFrame
507-
let df = self.from_arrow_table(table, name, py)?;
523+
let df = self.from_arrow(table, name, py)?;
508524
Ok(df)
509525
}
510526

@@ -518,7 +534,7 @@ impl PySessionContext {
518534
let table = data.call_method0("to_arrow")?;
519535

520536
// Convert Arrow Table to datafusion DataFrame
521-
let df = self.from_arrow_table(table, name, data.py())?;
537+
let df = self.from_arrow(table, name, data.py())?;
522538
Ok(df)
523539
}
524540

0 commit comments

Comments
 (0)