From b2cd5c18e6bcd9f3e81b19b1c39d798e0f3836b6 Mon Sep 17 00:00:00 2001 From: PiotrekB416 Date: Mon, 28 Aug 2023 00:08:13 +0200 Subject: [PATCH 1/2] Add support for lowercase order in tensor.copy and tensor.astype --- dpctl/tensor/_copy_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 3eae29f057..321c1393c4 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -517,6 +517,7 @@ def copy(usm_ary, order="K"): - "K": match the layout of `usm_ary` as closely as possible. """ + order = order.upper() if not isinstance(usm_ary, dpt.usm_ndarray): return TypeError( f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}" @@ -581,6 +582,7 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): A view can be returned, if possible, when `copy=False` is used. """ + order = order.upper() if not isinstance(usm_ary, dpt.usm_ndarray): return TypeError( f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}" From 5390456db2b1958f44e89ad20f57c65384d6a2bb Mon Sep 17 00:00:00 2001 From: PiotrekB416 Date: Mon, 28 Aug 2023 00:18:09 +0200 Subject: [PATCH 2/2] Refactored --- dpctl/tensor/_copy_utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 321c1393c4..ad5b956851 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -517,7 +517,11 @@ def copy(usm_ary, order="K"): - "K": match the layout of `usm_ary` as closely as possible. """ - order = order.upper() + if len(order) == 0 or order[0] not in "KkAaCcFf": + raise ValueError( + "Unrecognized order keyword value, expecting 'K', 'A', 'F', or 'C'." + ) + order = order[0].upper() if not isinstance(usm_ary, dpt.usm_ndarray): return TypeError( f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}" @@ -582,16 +586,15 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): A view can be returned, if possible, when `copy=False` is used. """ - order = order.upper() if not isinstance(usm_ary, dpt.usm_ndarray): return TypeError( f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}" ) - if not isinstance(order, str) or order not in ["A", "C", "F", "K"]: + if len(order) == 0 or order[0] not in "KkAaCcFf": raise ValueError( - "Unrecognized value of the order keyword. " - "Recognized values are 'A', 'C', 'F', or 'K'" + "Unrecognized order keyword value, expecting 'K', 'A', 'F', or 'C'." ) + order = order[0].upper() ary_dtype = usm_ary.dtype target_dtype = _get_dtype(newdtype, usm_ary.sycl_queue) if not dpt.can_cast(ary_dtype, target_dtype, casting=casting):