Skip to content

Commit 9694941

Browse files
Merge branch 'spokhode/dparray' of https://github.com/IntelPython/dpctl into develop
2 parents 3d2dbd6 + 7ebd06c commit 9694941

File tree

5 files changed

+307
-14
lines changed

5 files changed

+307
-14
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
include versioneer.py
2+
include dparray.py
23
recursive-include dpctl/include *.h *.hpp
34
include dpctl/*.pxd
45
include dpctl/*DPPL*Interface.*

dpctl/__init__.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@
2929

3030
from dpctl._sycl_core cimport *
3131
from dpctl._memory import *
32-
32+
from .dparray import *

dpctl/dparray.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import numpy as np
2+
from inspect import getmembers, isfunction, isclass, isbuiltin
3+
from numbers import Number
4+
from types import FunctionType as ftype, BuiltinFunctionType as bftype
5+
import sys
6+
import inspect
7+
import dpctl
8+
from dpctl.memory import MemoryUSMShared
9+
10+
debug = False
11+
12+
13+
def dprint(*args):
14+
if debug:
15+
print(*args)
16+
sys.stdout.flush()
17+
18+
19+
functions_list = [o[0] for o in getmembers(np) if isfunction(o[1]) or isbuiltin(o[1])]
20+
class_list = [o for o in getmembers(np) if isclass(o[1])]
21+
22+
array_interface_property = "__sycl_usm_array_interface__"
23+
24+
25+
def has_array_interface(x):
26+
return hasattr(x, array_interface_property)
27+
28+
29+
class ndarray(np.ndarray):
30+
"""
31+
numpy.ndarray subclass whose underlying memory buffer is allocated
32+
with a foreign allocator.
33+
"""
34+
35+
def __new__(
36+
subtype, shape, dtype=float, buffer=None, offset=0, strides=None, order=None
37+
):
38+
# Create a new array.
39+
if buffer is None:
40+
dprint("dparray::ndarray __new__ buffer None")
41+
nelems = np.prod(shape)
42+
dt = np.dtype(dtype)
43+
isz = dt.itemsize
44+
nbytes = int(isz * max(1, nelems))
45+
buf = MemoryUSMShared(nbytes)
46+
new_obj = np.ndarray.__new__(
47+
subtype,
48+
shape,
49+
dtype=dt,
50+
buffer=buf,
51+
offset=0,
52+
strides=strides,
53+
order=order,
54+
)
55+
if hasattr(new_obj, array_interface_property):
56+
dprint("buffer None new_obj already has sycl_usm")
57+
else:
58+
dprint("buffer None new_obj will add sycl_usm")
59+
setattr(new_obj, array_interface_property, {})
60+
return new_obj
61+
# zero copy if buffer is a usm backed array-like thing
62+
elif hasattr(buffer, array_interface_property):
63+
dprint("dparray::ndarray __new__ buffer", array_interface_property)
64+
# also check for array interface
65+
new_obj = np.ndarray.__new__(
66+
subtype,
67+
shape,
68+
dtype=dtype,
69+
buffer=buffer,
70+
offset=offset,
71+
strides=strides,
72+
order=order,
73+
)
74+
if hasattr(new_obj, array_interface_property):
75+
dprint("buffer None new_obj already has sycl_usm")
76+
else:
77+
dprint("buffer None new_obj will add sycl_usm")
78+
setattr(new_obj, array_interface_property, {})
79+
return new_obj
80+
else:
81+
dprint("dparray::ndarray __new__ buffer not None and not sycl_usm")
82+
nelems = np.prod(shape)
83+
# must copy
84+
ar = np.ndarray(
85+
shape,
86+
dtype=dtype,
87+
buffer=buffer,
88+
offset=offset,
89+
strides=strides,
90+
order=order,
91+
)
92+
nbytes = int(ar.nbytes)
93+
buf = MemoryUSMShared(nbytes)
94+
new_obj = np.ndarray.__new__(
95+
subtype,
96+
shape,
97+
dtype=dtype,
98+
buffer=buf,
99+
offset=0,
100+
strides=strides,
101+
order=order,
102+
)
103+
np.copyto(new_obj, ar, casting="no")
104+
if hasattr(new_obj, array_interface_property):
105+
dprint("buffer None new_obj already has sycl_usm")
106+
else:
107+
dprint("buffer None new_obj will add sycl_usm")
108+
setattr(new_obj, array_interface_property, {})
109+
return new_obj
110+
111+
def __array_finalize__(self, obj):
112+
dprint("__array_finalize__:", obj, hex(id(obj)), type(obj))
113+
# When called from the explicit constructor, obj is None
114+
if obj is None:
115+
return
116+
# When called in new-from-template, `obj` is another instance of our own
117+
# subclass, that we might use to update the new `self` instance.
118+
# However, when called from view casting, `obj` can be an instance of any
119+
# subclass of ndarray, including our own.
120+
if hasattr(obj, array_interface_property):
121+
return
122+
if isinstance(obj, np.ndarray):
123+
ob = self
124+
while isinstance(ob, np.ndarray):
125+
if hasattr(ob, array_interface_property):
126+
return
127+
ob = ob.base
128+
129+
# Just raise an exception since __array_ufunc__ makes all reasonable cases not
130+
# need the code below.
131+
raise ValueError(
132+
"Non-USM allocated ndarray can not viewed as a USM-allocated one without a copy"
133+
)
134+
135+
# Tell Numba to not treat this type just like a NumPy ndarray but to propagate its type.
136+
# This way it will use the custom dparray allocator.
137+
__numba_no_subtype_ndarray__ = True
138+
139+
# Convert to a NumPy ndarray.
140+
def as_ndarray(self):
141+
return np.copy(self)
142+
143+
def __array__(self):
144+
return self
145+
146+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
147+
if method == "__call__":
148+
N = None
149+
scalars = []
150+
typing = []
151+
for inp in inputs:
152+
if isinstance(inp, Number):
153+
scalars.append(inp)
154+
typing.append(inp)
155+
elif isinstance(inp, (self.__class__, np.ndarray)):
156+
if isinstance(inp, self.__class__):
157+
scalars.append(np.ndarray(inp.shape, inp.dtype, inp))
158+
typing.append(np.ndarray(inp.shape, inp.dtype))
159+
else:
160+
scalars.append(inp)
161+
typing.append(inp)
162+
if N is not None:
163+
if N != inp.shape:
164+
raise TypeError("inconsistent sizes")
165+
else:
166+
N = inp.shape
167+
else:
168+
return NotImplemented
169+
# Have to avoid recursive calls to array_ufunc here.
170+
# If no out kwarg then we create a dparray out so that we get
171+
# USM memory. However, if kwarg has dparray-typed out then
172+
# array_ufunc is called recursively so we cast out as regular
173+
# NumPy ndarray (having a USM data pointer).
174+
if kwargs.get("out", None) is None:
175+
# maybe copy?
176+
# deal with multiple returned arrays, so kwargs['out'] can be tuple
177+
res_type = np.result_type(*typing)
178+
out = empty(inputs[0].shape, dtype=res_type)
179+
out_as_np = np.ndarray(out.shape, out.dtype, out)
180+
kwargs["out"] = out_as_np
181+
else:
182+
# If they manually gave dparray as out kwarg then we have to also
183+
# cast as regular NumPy ndarray to avoid recursion.
184+
if isinstance(kwargs["out"], ndarray):
185+
out = kwargs["out"]
186+
kwargs["out"] = np.ndarray(out.shape, out.dtype, out)
187+
else:
188+
out = kwargs["out"]
189+
ret = ufunc(*scalars, **kwargs)
190+
return out
191+
else:
192+
return NotImplemented
193+
194+
195+
def isdef(x):
196+
try:
197+
eval(x)
198+
return True
199+
except NameError:
200+
return False
201+
202+
203+
for c in class_list:
204+
cname = c[0]
205+
if isdef(cname):
206+
continue
207+
# For now we do the simple thing and copy the types from NumPy module into dparray module.
208+
new_func = "%s = np.%s" % (cname, cname)
209+
try:
210+
the_code = compile(new_func, "__init__", "exec")
211+
exec(the_code)
212+
except:
213+
print("Failed to exec type propagation", cname)
214+
pass
215+
216+
# Redefine all Numpy functions in this module and if they
217+
# return a Numpy array, transform that to a USM-backed array
218+
# instead. This is a stop-gap. We should eventually find a
219+
# way to do the allocation correct to start with.
220+
for fname in functions_list:
221+
if isdef(fname):
222+
continue
223+
new_func = "def %s(*args, **kwargs):\n" % fname
224+
new_func += " ret = np.%s(*args, **kwargs)\n" % fname
225+
new_func += " if type(ret) == np.ndarray:\n"
226+
new_func += " ret = ndarray(ret.shape, ret.dtype, ret)\n"
227+
new_func += " return ret\n"
228+
the_code = compile(new_func, "__init__", "exec")
229+
exec(the_code)
230+
231+
232+
def from_ndarray(x):
233+
return copy(x)
234+
235+
236+
def as_ndarray(x):
237+
return np.copy(x)

dpctl/tests/test_dparray.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import unittest
2+
from dpctl import dparray
3+
import numpy
4+
5+
6+
class TestOverloadList(unittest.TestCase):
7+
def setUp(self):
8+
self.X = dparray.ndarray((256, 4), dtype="d")
9+
self.X.fill(1.0)
10+
11+
def test_dparray_type(self):
12+
self.assertIsInstance(self.X, dparray.ndarray)
13+
14+
def test_dparray_as_ndarray_self(self):
15+
Y = self.X.as_ndarray()
16+
self.assertEqual(type(Y), numpy.ndarray)
17+
18+
def test_dparray_as_ndarray(self):
19+
Y = dparray.as_ndarray(self.X)
20+
self.assertEqual(type(Y), numpy.ndarray)
21+
22+
def test_dparray_from_ndarray(self):
23+
Y = dparray.as_ndarray(self.X)
24+
dp1 = dparray.from_ndarray(Y)
25+
self.assertIsInstance(dp1, dparray.ndarray)
26+
27+
def test_multiplication_dparray(self):
28+
C = self.X * 5
29+
self.assertIsInstance(C, dparray.ndarray)
30+
31+
def test_dparray_mixing_dpctl_and_numpy(self):
32+
dp_numpy = numpy.ones((256, 4), dtype="d")
33+
res = dp_numpy * self.X
34+
self.assertIsInstance(res, dparray.ndarray)
35+
36+
def test_dparray_shape(self):
37+
res = self.X.shape
38+
self.assertEqual(res, (256, 4))
39+
40+
def test_dparray_T(self):
41+
res = self.X.T
42+
self.assertEqual(res.shape, (4, 256))
43+
44+
def test_numpy_ravel_with_dparray(self):
45+
res = numpy.ravel(self.X)
46+
self.assertEqual(res.shape, (1024,))
47+
48+
@unittest.expectedFailure
49+
def test_numpy_sum_with_dparray(self):
50+
res = numpy.sum(self.X)
51+
self.assertEqual(res, 1024.0)
52+
53+
54+
if __name__ == "__main__":
55+
unittest.main()

setup.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
##===---------- setup.py - dpctl.ocldrv interface -----*- Python -*-----===##
22
##
3-
## Data Parallel Control Library (dpCtl)
3+
# Data Parallel Control Library (dpCtl)
44
##
5-
## Copyright 2020 Intel Corporation
5+
# Copyright 2020 Intel Corporation
66
##
7-
## Licensed under the Apache License, Version 2.0 (the "License");
8-
## you may not use this file except in compliance with the License.
9-
## You may obtain a copy of the License at
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
1010
##
11-
## http://www.apache.org/licenses/LICENSE-2.0
11+
# http://www.apache.org/licenses/LICENSE-2.0
1212
##
13-
## Unless required by applicable law or agreed to in writing, software
14-
## distributed under the License is distributed on an "AS IS" BASIS,
15-
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16-
## See the License for the specific language governing permissions and
17-
## limitations under the License.
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
1818
##
1919
##===----------------------------------------------------------------------===##
2020
###
21-
### \file
22-
### This file builds the dpctl and dpctl.ocldrv extension modules.
21+
# \file
22+
# This file builds the dpctl and dpctl.ocldrv extension modules.
2323
##===----------------------------------------------------------------------===##
2424
import os
2525
import os.path

0 commit comments

Comments
 (0)