diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index af44104288..27873e2b3a 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -265,6 +265,8 @@ def broadcast_arrays(*args): `device` and `usm_type` attributes as its corresponding input array. """ + if len(args) == 0: + raise ValueError("`broadcast_arrays` requires at least one argument") for X in args: if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index b64f68cbb8..563d5bf9dd 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -435,6 +435,11 @@ def test_incompatible_shapes_raise_valueerror(shapes): assert_broadcast_arrays_raise(input_shapes[::-1]) +def test_broadcast_arrays_no_args(): + with pytest.raises(ValueError): + dpt.broadcast_arrays() + + def test_flip_axis_incorrect(): q = get_queue_or_skip()