From ea1e66343ece7207deb364311ad2e9bac0842e10 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 26 Jun 2024 23:58:03 +0000 Subject: [PATCH 1/2] Fixes unhelpful ValueError when no arguments are passed to `broadcast_arrays` --- dpctl/tensor/_manipulation_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index af44104288..27873e2b3a 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -265,6 +265,8 @@ def broadcast_arrays(*args): `device` and `usm_type` attributes as its corresponding input array. """ + if len(args) == 0: + raise ValueError("`broadcast_arrays` requires at least one argument") for X in args: if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") From 21e1958b67ba67259a8f156b09717e11893c673c Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 27 Jun 2024 00:13:34 +0000 Subject: [PATCH 2/2] Add a test for `broadcast_arrays` empty input value error --- dpctl/tests/test_usm_ndarray_manipulation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index b64f68cbb8..563d5bf9dd 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -435,6 +435,11 @@ def test_incompatible_shapes_raise_valueerror(shapes): assert_broadcast_arrays_raise(input_shapes[::-1]) +def test_broadcast_arrays_no_args(): + with pytest.raises(ValueError): + dpt.broadcast_arrays() + + def test_flip_axis_incorrect(): q = get_queue_or_skip()