diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 1a907ebf20..094b959efe 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -48,18 +48,14 @@ def _broadcast_strides(X_shape, X_strides, res_ndim): return tuple(out_strides) -def _broadcast_shapes(*args): - """ - Broadcast the input shapes into a single shape; - returns tuple broadcasted shape. - """ - shapes = [array.shape for array in args] +def _broadcast_shape_impl(shapes): if len(set(shapes)) == 1: return shapes[0] mutable_shapes = False nds = [len(s) for s in shapes] biggest = max(nds) - for i in range(len(args)): + sh_len = len(shapes) + for i in range(sh_len): diff = biggest - nds[i] if diff > 0: ty = type(shapes[i]) @@ -77,7 +73,7 @@ def _broadcast_shapes(*args): unique.remove(1) new_length = unique.pop() common_shape.append(new_length) - for i in range(len(args)): + for i in range(sh_len): if shapes[i][axis] == 1: if not mutable_shapes: shapes = [list(s) for s in shapes] @@ -89,6 +85,15 @@ def _broadcast_shapes(*args): return tuple(common_shape) +def _broadcast_shapes(*args): + """ + Broadcast the input shapes into a single shape; + returns tuple broadcasted shape. + """ + array_shapes = [array.shape for array in args] + return _broadcast_shape_impl(array_shapes) + + def permute_dims(X, axes): """ permute_dims(X: usm_ndarray, axes: tuple or list) -> usm_ndarray