diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index c9544abac..bbbdddd41 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -256,3 +256,59 @@ def test_repartition(df): def test_repartition_by_hash(df): df.repartition_by_hash(column("a"), num=2) + + +def test_intersect(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df_a = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3, 4, 5]), pa.array([6, 7, 8])], + names=["a", "b"], + ) + df_b = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3]), pa.array([6])], + names=["a", "b"], + ) + df_c = ctx.create_dataframe([[batch]]).sort( + column("a").sort(ascending=True) + ) + + df_a_i_b = df_a.intersect(df_b).sort(column("a").sort(ascending=True)) + + assert df_c.collect() == df_a_i_b.collect() + + +def test_except_all(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df_a = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3, 4, 5]), pa.array([6, 7, 8])], + names=["a", "b"], + ) + df_b = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array([4, 5])], + names=["a", "b"], + ) + df_c = ctx.create_dataframe([[batch]]).sort( + column("a").sort(ascending=True) + ) + + df_a_e_b = df_a.except_all(df_b).sort(column("a").sort(ascending=True)) + + assert df_c.collect() == df_a_e_b.collect() diff --git a/src/dataframe.rs b/src/dataframe.rs index f6cb4f11e..e491c3d9d 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -192,4 +192,16 @@ impl PyDataFrame { let new_df = self.df.repartition(Partitioning::Hash(expr, num))?; Ok(Self::new(new_df)) } + + /// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema + fn intersect(&self, py_df: PyDataFrame) -> PyResult { + let new_df = self.df.intersect(py_df.df)?; + Ok(Self::new(new_df)) + } + + /// Calculate the exception of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema + fn except_all(&self, py_df: PyDataFrame) -> PyResult { + let new_df = self.df.except(py_df.df)?; + Ok(Self::new(new_df)) + } }