Skip to content

feat: add execute_stream and execute_stream_partitioned #610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,36 @@ def test_to_arrow_table(df):
assert set(pyarrow_table.column_names) == {"a", "b", "c"}


def test_execute_stream(df):
stream = df.execute_stream()
assert all(batch is not None for batch in stream)
assert not list(stream) # after one iteration the generator must be exhausted


@pytest.mark.parametrize("schema", [True, False])
def test_execute_stream_to_arrow_table(df, schema):
stream = df.execute_stream()

if schema:
pyarrow_table = pa.Table.from_batches(
(batch.to_pyarrow() for batch in stream), schema=df.schema()
)
else:
pyarrow_table = pa.Table.from_batches((batch.to_pyarrow() for batch in stream))

assert isinstance(pyarrow_table, pa.Table)
assert pyarrow_table.shape == (3, 3)
assert set(pyarrow_table.column_names) == {"a", "b", "c"}


def test_execute_stream_partitioned(df):
streams = df.execute_stream_partitioned()
assert all(batch is not None for stream in streams for batch in stream)
assert all(
not list(stream) for stream in streams
) # after one iteration all generators must be exhausted


def test_empty_to_arrow_table(df):
# Convert empty datafusion dataframe to pyarrow Table
pyarrow_table = df.limit(0).to_arrow_table()
Expand Down
45 changes: 40 additions & 5 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,27 @@
// specific language governing permissions and limitations
// under the License.

use crate::physical_plan::PyExecutionPlan;
use crate::sql::logical::PyLogicalPlan;
use crate::utils::wait_for_future;
use crate::{errors::DataFusionError, expr::PyExpr};
use std::sync::Arc;

use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::arrow::util::pretty;
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
use datafusion::execution::SendableRecordBatchStream;
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
use datafusion::parquet::file::properties::WriterProperties;
use datafusion::prelude::*;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyTuple;
use std::sync::Arc;
use tokio::task::JoinHandle;

use crate::errors::py_datafusion_err;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
use crate::sql::logical::PyLogicalPlan;
use crate::utils::{get_tokio_runtime, wait_for_future};
use crate::{errors::DataFusionError, expr::PyExpr};

/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
Expand Down Expand Up @@ -399,6 +405,35 @@ impl PyDataFrame {
})
}

fn execute_stream(&self, py: Python) -> PyResult<PyRecordBatchStream> {
// create a Tokio runtime to run the async code
let rt = &get_tokio_runtime(py).0;
let df = self.df.as_ref().clone();
let fut: JoinHandle<datafusion_common::Result<SendableRecordBatchStream>> =
rt.spawn(async move { df.execute_stream().await });
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
Ok(PyRecordBatchStream::new(stream?))
}

fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
// create a Tokio runtime to run the async code
let rt = &get_tokio_runtime(py).0;
let df = self.df.as_ref().clone();
let fut: JoinHandle<datafusion_common::Result<Vec<SendableRecordBatchStream>>> =
rt.spawn(async move { df.execute_stream_partitioned().await });
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;

match stream {
Ok(batches) => Ok(batches
.into_iter()
.map(|batch_stream| PyRecordBatchStream::new(batch_stream))
.collect()),
_ => Err(PyValueError::new_err(
"Unable to execute stream partitioned",
)),
}
}

/// Convert to pandas dataframe with pyarrow
/// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
fn to_pandas(&self, py: Python) -> PyResult<PyObject> {
Expand Down
9 changes: 9 additions & 0 deletions src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use datafusion::arrow::pyarrow::ToPyArrow;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::physical_plan::SendableRecordBatchStream;
use futures::StreamExt;
use pyo3::prelude::*;
use pyo3::{pyclass, pymethods, PyObject, PyResult, Python};

#[pyclass(name = "RecordBatch", module = "datafusion", subclass)]
Expand Down Expand Up @@ -61,4 +62,12 @@ impl PyRecordBatchStream {
Some(Err(e)) => Err(e.into()),
}
}

fn __next__(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
self.next(py)
}

fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
}