Skip to content

Commit f469af2

Browse files
Merge pull request #798 from IntelPython/broadcast_arrays
Implementation of broadcast_arrays function
2 parents e96cce3 + bb9085b commit f469af2

File tree

3 files changed

+218
-0
lines changed

3 files changed

+218
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from dpctl.tensor._device import Device
2727
from dpctl.tensor._dlpack import from_dlpack
2828
from dpctl.tensor._manipulation_functions import (
29+
broadcast_arrays,
2930
broadcast_to,
3031
expand_dims,
3132
permute_dims,
@@ -42,6 +43,7 @@
4243
"copy",
4344
"empty",
4445
"reshape",
46+
"broadcast_arrays",
4547
"broadcast_to",
4648
"expand_dims",
4749
"permute_dims",

dpctl/tensor/_manipulation_functions.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# limitations under the License.
1616

1717

18+
from itertools import chain, repeat
19+
1820
import numpy as np
1921
from numpy.core.numeric import normalize_axis_tuple
2022

@@ -38,6 +40,47 @@ def _broadcast_strides(X_shape, X_strides, res_ndim):
3840
return tuple(out_strides)
3941

4042

43+
def _broadcast_shapes(*args):
44+
"""
45+
Broadcast the input shapes into a single shape;
46+
returns tuple broadcasted shape.
47+
"""
48+
shapes = [array.shape for array in args]
49+
if len(set(shapes)) == 1:
50+
return shapes[0]
51+
mutable_shapes = False
52+
nds = [len(s) for s in shapes]
53+
biggest = max(nds)
54+
for i in range(len(args)):
55+
diff = biggest - nds[i]
56+
if diff > 0:
57+
ty = type(shapes[i])
58+
shapes[i] = ty(chain(repeat(1, diff), shapes[i]))
59+
common_shape = []
60+
for axis in range(biggest):
61+
lengths = [s[axis] for s in shapes]
62+
unique = set(lengths + [1])
63+
if len(unique) > 2:
64+
raise ValueError(
65+
"Shape mismatch: two or more arrays have "
66+
f"incompatible dimensions on axis ({axis},)"
67+
)
68+
elif len(unique) == 2:
69+
unique.remove(1)
70+
new_length = unique.pop()
71+
common_shape.append(new_length)
72+
for i in range(len(args)):
73+
if shapes[i][axis] == 1:
74+
if not mutable_shapes:
75+
shapes = [list(s) for s in shapes]
76+
mutable_shapes = True
77+
shapes[i][axis] = new_length
78+
else:
79+
common_shape.append(1)
80+
81+
return tuple(common_shape)
82+
83+
4184
def permute_dims(X, axes):
4285
"""
4386
permute_dims(X: usm_ndarray, axes: tuple or list) -> usm_ndarray
@@ -146,3 +189,21 @@ def broadcast_to(X, shape):
146189
strides=new_sts,
147190
offset=X.__sycl_usm_array_interface__.get("offset", 0),
148191
)
192+
193+
194+
def broadcast_arrays(*args):
195+
"""
196+
broadcast_arrays(*args: usm_ndarrays) -> list of usm_ndarrays
197+
198+
Broadcasts one or more usm_ndarrays against one another.
199+
"""
200+
for X in args:
201+
if not isinstance(X, dpt.usm_ndarray):
202+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
203+
204+
shape = _broadcast_shapes(*args)
205+
206+
if all(X.shape == shape for X in args):
207+
return args
208+
209+
return [broadcast_to(X, shape) for X in args]

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,158 @@ def test_broadcast_to_raises(data):
326326
Xnp = np.zeros(orig_shape)
327327
X = dpt.asarray(Xnp, sycl_queue=q)
328328
pytest.raises(ValueError, dpt.broadcast_to, X, target_shape)
329+
330+
331+
def assert_broadcast_correct(input_shapes):
332+
try:
333+
q = dpctl.SyclQueue()
334+
except dpctl.SyclQueueCreationError:
335+
pytest.skip("Queue could not be created")
336+
np_arrays = [np.zeros(s) for s in input_shapes]
337+
out_np_arrays = np.broadcast_arrays(*np_arrays)
338+
usm_arrays = [dpt.asarray(Xnp, sycl_queue=q) for Xnp in np_arrays]
339+
out_usm_arrays = dpt.broadcast_arrays(*usm_arrays)
340+
for Xnp, X in zip(out_np_arrays, out_usm_arrays):
341+
assert_array_equal(
342+
Xnp, dpt.asnumpy(X), err_msg=f"Failed for {input_shapes})"
343+
)
344+
345+
346+
def assert_broadcast_arrays_raise(input_shapes):
347+
try:
348+
q = dpctl.SyclQueue()
349+
except dpctl.SyclQueueCreationError:
350+
pytest.skip("Queue could not be created")
351+
usm_arrays = [dpt.asarray(np.zeros(s), sycl_queue=q) for s in input_shapes]
352+
pytest.raises(ValueError, dpt.broadcast_arrays, *usm_arrays)
353+
354+
355+
def test_broadcast_arrays_same():
356+
try:
357+
q = dpctl.SyclQueue()
358+
except dpctl.SyclQueueCreationError:
359+
pytest.skip("Queue could not be created")
360+
Xnp = np.arange(10)
361+
Ynp = np.arange(10)
362+
res_Xnp, res_Ynp = np.broadcast_arrays(Xnp, Ynp)
363+
X = dpt.asarray(Xnp, sycl_queue=q)
364+
Y = dpt.asarray(Ynp, sycl_queue=q)
365+
res_X, res_Y = dpt.broadcast_arrays(X, Y)
366+
assert_array_equal(res_Xnp, dpt.asnumpy(res_X))
367+
assert_array_equal(res_Ynp, dpt.asnumpy(res_Y))
368+
369+
370+
def test_broadcast_arrays_one_off():
371+
try:
372+
q = dpctl.SyclQueue()
373+
except dpctl.SyclQueueCreationError:
374+
pytest.skip("Queue could not be created")
375+
Xnp = np.array([[1, 2, 3]])
376+
Ynp = np.array([[1], [2], [3]])
377+
res_Xnp, res_Ynp = np.broadcast_arrays(Xnp, Ynp)
378+
X = dpt.asarray(Xnp, sycl_queue=q)
379+
Y = dpt.asarray(Ynp, sycl_queue=q)
380+
res_X, res_Y = dpt.broadcast_arrays(X, Y)
381+
assert_array_equal(res_Xnp, dpt.asnumpy(res_X))
382+
assert_array_equal(res_Ynp, dpt.asnumpy(res_Y))
383+
384+
385+
@pytest.mark.parametrize(
386+
"shapes",
387+
[
388+
(),
389+
(1,),
390+
(3,),
391+
(0, 1),
392+
(0, 3),
393+
(1, 0),
394+
(3, 0),
395+
(1, 3),
396+
(3, 1),
397+
(3, 3),
398+
],
399+
)
400+
def test_broadcast_arrays_same_shapes(shapes):
401+
for shape in shapes:
402+
single_input_shapes = [shape]
403+
assert_broadcast_correct(single_input_shapes)
404+
double_input_shapes = [shape, shape]
405+
assert_broadcast_correct(double_input_shapes)
406+
triple_input_shapes = [shape, shape, shape]
407+
assert_broadcast_correct(triple_input_shapes)
408+
409+
410+
@pytest.mark.parametrize(
411+
"shapes",
412+
[
413+
[[(1,), (3,)]],
414+
[[(1, 3), (3, 3)]],
415+
[[(3, 1), (3, 3)]],
416+
[[(1, 3), (3, 1)]],
417+
[[(1, 1), (3, 3)]],
418+
[[(1, 1), (1, 3)]],
419+
[[(1, 1), (3, 1)]],
420+
[[(1, 0), (0, 0)]],
421+
[[(0, 1), (0, 0)]],
422+
[[(1, 0), (0, 1)]],
423+
[[(1, 1), (0, 0)]],
424+
[[(1, 1), (1, 0)]],
425+
[[(1, 1), (0, 1)]],
426+
],
427+
)
428+
def test_broadcast_arrays_same_len_shapes(shapes):
429+
# Check that two different input shapes of the same length, but some have
430+
# ones, broadcast to the correct shape.
431+
432+
for input_shapes in shapes:
433+
assert_broadcast_correct(input_shapes)
434+
assert_broadcast_correct(input_shapes[::-1])
435+
436+
437+
@pytest.mark.parametrize(
438+
"shapes",
439+
[
440+
[[(), (3,)]],
441+
[[(3,), (3, 3)]],
442+
[[(3,), (3, 1)]],
443+
[[(1,), (3, 3)]],
444+
[[(), (3, 3)]],
445+
[[(1, 1), (3,)]],
446+
[[(1,), (3, 1)]],
447+
[[(1,), (1, 3)]],
448+
[[(), (1, 3)]],
449+
[[(), (3, 1)]],
450+
[[(), (0,)]],
451+
[[(0,), (0, 0)]],
452+
[[(0,), (0, 1)]],
453+
[[(1,), (0, 0)]],
454+
[[(), (0, 0)]],
455+
[[(1, 1), (0,)]],
456+
[[(1,), (0, 1)]],
457+
[[(1,), (1, 0)]],
458+
[[(), (1, 0)]],
459+
[[(), (0, 1)]],
460+
],
461+
)
462+
def test_broadcast_arrays_different_len_shapes(shapes):
463+
# Check that two different input shapes (of different lengths) broadcast
464+
# to the correct shape.
465+
466+
for input_shapes in shapes:
467+
assert_broadcast_correct(input_shapes)
468+
assert_broadcast_correct(input_shapes[::-1])
469+
470+
471+
@pytest.mark.parametrize(
472+
"shapes",
473+
[
474+
[[(3,), (4,)]],
475+
[[(2, 3), (2,)]],
476+
[[(3,), (3,), (4,)]],
477+
[[(1, 3, 4), (2, 3, 3)]],
478+
],
479+
)
480+
def test_incompatible_shapes_raise_valueerror(shapes):
481+
for input_shapes in shapes:
482+
assert_broadcast_arrays_raise(input_shapes)
483+
assert_broadcast_arrays_raise(input_shapes[::-1])

0 commit comments

Comments
 (0)