Skip to content

Commit b97efdf

Browse files
Merge pull request #584 from IntelPython/fix/usm-ndarray-suai
Fix/usm ndarray suai
2 parents 5f155bc + 91e4243 commit b97efdf

File tree

2 files changed

+71
-4
lines changed

2 files changed

+71
-4
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ cdef class usm_ndarray:
308308
mem_ptr = <char *>(<size_t> ary_iface['data'][0])
309309
ary_ptr = <char *>(<size_t> self.data_)
310310
ro_flag = False if (self.flags_ & USM_ARRAY_WRITEABLE) else True
311-
ary_iface['data'] = (<size_t> ary_ptr, ro_flag)
311+
ary_iface['data'] = (<size_t> mem_ptr, ro_flag)
312312
ary_iface['shape'] = self.shape
313313
if (self.strides_):
314314
ary_iface['strides'] = _make_int_tuple(self.nd_, self.strides_)
@@ -335,7 +335,7 @@ cdef class usm_ndarray:
335335
"""
336336
Gives the number of indices needed to address elements of this array.
337337
"""
338-
return int(self.nd_)
338+
return self.nd_
339339

340340
@property
341341
def usm_data(self):

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
import pytest
2222

2323
import dpctl
24-
25-
# import dpctl.memory as dpmem
24+
import dpctl.memory as dpm
2625
import dpctl.tensor as dpt
2726
from dpctl.tensor._usmarray import Device
2827

@@ -224,3 +223,71 @@ def test_slice_constructor_3d():
224223
assert np.array_equal(
225224
_to_numpy(Xusm[ind]), Xh[ind]
226225
), "Failed for {}".format(ind)
226+
227+
228+
@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
229+
def test_slice_suai(usm_type):
230+
Xh = np.arange(0, 10, dtype="u1")
231+
default_device = dpctl.select_default_device()
232+
Xusm = _from_numpy(Xh, device=default_device, usm_type=usm_type)
233+
for ind in [slice(2, 3, None), slice(5, 7, None), slice(3, 9, None)]:
234+
assert np.array_equal(
235+
dpm.as_usm_memory(Xusm[ind]).copy_to_host(), Xh[ind]
236+
), "Failed for {}".format(ind)
237+
238+
239+
def test_slicing_basic():
240+
Xusm = dpt.usm_ndarray((10, 5), dtype="c16")
241+
Xusm[None]
242+
Xusm[...]
243+
Xusm[8]
244+
Xusm[-3]
245+
with pytest.raises(IndexError):
246+
Xusm[..., ...]
247+
with pytest.raises(IndexError):
248+
Xusm[1, 1, :, 1]
249+
Xusm[:, -4]
250+
with pytest.raises(IndexError):
251+
Xusm[:, -128]
252+
with pytest.raises(TypeError):
253+
Xusm[{1, 2, 3, 4, 5, 6, 7}]
254+
255+
256+
def test_ctor_invalid_shape():
257+
with pytest.raises(TypeError):
258+
dpt.usm_ndarray(dict())
259+
260+
261+
def test_ctor_invalid_order():
262+
with pytest.raises(ValueError):
263+
dpt.usm_ndarray((5, 5, 3), order="Z")
264+
265+
266+
def test_ctor_buffer_kwarg():
267+
dpt.usm_ndarray(10, buffer=b"device")
268+
with pytest.raises(ValueError):
269+
dpt.usm_ndarray(10, buffer="invalid_param")
270+
Xusm = dpt.usm_ndarray((10, 5), dtype="c16")
271+
X2 = dpt.usm_ndarray(Xusm.shape, buffer=Xusm, dtype=Xusm.dtype)
272+
assert np.array_equal(
273+
Xusm.usm_data.copy_to_host(), X2.usm_data.copy_to_host()
274+
)
275+
with pytest.raises(ValueError):
276+
dpt.usm_ndarray(10, buffer=dict())
277+
278+
279+
def test_usm_ndarray_props():
280+
Xusm = dpt.usm_ndarray((10, 5), dtype="c16", order="F")
281+
Xusm.ndim
282+
repr(Xusm)
283+
Xusm.flags
284+
Xusm.__sycl_usm_array_interface__
285+
Xusm.device
286+
Xusm.strides
287+
Xusm.real
288+
Xusm.imag
289+
try:
290+
dpctl.SyclQueue("cpu")
291+
except dpctl.SyclQueueCreationError:
292+
pytest.skip("Sycl device CPU was not detected")
293+
Xusm.to_device("cpu")

0 commit comments

Comments
 (0)