@@ -326,3 +326,158 @@ def test_broadcast_to_raises(data):
326
326
Xnp = np .zeros (orig_shape )
327
327
X = dpt .asarray (Xnp , sycl_queue = q )
328
328
pytest .raises (ValueError , dpt .broadcast_to , X , target_shape )
329
+
330
+
331
+ def assert_broadcast_correct (input_shapes ):
332
+ try :
333
+ q = dpctl .SyclQueue ()
334
+ except dpctl .SyclQueueCreationError :
335
+ pytest .skip ("Queue could not be created" )
336
+ np_arrays = [np .zeros (s ) for s in input_shapes ]
337
+ out_np_arrays = np .broadcast_arrays (* np_arrays )
338
+ usm_arrays = [dpt .asarray (Xnp , sycl_queue = q ) for Xnp in np_arrays ]
339
+ out_usm_arrays = dpt .broadcast_arrays (* usm_arrays )
340
+ for Xnp , X in zip (out_np_arrays , out_usm_arrays ):
341
+ assert_array_equal (
342
+ Xnp , dpt .asnumpy (X ), err_msg = f"Failed for { input_shapes } )"
343
+ )
344
+
345
+
346
+ def assert_broadcast_arrays_raise (input_shapes ):
347
+ try :
348
+ q = dpctl .SyclQueue ()
349
+ except dpctl .SyclQueueCreationError :
350
+ pytest .skip ("Queue could not be created" )
351
+ usm_arrays = [dpt .asarray (np .zeros (s ), sycl_queue = q ) for s in input_shapes ]
352
+ pytest .raises (ValueError , dpt .broadcast_arrays , * usm_arrays )
353
+
354
+
355
+ def test_broadcast_arrays_same ():
356
+ try :
357
+ q = dpctl .SyclQueue ()
358
+ except dpctl .SyclQueueCreationError :
359
+ pytest .skip ("Queue could not be created" )
360
+ Xnp = np .arange (10 )
361
+ Ynp = np .arange (10 )
362
+ res_Xnp , res_Ynp = np .broadcast_arrays (Xnp , Ynp )
363
+ X = dpt .asarray (Xnp , sycl_queue = q )
364
+ Y = dpt .asarray (Ynp , sycl_queue = q )
365
+ res_X , res_Y = dpt .broadcast_arrays (X , Y )
366
+ assert_array_equal (res_Xnp , dpt .asnumpy (res_X ))
367
+ assert_array_equal (res_Ynp , dpt .asnumpy (res_Y ))
368
+
369
+
370
+ def test_broadcast_arrays_one_off ():
371
+ try :
372
+ q = dpctl .SyclQueue ()
373
+ except dpctl .SyclQueueCreationError :
374
+ pytest .skip ("Queue could not be created" )
375
+ Xnp = np .array ([[1 , 2 , 3 ]])
376
+ Ynp = np .array ([[1 ], [2 ], [3 ]])
377
+ res_Xnp , res_Ynp = np .broadcast_arrays (Xnp , Ynp )
378
+ X = dpt .asarray (Xnp , sycl_queue = q )
379
+ Y = dpt .asarray (Ynp , sycl_queue = q )
380
+ res_X , res_Y = dpt .broadcast_arrays (X , Y )
381
+ assert_array_equal (res_Xnp , dpt .asnumpy (res_X ))
382
+ assert_array_equal (res_Ynp , dpt .asnumpy (res_Y ))
383
+
384
+
385
+ @pytest .mark .parametrize (
386
+ "shapes" ,
387
+ [
388
+ (),
389
+ (1 ,),
390
+ (3 ,),
391
+ (0 , 1 ),
392
+ (0 , 3 ),
393
+ (1 , 0 ),
394
+ (3 , 0 ),
395
+ (1 , 3 ),
396
+ (3 , 1 ),
397
+ (3 , 3 ),
398
+ ],
399
+ )
400
+ def test_broadcast_arrays_same_shapes (shapes ):
401
+ for shape in shapes :
402
+ single_input_shapes = [shape ]
403
+ assert_broadcast_correct (single_input_shapes )
404
+ double_input_shapes = [shape , shape ]
405
+ assert_broadcast_correct (double_input_shapes )
406
+ triple_input_shapes = [shape , shape , shape ]
407
+ assert_broadcast_correct (triple_input_shapes )
408
+
409
+
410
+ @pytest .mark .parametrize (
411
+ "shapes" ,
412
+ [
413
+ [[(1 ,), (3 ,)]],
414
+ [[(1 , 3 ), (3 , 3 )]],
415
+ [[(3 , 1 ), (3 , 3 )]],
416
+ [[(1 , 3 ), (3 , 1 )]],
417
+ [[(1 , 1 ), (3 , 3 )]],
418
+ [[(1 , 1 ), (1 , 3 )]],
419
+ [[(1 , 1 ), (3 , 1 )]],
420
+ [[(1 , 0 ), (0 , 0 )]],
421
+ [[(0 , 1 ), (0 , 0 )]],
422
+ [[(1 , 0 ), (0 , 1 )]],
423
+ [[(1 , 1 ), (0 , 0 )]],
424
+ [[(1 , 1 ), (1 , 0 )]],
425
+ [[(1 , 1 ), (0 , 1 )]],
426
+ ],
427
+ )
428
+ def test_broadcast_arrays_same_len_shapes (shapes ):
429
+ # Check that two different input shapes of the same length, but some have
430
+ # ones, broadcast to the correct shape.
431
+
432
+ for input_shapes in shapes :
433
+ assert_broadcast_correct (input_shapes )
434
+ assert_broadcast_correct (input_shapes [::- 1 ])
435
+
436
+
437
+ @pytest .mark .parametrize (
438
+ "shapes" ,
439
+ [
440
+ [[(), (3 ,)]],
441
+ [[(3 ,), (3 , 3 )]],
442
+ [[(3 ,), (3 , 1 )]],
443
+ [[(1 ,), (3 , 3 )]],
444
+ [[(), (3 , 3 )]],
445
+ [[(1 , 1 ), (3 ,)]],
446
+ [[(1 ,), (3 , 1 )]],
447
+ [[(1 ,), (1 , 3 )]],
448
+ [[(), (1 , 3 )]],
449
+ [[(), (3 , 1 )]],
450
+ [[(), (0 ,)]],
451
+ [[(0 ,), (0 , 0 )]],
452
+ [[(0 ,), (0 , 1 )]],
453
+ [[(1 ,), (0 , 0 )]],
454
+ [[(), (0 , 0 )]],
455
+ [[(1 , 1 ), (0 ,)]],
456
+ [[(1 ,), (0 , 1 )]],
457
+ [[(1 ,), (1 , 0 )]],
458
+ [[(), (1 , 0 )]],
459
+ [[(), (0 , 1 )]],
460
+ ],
461
+ )
462
+ def test_broadcast_arrays_different_len_shapes (shapes ):
463
+ # Check that two different input shapes (of different lengths) broadcast
464
+ # to the correct shape.
465
+
466
+ for input_shapes in shapes :
467
+ assert_broadcast_correct (input_shapes )
468
+ assert_broadcast_correct (input_shapes [::- 1 ])
469
+
470
+
471
+ @pytest .mark .parametrize (
472
+ "shapes" ,
473
+ [
474
+ [[(3 ,), (4 ,)]],
475
+ [[(2 , 3 ), (2 ,)]],
476
+ [[(3 ,), (3 ,), (4 ,)]],
477
+ [[(1 , 3 , 4 ), (2 , 3 , 3 )]],
478
+ ],
479
+ )
480
+ def test_incompatible_shapes_raise_valueerror (shapes ):
481
+ for input_shapes in shapes :
482
+ assert_broadcast_arrays_raise (input_shapes )
483
+ assert_broadcast_arrays_raise (input_shapes [::- 1 ])
0 commit comments