Skip to content

Commit b683961

Browse files
committed
feat: add intersect and except bindings for DataFrame
1 parent abc6d54 commit b683961

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

datafusion/tests/test_dataframe.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,54 @@ def test_explain(df):
212212
column("a") - column("b"),
213213
)
214214
df.explain()
215+
216+
217+
def test_intersect():
218+
ctx = SessionContext()
219+
220+
batch = pa.RecordBatch.from_arrays(
221+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
222+
names=["a", "b"],
223+
)
224+
df_a = ctx.create_dataframe([[batch]])
225+
226+
batch = pa.RecordBatch.from_arrays(
227+
[pa.array([3, 4, 5]), pa.array([6, 7, 8])],
228+
names=["a", "b"],
229+
)
230+
df_b = ctx.create_dataframe([[batch]])
231+
232+
batch = pa.RecordBatch.from_arrays(
233+
[pa.array([3]), pa.array([6])],
234+
names=["a", "b"],
235+
)
236+
df_c = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True))
237+
238+
df_c.show()
239+
df_a.intersect(df_b).sort(column("a").sort(ascending=True)).show()
240+
241+
assert df_c.collect() == df_a.intersect(df_b).sort(column("a").sort(ascending=True)).collect()
242+
243+
244+
def test_except_all():
245+
ctx = SessionContext()
246+
247+
batch = pa.RecordBatch.from_arrays(
248+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
249+
names=["a", "b"],
250+
)
251+
df_a = ctx.create_dataframe([[batch]])
252+
253+
batch = pa.RecordBatch.from_arrays(
254+
[pa.array([3, 4, 5]), pa.array([6, 7, 8])],
255+
names=["a", "b"],
256+
)
257+
df_b = ctx.create_dataframe([[batch]])
258+
259+
batch = pa.RecordBatch.from_arrays(
260+
[pa.array([1, 2]), pa.array([4, 5])],
261+
names=["a", "b"],
262+
)
263+
df_c = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True))
264+
265+
assert df_c.collect() == df_a.except_all(df_b).sort(column("a").sort(ascending=True)).collect()

src/dataframe.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ impl PyDataFrame {
147147
"The join type {} does not exist or is not implemented",
148148
how
149149
))
150-
.into())
150+
.into());
151151
}
152152
};
153153

@@ -164,4 +164,16 @@ impl PyDataFrame {
164164
let batches = wait_for_future(py, df.collect())?;
165165
Ok(pretty::print_batches(&batches)?)
166166
}
167+
168+
/// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema
169+
fn intersect(&self, py_df: PyDataFrame) -> PyResult<Self> {
170+
let new_df = self.df.intersect(py_df.df)?;
171+
Ok(Self::new(new_df))
172+
}
173+
174+
/// Calculate the exception of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema
175+
fn except_all(&self, py_df: PyDataFrame) -> PyResult<Self> {
176+
let new_df = self.df.except(py_df.df)?;
177+
Ok(Self::new(new_df))
178+
}
167179
}

0 commit comments

Comments
 (0)