diff --git a/src/dataframe.rs b/src/dataframe.rs index c2ad4771..ab4749e3 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -51,7 +51,7 @@ use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::logical::PyLogicalPlan; use crate::utils::{ - get_tokio_runtime, py_obj_to_scalar_value, validate_pycapsule, wait_for_future, + get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future, }; use crate::{ errors::PyDataFusionResult, @@ -289,21 +289,33 @@ impl PyParquetColumnOptions { #[derive(Clone)] pub struct PyDataFrame { df: Arc, + + // In IPython environment cache batches between __repr__ and _repr_html_ calls. + batches: Option<(Vec, bool)>, } impl PyDataFrame { /// creates a new PyDataFrame pub fn new(df: DataFrame) -> Self { - Self { df: Arc::new(df) } + Self { + df: Arc::new(df), + batches: None, + } } - fn prepare_repr_string(&self, py: Python, as_html: bool) -> PyDataFusionResult { + fn prepare_repr_string(&mut self, py: Python, as_html: bool) -> PyDataFusionResult { // Get the Python formatter and config let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?; - let (batches, has_more) = wait_for_future( - py, - collect_record_batches_to_display(self.df.as_ref().clone(), config), - )??; + + let should_cache = *is_ipython_env(py) && self.batches.is_none(); + let (batches, has_more) = match self.batches.take() { + Some(b) => b, + None => wait_for_future( + py, + collect_record_batches_to_display(self.df.as_ref().clone(), config), + )??, + }; + if batches.is_empty() { // This should not be reached, but do it for safety since we index into the vector below return Ok("No data to display".to_string()); @@ -313,7 +325,7 @@ impl PyDataFrame { // Convert record batches to PyObject list let py_batches = batches - .into_iter() + .iter() .map(|rb| rb.to_pyarrow(py)) .collect::>>()?; @@ -334,6 +346,10 @@ impl PyDataFrame { let html_result = formatter.call_method(method_name, (), Some(&kwargs))?; let html_str: String = html_result.extract()?; + if should_cache { + self.batches = Some((batches, has_more)); + } + Ok(html_str) } } @@ -361,7 +377,7 @@ impl PyDataFrame { } } - fn __repr__(&self, py: Python) -> PyDataFusionResult { + fn __repr__(&mut self, py: Python) -> PyDataFusionResult { self.prepare_repr_string(py, false) } @@ -396,7 +412,7 @@ impl PyDataFrame { Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}")) } - fn _repr_html_(&self, py: Python) -> PyDataFusionResult { + fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult { self.prepare_repr_string(py, true) } diff --git a/src/utils.rs b/src/utils.rs index 90d65438..f4e121fd 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -39,6 +39,17 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime { RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap())) } +#[inline] +pub(crate) fn is_ipython_env(py: Python) -> &'static bool { + static IS_IPYTHON_ENV: OnceLock = OnceLock::new(); + IS_IPYTHON_ENV.get_or_init(|| { + py.import("IPython") + .and_then(|ipython| ipython.call_method0("get_ipython")) + .map(|ipython| !ipython.is_none()) + .unwrap_or(false) + }) +} + /// Utility to get the Global Datafussion CTX #[inline] pub(crate) fn get_global_ctx() -> &'static SessionContext {