From e86fc1189d1fb06afc1a323dbeaeaeec3dbcb22a Mon Sep 17 00:00:00 2001 From: francis-du Date: Thu, 8 Sep 2022 16:06:59 +0800 Subject: [PATCH 1/3] fix: conflicting --- datafusion/tests/test_dataframe.py | 51 ++++++++++++++++++++++++++++++ src/dataframe.rs | 12 +++++++ 2 files changed, 63 insertions(+) diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index c9544abac..e92e67d87 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -256,3 +256,54 @@ def test_repartition(df): def test_repartition_by_hash(df): df.repartition_by_hash(column("a"), num=2) + + +def test_intersect(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df_a = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3, 4, 5]), pa.array([6, 7, 8])], + names=["a", "b"], + ) + df_b = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3]), pa.array([6])], + names=["a", "b"], + ) + df_c = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True)) + + df_c.show() + df_a.intersect(df_b).sort(column("a").sort(ascending=True)).show() + + assert df_c.collect() == df_a.intersect(df_b).sort(column("a").sort(ascending=True)).collect() + + +def test_except_all(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df_a = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3, 4, 5]), pa.array([6, 7, 8])], + names=["a", "b"], + ) + df_b = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array([4, 5])], + names=["a", "b"], + ) + df_c = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True)) + + assert df_c.collect() == df_a.except_all(df_b).sort(column("a").sort(ascending=True)).collect() diff --git a/src/dataframe.rs b/src/dataframe.rs index f6cb4f11e..e491c3d9d 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -192,4 +192,16 @@ impl PyDataFrame { let new_df = self.df.repartition(Partitioning::Hash(expr, num))?; Ok(Self::new(new_df)) } + + /// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema + fn intersect(&self, py_df: PyDataFrame) -> PyResult { + let new_df = self.df.intersect(py_df.df)?; + Ok(Self::new(new_df)) + } + + /// Calculate the exception of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema + fn except_all(&self, py_df: PyDataFrame) -> PyResult { + let new_df = self.df.except(py_df.df)?; + Ok(Self::new(new_df)) + } } From ffd15ea31ea211fb6b81a3b9d92277cc0e2a9a06 Mon Sep 17 00:00:00 2001 From: francis-du Date: Tue, 6 Sep 2022 12:01:30 +0800 Subject: [PATCH 2/3] fix: python linter --- datafusion/tests/test_dataframe.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index e92e67d87..d056dd185 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -277,12 +277,19 @@ def test_intersect(): [pa.array([3]), pa.array([6])], names=["a", "b"], ) - df_c = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True)) + df_c = ctx.create_dataframe([[batch]]).sort( + column("a").sort(ascending=True) + ) df_c.show() df_a.intersect(df_b).sort(column("a").sort(ascending=True)).show() - assert df_c.collect() == df_a.intersect(df_b).sort(column("a").sort(ascending=True)).collect() + assert ( + df_c.collect() + == df_a.intersect(df_b) + .sort(column("a").sort(ascending=True)) + .collect() + ) def test_except_all(): @@ -304,6 +311,13 @@ def test_except_all(): [pa.array([1, 2]), pa.array([4, 5])], names=["a", "b"], ) - df_c = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True)) + df_c = ctx.create_dataframe([[batch]]).sort( + column("a").sort(ascending=True) + ) - assert df_c.collect() == df_a.except_all(df_b).sort(column("a").sort(ascending=True)).collect() + assert ( + df_c.collect() + == df_a.except_all(df_b) + .sort(column("a").sort(ascending=True)) + .collect() + ) From 7278fda8eb698e10dc80a69fb79258ade318fc70 Mon Sep 17 00:00:00 2001 From: francis-du Date: Thu, 8 Sep 2022 16:37:26 +0800 Subject: [PATCH 3/3] fix: flake W503 issue --- datafusion/tests/test_dataframe.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index d056dd185..bbbdddd41 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -281,15 +281,9 @@ def test_intersect(): column("a").sort(ascending=True) ) - df_c.show() - df_a.intersect(df_b).sort(column("a").sort(ascending=True)).show() + df_a_i_b = df_a.intersect(df_b).sort(column("a").sort(ascending=True)) - assert ( - df_c.collect() - == df_a.intersect(df_b) - .sort(column("a").sort(ascending=True)) - .collect() - ) + assert df_c.collect() == df_a_i_b.collect() def test_except_all(): @@ -315,9 +309,6 @@ def test_except_all(): column("a").sort(ascending=True) ) - assert ( - df_c.collect() - == df_a.except_all(df_b) - .sort(column("a").sort(ascending=True)) - .collect() - ) + df_a_e_b = df_a.except_all(df_b).sort(column("a").sort(ascending=True)) + + assert df_c.collect() == df_a_e_b.collect()