diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index 760c37610..9880b6d33 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -22,6 +22,11 @@ from datafusion import DataFrame, SessionContext, column, literal, udf +@pytest.fixture +def ctx(): + return SessionContext() + + @pytest.fixture def df(): ctx = SessionContext() @@ -323,3 +328,56 @@ def test_collect_partitioned(): ) assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned() + + +def test_union(ctx): + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df_a = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3, 4, 5]), pa.array([6, 7, 8])], + names=["a", "b"], + ) + df_b = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3, 3, 4, 5]), pa.array([4, 5, 6, 6, 7, 8])], + names=["a", "b"], + ) + df_c = ctx.create_dataframe([[batch]]).sort( + column("a").sort(ascending=True) + ) + + df_a_u_b = df_a.union(df_b).sort(column("a").sort(ascending=True)) + + assert df_c.collect() == df_a_u_b.collect() + + +def test_union_distinct(ctx): + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df_a = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3, 4, 5]), pa.array([6, 7, 8])], + names=["a", "b"], + ) + df_b = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3, 4, 5]), pa.array([4, 5, 6, 7, 8])], + names=["a", "b"], + ) + df_c = ctx.create_dataframe([[batch]]).sort( + column("a").sort(ascending=True) + ) + + df_a_u_b = df_a.union(df_b, True).sort(column("a").sort(ascending=True)) + + assert df_c.collect() == df_a_u_b.collect() + assert df_c.collect() == df_a_u_b.collect() diff --git a/src/dataframe.rs b/src/dataframe.rs index 4ae0160a9..4d8c0a3fb 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -204,6 +204,26 @@ impl PyDataFrame { Ok(Self::new(new_df)) } + /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The + /// two `DataFrame`s must have exactly the same schema + #[args(distinct = false)] + fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyResult { + let new_df = if distinct { + self.df.union_distinct(py_df.df)? + } else { + self.df.union(py_df.df)? + }; + + Ok(Self::new(new_df)) + } + + /// Calculate the distinct union of two `DataFrame`s. The + /// two `DataFrame`s must have exactly the same schema + fn union_distinct(&self, py_df: PyDataFrame) -> PyResult { + let new_df = self.df.union_distinct(py_df.df)?; + Ok(Self::new(new_df)) + } + /// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema fn intersect(&self, py_df: PyDataFrame) -> PyResult { let new_df = self.df.intersect(py_df.df)?;