Skip to content

Commit 6a14e3f

Browse files
Add tests to check validation of __sycl_usm_array_interface__ in memory objects
1 parent 20d418b commit 6a14e3f

File tree

1 file changed

+95
-1
lines changed

1 file changed

+95
-1
lines changed

dpctl/tests/test_sycl_usm.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,12 +278,15 @@ def test_sycl_usm_array_interface(memory_ctor):
278278

279279

280280
class View:
281-
def __init__(self, buf, shape, strides, offset, syclobj=None):
281+
def __init__(
282+
self, buf, shape, strides, offset, syclobj=None, transf_fn=None
283+
):
282284
self.buffer_ = buf
283285
self.shape_ = shape
284286
self.strides_ = strides
285287
self.offset_ = offset
286288
self.syclobj_ = syclobj
289+
self.transf_fn_ = transf_fn
287290

288291
@property
289292
def __sycl_usm_array_interface__(self):
@@ -293,6 +296,8 @@ def __sycl_usm_array_interface__(self):
293296
sua_iface["strides"] = self.strides_
294297
if self.syclobj_:
295298
sua_iface["syclobj"] = self.syclobj_
299+
if self.transf_fn_:
300+
sua_iface = self.transf_fn_(sua_iface)
296301
return sua_iface
297302

298303

@@ -360,6 +365,95 @@ def test_suai_non_contig_2D(memory_ctor):
360365
assert np.array_equal(res, expected_res)
361366

362367

368+
def test_suai_invalid_suai():
369+
n_bytes = 2 * 3 * 5 * 128
370+
try:
371+
q = dpctl.SyclQueue()
372+
except dpctl.SyclQueueCreationError:
373+
pytest.skip("Could not create default queue")
374+
try:
375+
buf = MemoryUSMShared(n_bytes, queue=q)
376+
except Exception:
377+
pytest.skip("USM-shared allocation failed")
378+
# different syclobj values
379+
for syclobj in [
380+
q,
381+
q.sycl_context,
382+
q._get_capsule(),
383+
q.sycl_context._get_capsule(),
384+
]:
385+
v = View(buf, shape=(n_bytes,), strides=(1,), offset=0, syclobj=syclobj)
386+
MemoryUSMShared(v)
387+
with pytest.raises(ValueError):
388+
MemoryUSMDevice(v)
389+
with pytest.raises(ValueError):
390+
MemoryUSMHost(v)
391+
392+
# version validation
393+
def invalid_version(suai_iface):
394+
"Set version to invalid"
395+
suai_iface["version"] = 0
396+
return suai_iface
397+
398+
v = View(
399+
buf, shape=(n_bytes,), strides=(1,), offset=0, transf_fn=invalid_version
400+
)
401+
with pytest.raises(ValueError):
402+
MemoryUSMShared(v)
403+
404+
# data validation
405+
def invalid_data(suai_iface):
406+
"Set data to invalid"
407+
suai_iface["data"] = tuple()
408+
return suai_iface
409+
410+
v = View(
411+
buf, shape=(n_bytes,), strides=(1,), offset=0, transf_fn=invalid_data
412+
)
413+
with pytest.raises(ValueError):
414+
MemoryUSMShared(v)
415+
# set shape to a negative value
416+
v = View(buf, shape=(-n_bytes,), strides=(2,), offset=0)
417+
with pytest.raises(ValueError):
418+
MemoryUSMShared(v)
419+
v = View(buf, shape=(-n_bytes,), strides=None, offset=0)
420+
with pytest.raises(ValueError):
421+
MemoryUSMShared(v)
422+
# shape validation
423+
v = View(buf, shape=None, strides=(1,), offset=0)
424+
with pytest.raises(ValueError):
425+
MemoryUSMShared(v)
426+
427+
# typestr validation
428+
def invalid_typestr(suai_iface):
429+
suai_iface["typestr"] = "invalid"
430+
return suai_iface
431+
432+
v = View(
433+
buf, shape=(n_bytes,), strides=(1,), offset=0, transf_fn=invalid_typestr
434+
)
435+
with pytest.raises(ValueError):
436+
MemoryUSMShared(v)
437+
438+
def unsupported_typestr(suai_iface):
439+
suai_iface["typestr"] = "O"
440+
return suai_iface
441+
442+
v = View(
443+
buf,
444+
shape=(n_bytes,),
445+
strides=(1,),
446+
offset=0,
447+
transf_fn=unsupported_typestr,
448+
)
449+
with pytest.raises(ValueError):
450+
MemoryUSMShared(v)
451+
# set strides to invalid value
452+
v = View(buf, shape=(n_bytes,), strides=Ellipsis, offset=0)
453+
with pytest.raises(ValueError):
454+
MemoryUSMShared(v)
455+
456+
363457
def check_view(v):
364458
"""
365459
Memory object created from duck __sycl_usm_array_interface__ argument

0 commit comments

Comments
 (0)