@@ -212,3 +212,54 @@ def test_explain(df):
212
212
column ("a" ) - column ("b" ),
213
213
)
214
214
df .explain ()
215
+
216
+
217
+ def test_intersect ():
218
+ ctx = SessionContext ()
219
+
220
+ batch = pa .RecordBatch .from_arrays (
221
+ [pa .array ([1 , 2 , 3 ]), pa .array ([4 , 5 , 6 ])],
222
+ names = ["a" , "b" ],
223
+ )
224
+ df_a = ctx .create_dataframe ([[batch ]])
225
+
226
+ batch = pa .RecordBatch .from_arrays (
227
+ [pa .array ([3 , 4 , 5 ]), pa .array ([6 , 7 , 8 ])],
228
+ names = ["a" , "b" ],
229
+ )
230
+ df_b = ctx .create_dataframe ([[batch ]])
231
+
232
+ batch = pa .RecordBatch .from_arrays (
233
+ [pa .array ([3 ]), pa .array ([6 ])],
234
+ names = ["a" , "b" ],
235
+ )
236
+ df_c = ctx .create_dataframe ([[batch ]]).sort (column ("a" ).sort (ascending = True ))
237
+
238
+ df_c .show ()
239
+ df_a .intersect (df_b ).sort (column ("a" ).sort (ascending = True )).show ()
240
+
241
+ assert df_c .collect () == df_a .intersect (df_b ).sort (column ("a" ).sort (ascending = True )).collect ()
242
+
243
+
244
+ def test_except_all ():
245
+ ctx = SessionContext ()
246
+
247
+ batch = pa .RecordBatch .from_arrays (
248
+ [pa .array ([1 , 2 , 3 ]), pa .array ([4 , 5 , 6 ])],
249
+ names = ["a" , "b" ],
250
+ )
251
+ df_a = ctx .create_dataframe ([[batch ]])
252
+
253
+ batch = pa .RecordBatch .from_arrays (
254
+ [pa .array ([3 , 4 , 5 ]), pa .array ([6 , 7 , 8 ])],
255
+ names = ["a" , "b" ],
256
+ )
257
+ df_b = ctx .create_dataframe ([[batch ]])
258
+
259
+ batch = pa .RecordBatch .from_arrays (
260
+ [pa .array ([1 , 2 ]), pa .array ([4 , 5 ])],
261
+ names = ["a" , "b" ],
262
+ )
263
+ df_c = ctx .create_dataframe ([[batch ]]).sort (column ("a" ).sort (ascending = True ))
264
+
265
+ assert df_c .collect () == df_a .except_all (df_b ).sort (column ("a" ).sort (ascending = True )).collect ()
0 commit comments