@@ -278,12 +278,15 @@ def test_sycl_usm_array_interface(memory_ctor):
278
278
279
279
280
280
class View :
281
- def __init__ (self , buf , shape , strides , offset , syclobj = None ):
281
+ def __init__ (
282
+ self , buf , shape , strides , offset , syclobj = None , transf_fn = None
283
+ ):
282
284
self .buffer_ = buf
283
285
self .shape_ = shape
284
286
self .strides_ = strides
285
287
self .offset_ = offset
286
288
self .syclobj_ = syclobj
289
+ self .transf_fn_ = transf_fn
287
290
288
291
@property
289
292
def __sycl_usm_array_interface__ (self ):
@@ -293,6 +296,8 @@ def __sycl_usm_array_interface__(self):
293
296
sua_iface ["strides" ] = self .strides_
294
297
if self .syclobj_ :
295
298
sua_iface ["syclobj" ] = self .syclobj_
299
+ if self .transf_fn_ :
300
+ sua_iface = self .transf_fn_ (sua_iface )
296
301
return sua_iface
297
302
298
303
@@ -360,6 +365,95 @@ def test_suai_non_contig_2D(memory_ctor):
360
365
assert np .array_equal (res , expected_res )
361
366
362
367
368
+ def test_suai_invalid_suai ():
369
+ n_bytes = 2 * 3 * 5 * 128
370
+ try :
371
+ q = dpctl .SyclQueue ()
372
+ except dpctl .SyclQueueCreationError :
373
+ pytest .skip ("Could not create default queue" )
374
+ try :
375
+ buf = MemoryUSMShared (n_bytes , queue = q )
376
+ except Exception :
377
+ pytest .skip ("USM-shared allocation failed" )
378
+ # different syclobj values
379
+ for syclobj in [
380
+ q ,
381
+ q .sycl_context ,
382
+ q ._get_capsule (),
383
+ q .sycl_context ._get_capsule (),
384
+ ]:
385
+ v = View (buf , shape = (n_bytes ,), strides = (1 ,), offset = 0 , syclobj = syclobj )
386
+ MemoryUSMShared (v )
387
+ with pytest .raises (ValueError ):
388
+ MemoryUSMDevice (v )
389
+ with pytest .raises (ValueError ):
390
+ MemoryUSMHost (v )
391
+
392
+ # version validation
393
+ def invalid_version (suai_iface ):
394
+ "Set version to invalid"
395
+ suai_iface ["version" ] = 0
396
+ return suai_iface
397
+
398
+ v = View (
399
+ buf , shape = (n_bytes ,), strides = (1 ,), offset = 0 , transf_fn = invalid_version
400
+ )
401
+ with pytest .raises (ValueError ):
402
+ MemoryUSMShared (v )
403
+
404
+ # data validation
405
+ def invalid_data (suai_iface ):
406
+ "Set data to invalid"
407
+ suai_iface ["data" ] = tuple ()
408
+ return suai_iface
409
+
410
+ v = View (
411
+ buf , shape = (n_bytes ,), strides = (1 ,), offset = 0 , transf_fn = invalid_data
412
+ )
413
+ with pytest .raises (ValueError ):
414
+ MemoryUSMShared (v )
415
+ # set shape to a negative value
416
+ v = View (buf , shape = (- n_bytes ,), strides = (2 ,), offset = 0 )
417
+ with pytest .raises (ValueError ):
418
+ MemoryUSMShared (v )
419
+ v = View (buf , shape = (- n_bytes ,), strides = None , offset = 0 )
420
+ with pytest .raises (ValueError ):
421
+ MemoryUSMShared (v )
422
+ # shape validation
423
+ v = View (buf , shape = None , strides = (1 ,), offset = 0 )
424
+ with pytest .raises (ValueError ):
425
+ MemoryUSMShared (v )
426
+
427
+ # typestr validation
428
+ def invalid_typestr (suai_iface ):
429
+ suai_iface ["typestr" ] = "invalid"
430
+ return suai_iface
431
+
432
+ v = View (
433
+ buf , shape = (n_bytes ,), strides = (1 ,), offset = 0 , transf_fn = invalid_typestr
434
+ )
435
+ with pytest .raises (ValueError ):
436
+ MemoryUSMShared (v )
437
+
438
+ def unsupported_typestr (suai_iface ):
439
+ suai_iface ["typestr" ] = "O"
440
+ return suai_iface
441
+
442
+ v = View (
443
+ buf ,
444
+ shape = (n_bytes ,),
445
+ strides = (1 ,),
446
+ offset = 0 ,
447
+ transf_fn = unsupported_typestr ,
448
+ )
449
+ with pytest .raises (ValueError ):
450
+ MemoryUSMShared (v )
451
+ # set strides to invalid value
452
+ v = View (buf , shape = (n_bytes ,), strides = Ellipsis , offset = 0 )
453
+ with pytest .raises (ValueError ):
454
+ MemoryUSMShared (v )
455
+
456
+
363
457
def check_view (v ):
364
458
"""
365
459
Memory object created from duck __sycl_usm_array_interface__ argument
0 commit comments