Skip to content

Commit 44bbe03

Browse files
Merge branch 'review-pr-192' into merged-192-200
2 parents bbd586a + 897afc9 commit 44bbe03

File tree

3 files changed

+341
-0
lines changed

3 files changed

+341
-0
lines changed

dpctl/dparray.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
##===---------- dparray.py - dpctl -------*- Python -*----===##
2+
##
3+
## Data Parallel Control (dpCtl)
4+
##
5+
## Copyright 2020 Intel Corporation
6+
##
7+
## Licensed under the Apache License, Version 2.0 (the "License");
8+
## you may not use this file except in compliance with the License.
9+
## You may obtain a copy of the License at
10+
##
11+
## http://www.apache.org/licenses/LICENSE-2.0
12+
##
13+
## Unless required by applicable law or agreed to in writing, software
14+
## distributed under the License is distributed on an "AS IS" BASIS,
15+
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
## See the License for the specific language governing permissions and
17+
## limitations under the License.
18+
##
19+
##===----------------------------------------------------------------------===##
20+
###
21+
### \file
22+
### This file implements a dparray - USM aware implementation of ndarray.
23+
##===----------------------------------------------------------------------===##
24+
25+
import numpy as np
26+
from inspect import getmembers, isfunction, isclass, isbuiltin
27+
from numbers import Number
28+
from types import FunctionType as ftype, BuiltinFunctionType as bftype
29+
import sys
30+
import inspect
31+
import dpctl
32+
from dpctl.memory import MemoryUSMShared
33+
34+
debug = False
35+
36+
37+
def dprint(*args):
38+
if debug:
39+
print(*args)
40+
sys.stdout.flush()
41+
42+
43+
functions_list = [o[0] for o in getmembers(np) if isfunction(o[1]) or isbuiltin(o[1])]
44+
class_list = [o for o in getmembers(np) if isclass(o[1])]
45+
46+
array_interface_property = "__sycl_usm_array_interface__"
47+
48+
49+
def has_array_interface(x):
50+
return hasattr(x, array_interface_property)
51+
52+
53+
class ndarray(np.ndarray):
54+
"""
55+
numpy.ndarray subclass whose underlying memory buffer is allocated
56+
with a foreign allocator.
57+
"""
58+
59+
def __new__(
60+
subtype, shape, dtype=float, buffer=None, offset=0, strides=None, order=None
61+
):
62+
# Create a new array.
63+
if buffer is None:
64+
dprint("dparray::ndarray __new__ buffer None")
65+
nelems = np.prod(shape)
66+
dt = np.dtype(dtype)
67+
isz = dt.itemsize
68+
nbytes = int(isz * max(1, nelems))
69+
buf = MemoryUSMShared(nbytes)
70+
new_obj = np.ndarray.__new__(
71+
subtype,
72+
shape,
73+
dtype=dt,
74+
buffer=buf,
75+
offset=0,
76+
strides=strides,
77+
order=order,
78+
)
79+
if hasattr(new_obj, array_interface_property):
80+
dprint("buffer None new_obj already has sycl_usm")
81+
else:
82+
dprint("buffer None new_obj will add sycl_usm")
83+
setattr(new_obj, array_interface_property, {})
84+
return new_obj
85+
# zero copy if buffer is a usm backed array-like thing
86+
elif hasattr(buffer, array_interface_property):
87+
dprint("dparray::ndarray __new__ buffer", array_interface_property)
88+
# also check for array interface
89+
new_obj = np.ndarray.__new__(
90+
subtype,
91+
shape,
92+
dtype=dtype,
93+
buffer=buffer,
94+
offset=offset,
95+
strides=strides,
96+
order=order,
97+
)
98+
if hasattr(new_obj, array_interface_property):
99+
dprint("buffer None new_obj already has sycl_usm")
100+
else:
101+
dprint("buffer None new_obj will add sycl_usm")
102+
setattr(new_obj, array_interface_property, {})
103+
return new_obj
104+
else:
105+
dprint("dparray::ndarray __new__ buffer not None and not sycl_usm")
106+
nelems = np.prod(shape)
107+
# must copy
108+
ar = np.ndarray(
109+
shape,
110+
dtype=dtype,
111+
buffer=buffer,
112+
offset=offset,
113+
strides=strides,
114+
order=order,
115+
)
116+
nbytes = int(ar.nbytes)
117+
buf = MemoryUSMShared(nbytes)
118+
new_obj = np.ndarray.__new__(
119+
subtype,
120+
shape,
121+
dtype=dtype,
122+
buffer=buf,
123+
offset=0,
124+
strides=strides,
125+
order=order,
126+
)
127+
np.copyto(new_obj, ar, casting="no")
128+
if hasattr(new_obj, array_interface_property):
129+
dprint("buffer None new_obj already has sycl_usm")
130+
else:
131+
dprint("buffer None new_obj will add sycl_usm")
132+
setattr(new_obj, array_interface_property, {})
133+
return new_obj
134+
135+
def __array_finalize__(self, obj):
136+
dprint("__array_finalize__:", obj, hex(id(obj)), type(obj))
137+
# When called from the explicit constructor, obj is None
138+
if obj is None:
139+
return
140+
# When called in new-from-template, `obj` is another instance of our own
141+
# subclass, that we might use to update the new `self` instance.
142+
# However, when called from view casting, `obj` can be an instance of any
143+
# subclass of ndarray, including our own.
144+
if hasattr(obj, array_interface_property):
145+
return
146+
if isinstance(obj, np.ndarray):
147+
ob = self
148+
while isinstance(ob, np.ndarray):
149+
if hasattr(ob, array_interface_property):
150+
return
151+
ob = ob.base
152+
153+
# Just raise an exception since __array_ufunc__ makes all reasonable cases not
154+
# need the code below.
155+
raise ValueError(
156+
"Non-USM allocated ndarray can not viewed as a USM-allocated one without a copy"
157+
)
158+
159+
# Tell Numba to not treat this type just like a NumPy ndarray but to propagate its type.
160+
# This way it will use the custom dparray allocator.
161+
__numba_no_subtype_ndarray__ = True
162+
163+
# Convert to a NumPy ndarray.
164+
def as_ndarray(self):
165+
return np.copy(self)
166+
167+
def __array__(self):
168+
return self
169+
170+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
171+
if method == "__call__":
172+
N = None
173+
scalars = []
174+
typing = []
175+
for inp in inputs:
176+
if isinstance(inp, Number):
177+
scalars.append(inp)
178+
typing.append(inp)
179+
elif isinstance(inp, (self.__class__, np.ndarray)):
180+
if isinstance(inp, self.__class__):
181+
scalars.append(np.ndarray(inp.shape, inp.dtype, inp))
182+
typing.append(np.ndarray(inp.shape, inp.dtype))
183+
else:
184+
scalars.append(inp)
185+
typing.append(inp)
186+
if N is not None:
187+
if N != inp.shape:
188+
raise TypeError("inconsistent sizes")
189+
else:
190+
N = inp.shape
191+
else:
192+
return NotImplemented
193+
# Have to avoid recursive calls to array_ufunc here.
194+
# If no out kwarg then we create a dparray out so that we get
195+
# USM memory. However, if kwarg has dparray-typed out then
196+
# array_ufunc is called recursively so we cast out as regular
197+
# NumPy ndarray (having a USM data pointer).
198+
if kwargs.get("out", None) is None:
199+
# maybe copy?
200+
# deal with multiple returned arrays, so kwargs['out'] can be tuple
201+
res_type = np.result_type(*typing)
202+
out = empty(inputs[0].shape, dtype=res_type)
203+
out_as_np = np.ndarray(out.shape, out.dtype, out)
204+
kwargs["out"] = out_as_np
205+
else:
206+
# If they manually gave dparray as out kwarg then we have to also
207+
# cast as regular NumPy ndarray to avoid recursion.
208+
if isinstance(kwargs["out"], ndarray):
209+
out = kwargs["out"]
210+
kwargs["out"] = np.ndarray(out.shape, out.dtype, out)
211+
else:
212+
out = kwargs["out"]
213+
ret = ufunc(*scalars, **kwargs)
214+
return out
215+
else:
216+
return NotImplemented
217+
218+
219+
def isdef(x):
220+
try:
221+
eval(x)
222+
return True
223+
except NameError:
224+
return False
225+
226+
227+
for c in class_list:
228+
cname = c[0]
229+
if isdef(cname):
230+
continue
231+
# For now we do the simple thing and copy the types from NumPy module into dparray module.
232+
new_func = "%s = np.%s" % (cname, cname)
233+
try:
234+
the_code = compile(new_func, "__init__", "exec")
235+
exec(the_code)
236+
except:
237+
print("Failed to exec type propagation", cname)
238+
pass
239+
240+
# Redefine all Numpy functions in this module and if they
241+
# return a Numpy array, transform that to a USM-backed array
242+
# instead. This is a stop-gap. We should eventually find a
243+
# way to do the allocation correct to start with.
244+
for fname in functions_list:
245+
if isdef(fname):
246+
continue
247+
new_func = "def %s(*args, **kwargs):\n" % fname
248+
new_func += " ret = np.%s(*args, **kwargs)\n" % fname
249+
new_func += " if type(ret) == np.ndarray:\n"
250+
new_func += " ret = ndarray(ret.shape, ret.dtype, ret)\n"
251+
new_func += " return ret\n"
252+
the_code = compile(new_func, "__init__", "exec")
253+
exec(the_code)
254+
255+
256+
def from_ndarray(x):
257+
return copy(x)
258+
259+
260+
def as_ndarray(x):
261+
return np.copy(x)

dpctl/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# Top-level module of all dpctl Python unit test cases.
2323
# ===-----------------------------------------------------------------------===#
2424

25+
from .test_dparray import *
2526
from .test_dump_functions import *
2627
from .test_sycl_device import *
2728
from .test_sycl_kernel_submit import *

dpctl/tests/test_dparray.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
##===---------- test_dparray.py - dpctl -------*- Python -*----===##
2+
##
3+
## Data Parallel Control (dpCtl)
4+
##
5+
## Copyright 2020 Intel Corporation
6+
##
7+
## Licensed under the Apache License, Version 2.0 (the "License");
8+
## you may not use this file except in compliance with the License.
9+
## You may obtain a copy of the License at
10+
##
11+
## http://www.apache.org/licenses/LICENSE-2.0
12+
##
13+
## Unless required by applicable law or agreed to in writing, software
14+
## distributed under the License is distributed on an "AS IS" BASIS,
15+
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
## See the License for the specific language governing permissions and
17+
## limitations under the License.
18+
##
19+
##===----------------------------------------------------------------------===##
20+
###
21+
### \file
22+
### A basic unit test for dpctl.dparray.
23+
##===----------------------------------------------------------------------===##
24+
25+
import unittest
26+
from dpctl import dparray
27+
import numpy
28+
29+
30+
class Test_dparray(unittest.TestCase):
31+
def setUp(self):
32+
self.X = dparray.ndarray((256, 4), dtype="d")
33+
self.X.fill(1.0)
34+
35+
def test_dparray_type(self):
36+
self.assertIsInstance(self.X, dparray.ndarray)
37+
38+
def test_dparray_as_ndarray_self(self):
39+
Y = self.X.as_ndarray()
40+
self.assertEqual(type(Y), numpy.ndarray)
41+
42+
def test_dparray_as_ndarray(self):
43+
Y = dparray.as_ndarray(self.X)
44+
self.assertEqual(type(Y), numpy.ndarray)
45+
46+
def test_dparray_from_ndarray(self):
47+
Y = dparray.as_ndarray(self.X)
48+
dp1 = dparray.from_ndarray(Y)
49+
self.assertIsInstance(dp1, dparray.ndarray)
50+
51+
def test_multiplication_dparray(self):
52+
C = self.X * 5
53+
self.assertIsInstance(C, dparray.ndarray)
54+
55+
def test_dparray_mixing_dpctl_and_numpy(self):
56+
dp_numpy = numpy.ones((256, 4), dtype="d")
57+
res = dp_numpy * self.X
58+
self.assertIsInstance(res, dparray.ndarray)
59+
60+
def test_dparray_shape(self):
61+
res = self.X.shape
62+
self.assertEqual(res, (256, 4))
63+
64+
def test_dparray_T(self):
65+
res = self.X.T
66+
self.assertEqual(res.shape, (4, 256))
67+
68+
def test_numpy_ravel_with_dparray(self):
69+
res = numpy.ravel(self.X)
70+
self.assertEqual(res.shape, (1024,))
71+
72+
@unittest.expectedFailure
73+
def test_numpy_sum_with_dparray(self):
74+
res = numpy.sum(self.X)
75+
self.assertEqual(res, 1024.0)
76+
77+
78+
if __name__ == "__main__":
79+
unittest.main()

0 commit comments

Comments
 (0)