Skip to content

Commit 439c2f4

Browse files
Merge pull request #708 from IntelPython/tensor-reshuffling
Exported dpctl.tensor.Device as array-API device class
2 parents 9310e08 + 8374c9f commit 439c2f4

File tree

6 files changed

+138
-97
lines changed

6 files changed

+138
-97
lines changed

dpctl/tensor/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,15 @@
2929
3030
"""
3131

32-
from dpctl.tensor._copy_utils import astype, copy
33-
from dpctl.tensor._copy_utils import copy_from_numpy as from_numpy
34-
from dpctl.tensor._copy_utils import copy_to_numpy as asnumpy
35-
from dpctl.tensor._copy_utils import copy_to_numpy as to_numpy
32+
from dpctl.tensor._copy_utils import asnumpy, astype, copy, from_numpy, to_numpy
3633
from dpctl.tensor._ctors import asarray, empty
34+
from dpctl.tensor._device import Device
3735
from dpctl.tensor._dlpack import from_dlpack
3836
from dpctl.tensor._reshape import reshape
3937
from dpctl.tensor._usmarray import usm_ndarray
4038

4139
__all__ = [
40+
"Device",
4241
"usm_ndarray",
4342
"asarray",
4443
"astype",

dpctl/tensor/_copy_utils.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import dpctl.memory as dpm
2121
import dpctl.tensor as dpt
22+
from dpctl.tensor._device import normalize_queue_device
2223

2324

2425
def contract_iter2(shape, strides1, strides2):
@@ -64,7 +65,7 @@ def contract_iter2(shape, strides1, strides2):
6465
return (sh, st1, disp1, st2, disp2)
6566

6667

67-
def has_memory_overlap(x1, x2):
68+
def _has_memory_overlap(x1, x2):
6869
m1 = dpm.as_usm_memory(x1)
6970
m2 = dpm.as_usm_memory(x2)
7071
if m1.sycl_device == m2.sycl_device:
@@ -77,7 +78,7 @@ def has_memory_overlap(x1, x2):
7778
return False
7879

7980

80-
def copy_to_numpy(ary):
81+
def _copy_to_numpy(ary):
8182
if type(ary) is not dpt.usm_ndarray:
8283
raise TypeError
8384
h = ary.usm_data.copy_to_host().view(ary.dtype)
@@ -93,7 +94,7 @@ def copy_to_numpy(ary):
9394
)
9495

9596

96-
def copy_from_numpy(np_ary, usm_type="device", sycl_queue=None):
97+
def _copy_from_numpy(np_ary, usm_type="device", sycl_queue=None):
9798
"Copies numpy array `np_ary` into a new usm_ndarray"
9899
# This may peform a copy to meet stated requirements
99100
Xnp = np.require(np_ary, requirements=["A", "O", "C", "E"])
@@ -111,7 +112,8 @@ def copy_from_numpy(np_ary, usm_type="device", sycl_queue=None):
111112
return Xusm
112113

113114

114-
def copy_from_numpy_into(dst, np_ary):
115+
def _copy_from_numpy_into(dst, np_ary):
116+
"Copies `np_ary` into `dst` of type :class:`dpctl.tensor.usm_ndarray"
115117
if not isinstance(np_ary, np.ndarray):
116118
raise TypeError("Expected numpy.ndarray, got {}".format(type(np_ary)))
117119
src_ary = np.broadcast_to(np.asarray(np_ary, dtype=dst.dtype), dst.shape)
@@ -122,6 +124,54 @@ def copy_from_numpy_into(dst, np_ary):
122124
usm_mem.copy_from_host(host_buf)
123125

124126

127+
def from_numpy(np_ary, device=None, usm_type="device", sycl_queue=None):
128+
"""
129+
from_numpy(arg, device=None, usm_type="device", sycl_queue=None)
130+
131+
Creates :class:`dpctl.tensor.usm_ndarray` from instance of
132+
`numpy.ndarray`.
133+
134+
Args:
135+
arg: An instance of `numpy.ndarray`
136+
device: array API specification of device where the output array
137+
is created.
138+
sycl_queue: a :class:`dpctl.SyclQueue` used to create the output
139+
array is created
140+
"""
141+
q = normalize_queue_device(sycl_queue=sycl_queue, device=device)
142+
return _copy_from_numpy(np_ary, usm_type=usm_type, sycl_queue=q)
143+
144+
145+
def to_numpy(usm_ary):
146+
"""
147+
to_numpy(usm_ary)
148+
149+
Copies content of :class:`dpctl.tensor.usm_ndarray` instance `usm_ary`
150+
into `numpy.ndarray` instance of the same shape and same data type.
151+
152+
Args:
153+
usm_ary: An instance of :class:`dpctl.tensor.usm_ndarray`
154+
Returns:
155+
An instance of `numpy.ndarray` populated with content of `usm_ary`.
156+
"""
157+
return _copy_to_numpy(usm_ary)
158+
159+
160+
def asnumpy(usm_ary):
161+
"""
162+
asnumpy(usm_ary)
163+
164+
Copies content of :class:`dpctl.tensor.usm_ndarray` instance `usm_ary`
165+
into `numpy.ndarray` instance of the same shape and same data type.
166+
167+
Args:
168+
usm_ary: An instance of :class:`dpctl.tensor.usm_ndarray`
169+
Returns:
170+
An instance of `numpy.ndarray` populated with content of `usm_ary`.
171+
"""
172+
return _copy_to_numpy(usm_ary)
173+
174+
125175
class Dummy:
126176
def __init__(self, iface):
127177
self.__sycl_usm_array_interface__ = iface
@@ -138,9 +188,9 @@ def copy_same_dtype(dst, src):
138188
raise ValueError
139189

140190
# check that memory regions do not overlap
141-
if has_memory_overlap(dst, src):
142-
tmp = copy_to_numpy(src)
143-
copy_from_numpy_into(dst, tmp)
191+
if _has_memory_overlap(dst, src):
192+
tmp = _copy_to_numpy(src)
193+
_copy_from_numpy_into(dst, tmp)
144194
return
145195

146196
if (dst.flags & 1) and (src.flags & 1):
@@ -184,10 +234,10 @@ def copy_same_shape(dst, src):
184234
return
185235

186236
# check that memory regions do not overlap
187-
if has_memory_overlap(dst, src):
188-
tmp = copy_to_numpy(src)
237+
if _has_memory_overlap(dst, src):
238+
tmp = _copy_to_numpy(src)
189239
tmp = tmp.astype(dst.dtype)
190-
copy_from_numpy_into(dst, tmp)
240+
_copy_from_numpy_into(dst, tmp)
191241
return
192242

193243
# simplify strides
@@ -218,7 +268,7 @@ def copy_same_shape(dst, src):
218268
mdst.copy_from_host(tmp.view("u1"))
219269

220270

221-
def copy_from_usm_ndarray_to_usm_ndarray(dst, src):
271+
def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
222272
if type(dst) is not dpt.usm_ndarray or type(src) is not dpt.usm_ndarray:
223273
raise TypeError
224274

@@ -389,7 +439,7 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
389439
buffer=R.usm_data,
390440
strides=new_strides,
391441
)
392-
copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
442+
_copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
393443
return R
394444
else:
395445
return usm_ary

dpctl/tensor/_ctors.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import dpctl.memory as dpm
2121
import dpctl.tensor as dpt
2222
import dpctl.utils
23+
from dpctl.tensor._device import normalize_queue_device
2324

2425
_empty_tuple = tuple()
2526
_host_set = frozenset([None])
@@ -72,29 +73,6 @@ def _array_info_sequence(li):
7273
return (n,) + dim, dt, device
7374

7475

75-
def _normalize_queue_device(q=None, d=None):
76-
if q is None:
77-
d = dpt._device.Device.create_device(d)
78-
return d.sycl_queue
79-
else:
80-
if not isinstance(q, dpctl.SyclQueue):
81-
raise TypeError(f"Expected dpctl.SyclQueue, got {type(q)}")
82-
if d is None:
83-
return q
84-
d = dpt._device.Device.create_device(d)
85-
qq = dpctl.utils.get_execution_queue(
86-
(
87-
q,
88-
d.sycl_queue,
89-
)
90-
)
91-
if qq is None:
92-
raise TypeError(
93-
"sycl_queue and device keywords can not be both specified"
94-
)
95-
return qq
96-
97-
9876
def _asarray_from_usm_ndarray(
9977
usm_ndary,
10078
dtype=None,
@@ -115,7 +93,7 @@ def _asarray_from_usm_ndarray(
11593
exec_q = dpctl.utils.get_execution_queue(
11694
[usm_ndary.sycl_queue, sycl_queue]
11795
)
118-
copy_q = _normalize_queue_device(q=sycl_queue, d=exec_q)
96+
copy_q = normalize_queue_device(sycl_queue=sycl_queue, device=exec_q)
11997
else:
12098
copy_q = usm_ndary.sycl_queue
12199
# Conditions for zero copy:
@@ -194,7 +172,7 @@ def _asarray_from_numpy_ndarray(
194172
usm_type = "device"
195173
if dtype is None:
196174
dtype = ary.dtype
197-
copy_q = _normalize_queue_device(q=None, d=sycl_queue)
175+
copy_q = normalize_queue_device(sycl_queue=None, device=sycl_queue)
198176
f_contig = ary.flags["F"]
199177
c_contig = ary.flags["C"]
200178
fc_contig = f_contig or c_contig
@@ -327,7 +305,9 @@ def asarray(
327305
)
328306
# 5. Normalize device/sycl_queue [keep it None if was None]
329307
if device is not None or sycl_queue is not None:
330-
sycl_queue = _normalize_queue_device(q=sycl_queue, d=device)
308+
sycl_queue = normalize_queue_device(
309+
sycl_queue=sycl_queue, device=device
310+
)
331311

332312
# handle instance(obj, usm_ndarray)
333313
if isinstance(obj, dpt.usm_ndarray):
@@ -459,7 +439,7 @@ def empty(
459439
raise TypeError(
460440
f"Expected usm_type to be of type str, got {type(usm_type)}"
461441
)
462-
sycl_queue = _normalize_queue_device(q=sycl_queue, d=device)
442+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
463443
res = dpt.usm_ndarray(
464444
sh,
465445
dtype=dtype,

dpctl/tensor/_device.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,54 @@ def __repr__(self):
105105
except TypeError:
106106
# This is a sub-device
107107
return repr(self.sycl_queue)
108+
109+
110+
def normalize_queue_device(sycl_queue=None, device=None):
111+
"""
112+
normalize_queue_device(sycl_queue=None, device=None)
113+
114+
Utility to process exclusive keyword arguments 'device'
115+
and 'sycl_queue' in functions of `dpctl.tensor`.
116+
117+
Args:
118+
sycl_queue(:class:`dpctl.SyclQueue`, optional):
119+
explicitly indicates where USM allocation is done
120+
and the population code (if any) is executed.
121+
Value `None` is interpreted as get the SYCL queue
122+
from `device` keyword, or use default queue.
123+
Default: None
124+
device (string, :class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue,
125+
:class:`dpctl.tensor.Device`, optional):
126+
array-API keyword indicating non-partitioned SYCL device
127+
where array is allocated.
128+
Returns
129+
:class:`dpctl.SyclQueue` object implied by either of provided
130+
keywords. If both are None, `dpctl.SyclQueue()` is returned.
131+
If both are specified and imply the same queue, `sycl_queue`
132+
is returned.
133+
Raises:
134+
TypeError: if argument is not of the expected type, or keywords
135+
imply incompatible queues.
136+
"""
137+
q = sycl_queue
138+
d = device
139+
if q is None:
140+
d = Device.create_device(d)
141+
return d.sycl_queue
142+
else:
143+
if not isinstance(q, dpctl.SyclQueue):
144+
raise TypeError(f"Expected dpctl.SyclQueue, got {type(q)}")
145+
if d is None:
146+
return q
147+
d = Device.create_device(d)
148+
qq = dpctl.utils.get_execution_queue(
149+
(
150+
q,
151+
d.sycl_queue,
152+
)
153+
)
154+
if qq is None:
155+
raise TypeError(
156+
"sycl_queue and device keywords can not be both specified"
157+
)
158+
return qq

dpctl/tensor/_usmarray.pyx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -884,13 +884,13 @@ cdef class usm_ndarray:
884884
except (ValueError, IndexError) as e:
885885
raise e
886886
from ._copy_utils import (
887-
copy_from_numpy_into,
888-
copy_from_usm_ndarray_to_usm_ndarray,
887+
_copy_from_numpy_into,
888+
_copy_from_usm_ndarray_to_usm_ndarray,
889889
)
890890
if isinstance(val, usm_ndarray):
891-
copy_from_usm_ndarray_to_usm_ndarray(Xv, val)
891+
_copy_from_usm_ndarray_to_usm_ndarray(Xv, val)
892892
else:
893-
copy_from_numpy_into(Xv, np.asarray(val))
893+
_copy_from_numpy_into(Xv, np.asarray(val))
894894

895895
def __sub__(first, other):
896896
"See comment in __add__"

0 commit comments

Comments
 (0)