Skip to content

Commit 6985c77

Browse files
Use numpy.promote_types to allow concatenation of different dtypes
1 parent 04142cd commit 6985c77

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def concat(arrays, axis=0):
298298
if n == 0:
299299
raise TypeError("Missing 1 required positional argument: 'arrays'")
300300

301-
if not isinstance(arrays, list) and not isinstance(arrays, tuple):
301+
if not isinstance(arrays, (list, tuple)):
302302
raise TypeError(f"Expected tuple or list type, got {type(arrays)}.")
303303

304304
for X in arrays:
@@ -314,8 +314,12 @@ def concat(arrays, axis=0):
314314
raise ValueError("All the input arrays must have usm_type")
315315

316316
X0 = arrays[0]
317-
if any(X0.dtype != arrays[i].dtype for i in range(1, n)):
318-
raise ValueError("All the input arrays must have same dtype")
317+
if not all(Xi.dtype.char in "?bBhHiIlLefdFD" for Xi in arrays):
318+
raise ValueError("Unsupported dtype encountered.")
319+
320+
res_dtype = X0.dtype
321+
for i in range(1, n):
322+
res_dtype = np.promote_types(res_dtype, arrays[i])
319323

320324
for i in range(1, n):
321325
if X0.ndim != arrays[i].ndim:
@@ -349,7 +353,7 @@ def concat(arrays, axis=0):
349353
)
350354

351355
res = dpt.empty(
352-
res_shape, dtype=X0.dtype, usm_type=res_usm_type, sycl_queue=exec_q
356+
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
353357
)
354358

355359
hev_list = []

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -748,16 +748,20 @@ def test_concat_incorrect_queue():
748748
pytest.raises(ValueError, dpt.concat, [X, Y])
749749

750750

751-
def test_concat_incorrect_dtype():
751+
def test_concat_different_dtype():
752752
try:
753753
q = dpctl.SyclQueue()
754754
except dpctl.SyclQueueCreationError:
755755
pytest.skip("Queue could not be created")
756756

757757
X = dpt.ones((2, 2), dtype=np.int64, sycl_queue=q)
758-
Y = dpt.ones((2, 2), dtype=np.uint64, sycl_queue=q)
758+
Y = dpt.ones((3, 2), dtype=np.uint32, sycl_queue=q)
759759

760-
pytest.raises(ValueError, dpt.concat, [X, Y])
760+
XY = dpt.concat([X, Y])
761+
762+
assert XY.dtype is X.dtype
763+
assert XY.shape == (5, 2)
764+
assert XY.sycl_queue == q
761765

762766

763767
def test_concat_incorrect_ndim():

0 commit comments

Comments
 (0)