diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index bed0a91a6..7d2ee06fd 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -212,3 +212,11 @@ def test_explain(df): column("a") - column("b"), ) df.explain() + + +def test_repartition(df): + df.repartition(2) + + +def test_repartition_by_hash(df): + df.repartition_by_hash(column("a"), num=2) diff --git a/src/dataframe.rs b/src/dataframe.rs index 80963f7f0..4dc645583 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -22,6 +22,7 @@ use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::arrow::util::pretty; use datafusion::dataframe::DataFrame; use datafusion::logical_plan::JoinType; +use datafusion::prelude::*; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use pyo3::types::PyTuple; @@ -147,7 +148,7 @@ impl PyDataFrame { "The join type {} does not exist or is not implemented", how )) - .into()) + .into()); } }; @@ -164,4 +165,18 @@ impl PyDataFrame { let batches = wait_for_future(py, df.collect())?; Ok(pretty::print_batches(&batches)?) } + + /// Repartition a `DataFrame` based on a logical partitioning scheme. + fn repartition(&self, num: usize) -> PyResult { + let new_df = self.df.repartition(Partitioning::RoundRobinBatch(num))?; + Ok(Self::new(new_df)) + } + + /// Repartition a `DataFrame` based on a logical partitioning scheme. + #[args(args = "*", num)] + fn repartition_by_hash(&self, args: Vec, num: usize) -> PyResult { + let expr = args.into_iter().map(|py_expr| py_expr.into()).collect(); + let new_df = self.df.repartition(Partitioning::Hash(expr, num))?; + Ok(Self::new(new_df)) + } }