@@ -389,45 +389,75 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
389
389
return R
390
390
391
391
392
- def _mock_extract (ary , ary_mask , p ):
393
- exec_q = dpctl .utils .get_execution_queue (
394
- (
395
- ary .sycl_queue ,
396
- ary_mask .sycl_queue ,
392
+ def _extract_impl (ary , ary_mask , axis = 0 ):
393
+ """Extract elements of ary by applying mask starting from slot
394
+ dimension axis"""
395
+ if not isinstance (ary , dpt .usm_ndarray ):
396
+ raise TypeError (
397
+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
398
+ )
399
+ if not isinstance (ary_mask , dpt .usm_ndarray ):
400
+ raise TypeError (
401
+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary_mask )} "
397
402
)
403
+ exec_q = dpctl .utils .get_execution_queue (
404
+ (ary .sycl_queue , ary_mask .sycl_queue )
398
405
)
399
406
if exec_q is None :
400
407
raise dpctl .utils .ExecutionPlacementError (
401
- "Can not automatically determine where to allocate the "
402
- "result or performance execution. "
403
- "Use `usm_ndarray.to_device` method to migrate data to "
404
- "be associated with the same queue."
408
+ "arrays have different associated queues. "
409
+ "Use `Y.to_device(X.device)` to migrate."
405
410
)
406
-
407
- res_usm_type = dpctl .utils .get_coerced_usm_type (
408
- (
409
- ary .usm_type ,
410
- ary_mask .usm_type ,
411
+ ary_nd = ary .ndim
412
+ pp = normalize_axis_index (operator .index (axis ), ary_nd )
413
+ mask_nd = ary_mask .ndim
414
+ if pp < 0 or pp + mask_nd > ary_nd :
415
+ raise ValueError (
416
+ "Parameter p is inconsistent with input array dimensions"
411
417
)
418
+ mask_nelems = ary_mask .size
419
+ cumsum = dpt .empty (mask_nelems , dtype = dpt .int64 , device = ary_mask .device )
420
+ exec_q = cumsum .sycl_queue
421
+ mask_count = ti .mask_positions (ary_mask , cumsum , sycl_queue = exec_q )
422
+ dst_shape = ary .shape [:pp ] + (mask_count ,) + ary .shape [pp + mask_nd :]
423
+ dst = dpt .empty (
424
+ dst_shape , dtype = ary .dtype , usm_type = ary .usm_type , device = ary .device
412
425
)
413
- ary_np = dpt .asnumpy (ary )
414
- mask_np = dpt .asnumpy (ary_mask )
415
- res_np = ary_np [(slice (None ),) * p + (mask_np ,)]
416
- res = dpt .empty (
417
- res_np .shape , dtype = ary .dtype , usm_type = res_usm_type , sycl_queue = exec_q
426
+ hev , _ = ti ._extract (
427
+ src = ary ,
428
+ cumsum = cumsum ,
429
+ axis_start = pp ,
430
+ axis_end = pp + mask_nd ,
431
+ dst = dst ,
432
+ sycl_queue = exec_q ,
418
433
)
419
- res [...] = res_np
420
- return res
434
+ hev . wait ()
435
+ return dst
421
436
422
437
423
- def _mock_nonzero (ary ):
438
+ def _nonzero_impl (ary ):
424
439
if not isinstance (ary , dpt .usm_ndarray ):
425
- raise TypeError
426
- q = ary .sycl_queue
440
+ raise TypeError (
441
+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
442
+ )
443
+ exec_q = ary .sycl_queue
427
444
usm_type = ary .usm_type
428
- ary_np = dpt .asnumpy (ary )
429
- nz = ary_np .nonzero ()
430
- return tuple (dpt .asarray (i , usm_type = usm_type , sycl_queue = q ) for i in nz )
445
+ mask_nelems = ary .size
446
+ cumsum = dpt .empty (
447
+ mask_nelems , dtype = dpt .int64 , sycl_queue = exec_q , order = "C"
448
+ )
449
+ mask_count = ti .mask_positions (ary , cumsum , sycl_queue = exec_q )
450
+ indexes = dpt .empty (
451
+ (ary .ndim , mask_count ),
452
+ dtype = cumsum .dtype ,
453
+ usm_type = usm_type ,
454
+ sycl_queue = exec_q ,
455
+ order = "C" ,
456
+ )
457
+ hev , _ = ti ._nonzero (cumsum , indexes , ary .shape , exec_q )
458
+ res = tuple (indexes [i , :] for i in range (ary .ndim ))
459
+ hev .wait ()
460
+ return res
431
461
432
462
433
463
def _take_multi_index (ary , inds , p ):
@@ -473,34 +503,57 @@ def _take_multi_index(ary, inds, p):
473
503
return res
474
504
475
505
476
- def _mock_place (ary , ary_mask , p , vals ):
506
+ def _place_impl (ary , ary_mask , vals , axis = 0 ):
507
+ """Extract elements of ary by applying mask starting from slot
508
+ dimension axis"""
477
509
if not isinstance (ary , dpt .usm_ndarray ):
478
- raise TypeError
510
+ raise TypeError (
511
+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
512
+ )
479
513
if not isinstance (ary_mask , dpt .usm_ndarray ):
480
- raise TypeError
514
+ raise TypeError (
515
+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary_mask )} "
516
+ )
517
+ if not isinstance (vals , dpt .usm_ndarray ):
518
+ raise TypeError (
519
+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary_mask )} "
520
+ )
481
521
exec_q = dpctl .utils .get_execution_queue (
482
- (ary .sycl_queue , ary_mask .sycl_queue )
522
+ (ary .sycl_queue , ary_mask .sycl_queue , vals . sycl_queue )
483
523
)
484
- if exec_q is not None and isinstance (vals , dpt .usm_ndarray ):
485
- exec_q = dpctl .utils .get_execution_queue ((exec_q , vals .sycl_queue ))
486
524
if exec_q is None :
487
525
raise dpctl .utils .ExecutionPlacementError (
488
- "Can not automatically determine where to allocate the "
489
- "result or performance execution. "
490
- "Use `usm_ndarray.to_device` method to migrate data to "
491
- "be associated with the same queue."
526
+ "arrays have different associated queues. "
527
+ "Use `Y.to_device(X.device)` to migrate."
492
528
)
493
-
494
- ary_np = dpt .asnumpy (ary )
495
- mask_np = dpt .asnumpy (ary_mask )
496
- if isinstance (vals , dpt .usm_ndarray ) or hasattr (
497
- vals , "__sycl_usm_array_interface__"
498
- ):
499
- vals_np = dpt .asnumpy (vals )
529
+ ary_nd = ary .ndim
530
+ pp = normalize_axis_index (operator .index (axis ), ary_nd )
531
+ mask_nd = ary_mask .ndim
532
+ if pp < 0 or pp + mask_nd > ary_nd :
533
+ raise ValueError (
534
+ "Parameter p is inconsistent with input array dimensions"
535
+ )
536
+ mask_nelems = ary_mask .size
537
+ cumsum = dpt .empty (mask_nelems , dtype = dpt .int64 , device = ary_mask .device )
538
+ exec_q = cumsum .sycl_queue
539
+ mask_count = ti .mask_positions (ary_mask , cumsum , sycl_queue = exec_q )
540
+ expected_vals_shape = (
541
+ ary .shape [:pp ] + (mask_count ,) + ary .shape [pp + mask_nd :]
542
+ )
543
+ if vals .dtype == ary .dtype :
544
+ rhs = vals
500
545
else :
501
- vals_np = vals
502
- ary_np [(slice (None ),) * p + (mask_np ,)] = vals_np
503
- ary [...] = ary_np
546
+ rhs = dpt .astype (vals , ary .dtype )
547
+ rhs = dpt .broadcast_to (rhs , expected_vals_shape )
548
+ hev , _ = ti ._place (
549
+ dst = ary ,
550
+ cumsum = cumsum ,
551
+ axis_start = pp ,
552
+ axis_end = pp + mask_nd ,
553
+ rhs = rhs ,
554
+ sycl_queue = exec_q ,
555
+ )
556
+ hev .wait ()
504
557
return
505
558
506
559
0 commit comments