@@ -20,12 +20,15 @@ use std::path::PathBuf;
2020use std:: str:: FromStr ;
2121use std:: sync:: Arc ;
2222
23+ use arrow:: array:: RecordBatchReader ;
24+ use arrow:: ffi_stream:: ArrowArrayStreamReader ;
25+ use arrow:: pyarrow:: FromPyArrow ;
2326use datafusion:: execution:: session_state:: SessionStateBuilder ;
2427use object_store:: ObjectStore ;
2528use url:: Url ;
2629use uuid:: Uuid ;
2730
28- use pyo3:: exceptions:: { PyKeyError , PyValueError } ;
31+ use pyo3:: exceptions:: { PyKeyError , PyTypeError , PyValueError } ;
2932use pyo3:: prelude:: * ;
3033
3134use crate :: catalog:: { PyCatalog , PyTable } ;
@@ -444,7 +447,7 @@ impl PySessionContext {
444447 let table = table_class. call_method1 ( "from_pylist" , args) ?;
445448
446449 // Convert Arrow Table to datafusion DataFrame
447- let df = self . from_arrow_table ( table, name, py) ?;
450+ let df = self . from_arrow ( table, name, py) ?;
448451 Ok ( df)
449452 }
450453
@@ -463,29 +466,42 @@ impl PySessionContext {
463466 let table = table_class. call_method1 ( "from_pydict" , args) ?;
464467
465468 // Convert Arrow Table to datafusion DataFrame
466- let df = self . from_arrow_table ( table, name, py) ?;
469+ let df = self . from_arrow ( table, name, py) ?;
467470 Ok ( df)
468471 }
469472
470473 /// Construct datafusion dataframe from Arrow Table
471- pub fn from_arrow_table (
474+ pub fn from_arrow (
472475 & mut self ,
473476 data : Bound < ' _ , PyAny > ,
474477 name : Option < & str > ,
475478 py : Python ,
476479 ) -> PyResult < PyDataFrame > {
477- // Instantiate pyarrow Table object & convert to batches
478- let table = data. call_method0 ( "to_batches" ) ?;
480+ let ( schema, batches) =
481+ if let Ok ( stream_reader) = ArrowArrayStreamReader :: from_pyarrow_bound ( & data) {
482+ // Works for any object that implements __arrow_c_stream__ in pycapsule.
483+
484+ let schema = stream_reader. schema ( ) . as_ref ( ) . to_owned ( ) ;
485+ let batches = stream_reader
486+ . collect :: < Result < Vec < RecordBatch > , arrow:: error:: ArrowError > > ( )
487+ . map_err ( DataFusionError :: from) ?;
488+
489+ ( schema, batches)
490+ } else if let Ok ( array) = RecordBatch :: from_pyarrow_bound ( & data) {
491+ // While this says RecordBatch, it will work for any object that implements
492+ // __arrow_c_array__ and returns a StructArray.
493+
494+ ( array. schema ( ) . as_ref ( ) . to_owned ( ) , vec ! [ array] )
495+ } else {
496+ return Err ( PyTypeError :: new_err (
497+ "Expected either a Arrow Array or Arrow Stream in from_arrow()." ,
498+ ) ) ;
499+ } ;
479500
480- let schema = data. getattr ( "schema" ) ?;
481- let schema = schema. extract :: < PyArrowType < Schema > > ( ) ?;
482-
483- // Cast PyAny to RecordBatch type
484501 // Because create_dataframe() expects a vector of vectors of record batches
485502 // here we need to wrap the vector of record batches in an additional vector
486- let batches = table. extract :: < PyArrowType < Vec < RecordBatch > > > ( ) ?;
487- let list_of_batches = PyArrowType :: from ( vec ! [ batches. 0 ] ) ;
488- self . create_dataframe ( list_of_batches, name, Some ( schema) , py)
503+ let list_of_batches = PyArrowType :: from ( vec ! [ batches] ) ;
504+ self . create_dataframe ( list_of_batches, name, Some ( schema. into ( ) ) , py)
489505 }
490506
491507 /// Construct datafusion dataframe from pandas
@@ -504,7 +520,7 @@ impl PySessionContext {
504520 let table = table_class. call_method1 ( "from_pandas" , args) ?;
505521
506522 // Convert Arrow Table to datafusion DataFrame
507- let df = self . from_arrow_table ( table, name, py) ?;
523+ let df = self . from_arrow ( table, name, py) ?;
508524 Ok ( df)
509525 }
510526
@@ -518,7 +534,7 @@ impl PySessionContext {
518534 let table = data. call_method0 ( "to_arrow" ) ?;
519535
520536 // Convert Arrow Table to datafusion DataFrame
521- let df = self . from_arrow_table ( table, name, data. py ( ) ) ?;
537+ let df = self . from_arrow ( table, name, data. py ( ) ) ?;
522538 Ok ( df)
523539 }
524540
0 commit comments