@@ -373,7 +373,11 @@ def test_datapi_device():
373
373
374
374
375
375
def _pyx_capi_fnptr_to_callable (
376
- X , pyx_capi_name , caps_name , fn_restype = ctypes .c_void_p
376
+ X ,
377
+ pyx_capi_name ,
378
+ caps_name ,
379
+ fn_restype = ctypes .c_void_p ,
380
+ fn_argtypes = (ctypes .py_object ,),
377
381
):
378
382
import sys
379
383
@@ -388,7 +392,7 @@ def _pyx_capi_fnptr_to_callable(
388
392
cap_ptr_fn .restype = ctypes .c_void_p
389
393
cap_ptr_fn .argtypes = [ctypes .py_object , ctypes .c_char_p ]
390
394
fn_ptr = cap_ptr_fn (cap , caps_name )
391
- callable_maker_ptr = ctypes .PYFUNCTYPE (fn_restype , ctypes . py_object )
395
+ callable_maker_ptr = ctypes .PYFUNCTYPE (fn_restype , * fn_argtypes )
392
396
return callable_maker_ptr (fn_ptr )
393
397
394
398
@@ -399,6 +403,7 @@ def test_pyx_capi_get_data():
399
403
"UsmNDArray_GetData" ,
400
404
b"char *(struct PyUSMArrayObject *)" ,
401
405
fn_restype = ctypes .c_void_p ,
406
+ fn_argtypes = (ctypes .py_object ,),
402
407
)
403
408
r1 = get_data_fn (X )
404
409
sua_iface = X .__sycl_usm_array_interface__
@@ -412,6 +417,7 @@ def test_pyx_capi_get_shape():
412
417
"UsmNDArray_GetShape" ,
413
418
b"Py_ssize_t *(struct PyUSMArrayObject *)" ,
414
419
fn_restype = ctypes .c_void_p ,
420
+ fn_argtypes = (ctypes .py_object ,),
415
421
)
416
422
c_longlong_p = ctypes .POINTER (ctypes .c_longlong )
417
423
shape0 = ctypes .cast (get_shape_fn (X ), c_longlong_p ).contents .value
@@ -425,6 +431,7 @@ def test_pyx_capi_get_strides():
425
431
"UsmNDArray_GetStrides" ,
426
432
b"Py_ssize_t *(struct PyUSMArrayObject *)" ,
427
433
fn_restype = ctypes .c_void_p ,
434
+ fn_argtypes = (ctypes .py_object ,),
428
435
)
429
436
c_longlong_p = ctypes .POINTER (ctypes .c_longlong )
430
437
strides0_p = get_strides_fn (X )
@@ -441,6 +448,7 @@ def test_pyx_capi_get_ndim():
441
448
"UsmNDArray_GetNDim" ,
442
449
b"int (struct PyUSMArrayObject *)" ,
443
450
fn_restype = ctypes .c_int ,
451
+ fn_argtypes = (ctypes .py_object ,),
444
452
)
445
453
assert get_ndim_fn (X ) == X .ndim
446
454
@@ -452,6 +460,7 @@ def test_pyx_capi_get_typenum():
452
460
"UsmNDArray_GetTypenum" ,
453
461
b"int (struct PyUSMArrayObject *)" ,
454
462
fn_restype = ctypes .c_int ,
463
+ fn_argtypes = (ctypes .py_object ,),
455
464
)
456
465
typenum = get_typenum_fn (X )
457
466
assert type (typenum ) is int
@@ -465,6 +474,7 @@ def test_pyx_capi_get_elemsize():
465
474
"UsmNDArray_GetElementSize" ,
466
475
b"int (struct PyUSMArrayObject *)" ,
467
476
fn_restype = ctypes .c_int ,
477
+ fn_argtypes = (ctypes .py_object ,),
468
478
)
469
479
itemsize = get_elemsize_fn (X )
470
480
assert type (itemsize ) is int
@@ -478,6 +488,7 @@ def test_pyx_capi_get_flags():
478
488
"UsmNDArray_GetFlags" ,
479
489
b"int (struct PyUSMArrayObject *)" ,
480
490
fn_restype = ctypes .c_int ,
491
+ fn_argtypes = (ctypes .py_object ,),
481
492
)
482
493
flags = get_flags_fn (X )
483
494
assert type (flags ) is int and X .flags == flags
@@ -490,6 +501,7 @@ def test_pyx_capi_get_offset():
490
501
"UsmNDArray_GetOffset" ,
491
502
b"Py_ssize_t (struct PyUSMArrayObject *)" ,
492
503
fn_restype = ctypes .c_longlong ,
504
+ fn_argtypes = (ctypes .py_object ,),
493
505
)
494
506
offset = get_offset_fn (X )
495
507
assert type (offset ) is int
@@ -503,11 +515,123 @@ def test_pyx_capi_get_queue_ref():
503
515
"UsmNDArray_GetQueueRef" ,
504
516
b"DPCTLSyclQueueRef (struct PyUSMArrayObject *)" ,
505
517
fn_restype = ctypes .c_void_p ,
518
+ fn_argtypes = (ctypes .py_object ,),
506
519
)
507
520
queue_ref = get_queue_ref_fn (X ) # address of a copy, should be unequal
508
521
assert queue_ref != X .sycl_queue .addressof_ref ()
509
522
510
523
524
+ def test_pyx_capi_make_from_memory ():
525
+ q = get_queue_or_skip ()
526
+ n0 , n1 = 4 , 6
527
+ c_tuple = (ctypes .c_ssize_t * 2 )(n0 , n1 )
528
+ mem = dpm .MemoryUSMShared (n0 * n1 * 4 , queue = q )
529
+ typenum = dpt .dtype ("single" ).num
530
+ any_usm_ndarray = dpt .empty (tuple (), dtype = "i4" , sycl_queue = q )
531
+ make_from_memory_fn = _pyx_capi_fnptr_to_callable (
532
+ any_usm_ndarray ,
533
+ "UsmNDArray_MakeFromMemory" ,
534
+ b"PyObject *(int, Py_ssize_t const *, int, "
535
+ b"struct Py_MemoryObject *, Py_ssize_t, char)" ,
536
+ fn_restype = ctypes .py_object ,
537
+ fn_argtypes = (
538
+ ctypes .c_int ,
539
+ ctypes .POINTER (ctypes .c_ssize_t ),
540
+ ctypes .c_int ,
541
+ ctypes .py_object ,
542
+ ctypes .c_ssize_t ,
543
+ ctypes .c_char ,
544
+ ),
545
+ )
546
+ r = make_from_memory_fn (
547
+ ctypes .c_int (2 ),
548
+ c_tuple ,
549
+ ctypes .c_int (typenum ),
550
+ mem ,
551
+ ctypes .c_ssize_t (0 ),
552
+ ctypes .c_char (b"C" ),
553
+ )
554
+ assert isinstance (r , dpt .usm_ndarray )
555
+ assert r .ndim == 2
556
+ assert r .shape == (n0 , n1 )
557
+ assert r ._pointer == mem ._pointer
558
+ assert r .usm_type == "shared"
559
+ assert r .sycl_queue == q
560
+ assert r .flags ["C" ]
561
+ r2 = make_from_memory_fn (
562
+ ctypes .c_int (2 ),
563
+ c_tuple ,
564
+ ctypes .c_int (typenum ),
565
+ mem ,
566
+ ctypes .c_ssize_t (0 ),
567
+ ctypes .c_char (b"F" ),
568
+ )
569
+ ptr = mem ._pointer
570
+ del mem
571
+ del r
572
+ assert isinstance (r2 , dpt .usm_ndarray )
573
+ assert r2 ._pointer == ptr
574
+ assert r2 .usm_type == "shared"
575
+ assert r2 .sycl_queue == q
576
+ assert r2 .flags ["F" ]
577
+
578
+
579
+ def test_pyx_capi_set_writable_flag ():
580
+ q = get_queue_or_skip ()
581
+ usm_ndarray = dpt .empty ((4 , 5 ), dtype = "i4" , sycl_queue = q )
582
+ assert isinstance (usm_ndarray , dpt .usm_ndarray )
583
+ assert usm_ndarray .flags ["WRITABLE" ] is True
584
+ set_writable = _pyx_capi_fnptr_to_callable (
585
+ usm_ndarray ,
586
+ "UsmNDArray_SetWritableFlag" ,
587
+ b"void (struct PyUSMArrayObject *, int)" ,
588
+ fn_restype = None ,
589
+ fn_argtypes = (ctypes .py_object , ctypes .c_int ),
590
+ )
591
+ set_writable (usm_ndarray , ctypes .c_int (0 ))
592
+ assert isinstance (usm_ndarray , dpt .usm_ndarray )
593
+ assert usm_ndarray .flags ["WRITABLE" ] is False
594
+ set_writable (usm_ndarray , ctypes .c_int (1 ))
595
+ assert isinstance (usm_ndarray , dpt .usm_ndarray )
596
+ assert usm_ndarray .flags ["WRITABLE" ] is True
597
+
598
+
599
+ def test_pyx_capi_make_from_ptr ():
600
+ q = get_queue_or_skip ()
601
+ usm_ndarray = dpt .empty (tuple (), dtype = "i4" , sycl_queue = q )
602
+ make_from_ptr = _pyx_capi_fnptr_to_callable (
603
+ usm_ndarray ,
604
+ "UsmNDArray_MakeFromPtr" ,
605
+ b"PyObject *(size_t, int, DPCTLSyclUSMRef, "
606
+ b"DPCTLSyclQueueRef, PyObject *)" ,
607
+ fn_restype = ctypes .py_object ,
608
+ fn_argtypes = (
609
+ ctypes .c_size_t ,
610
+ ctypes .c_int ,
611
+ ctypes .c_void_p ,
612
+ ctypes .c_void_p ,
613
+ ctypes .py_object ,
614
+ ),
615
+ )
616
+ nelems = 10
617
+ dt = dpt .int64
618
+ mem = dpm .MemoryUSMDevice (nelems * dt .itemsize , queue = q )
619
+ arr = make_from_ptr (
620
+ ctypes .c_size_t (nelems ),
621
+ dt .num ,
622
+ mem ._pointer ,
623
+ mem .sycl_queue .addressof_ref (),
624
+ mem ,
625
+ )
626
+ assert isinstance (arr , dpt .usm_ndarray )
627
+ assert arr .shape == (nelems ,)
628
+ assert arr .dtype == dt
629
+ assert arr .sycl_queue == q
630
+ assert arr ._pointer == mem ._pointer
631
+ del mem
632
+ assert isinstance (arr .__repr__ (), str )
633
+
634
+
511
635
def _pyx_capi_int (X , pyx_capi_name , caps_name = b"int" , val_restype = ctypes .c_int ):
512
636
import sys
513
637
0 commit comments