|
13 | 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
| 16 | +import operator |
| 17 | + |
16 | 18 | import numpy as np
|
| 19 | +from numpy.core.numeric import normalize_axis_index |
17 | 20 |
|
| 21 | +import dpctl |
18 | 22 | import dpctl.memory as dpm
|
19 | 23 | import dpctl.tensor as dpt
|
20 | 24 | import dpctl.tensor._tensor_impl as ti
|
| 25 | +import dpctl.utils |
21 | 26 | from dpctl.tensor._device import normalize_queue_device
|
22 | 27 |
|
23 | 28 | __doc__ = (
|
@@ -382,3 +387,227 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
|
382 | 387 | )
|
383 | 388 | _copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
|
384 | 389 | return R
|
| 390 | + |
| 391 | + |
| 392 | +def _extract_impl(ary, ary_mask, axis=0): |
| 393 | + """Extract elements of ary by applying mask starting from slot |
| 394 | + dimension axis""" |
| 395 | + if not isinstance(ary, dpt.usm_ndarray): |
| 396 | + raise TypeError( |
| 397 | + f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}" |
| 398 | + ) |
| 399 | + if not isinstance(ary_mask, dpt.usm_ndarray): |
| 400 | + raise TypeError( |
| 401 | + f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}" |
| 402 | + ) |
| 403 | + exec_q = dpctl.utils.get_execution_queue( |
| 404 | + (ary.sycl_queue, ary_mask.sycl_queue) |
| 405 | + ) |
| 406 | + if exec_q is None: |
| 407 | + raise dpctl.utils.ExecutionPlacementError( |
| 408 | + "arrays have different associated queues. " |
| 409 | + "Use `Y.to_device(X.device)` to migrate." |
| 410 | + ) |
| 411 | + ary_nd = ary.ndim |
| 412 | + pp = normalize_axis_index(operator.index(axis), ary_nd) |
| 413 | + mask_nd = ary_mask.ndim |
| 414 | + if pp < 0 or pp + mask_nd > ary_nd: |
| 415 | + raise ValueError( |
| 416 | + "Parameter p is inconsistent with input array dimensions" |
| 417 | + ) |
| 418 | + mask_nelems = ary_mask.size |
| 419 | + cumsum = dpt.empty(mask_nelems, dtype=dpt.int64, device=ary_mask.device) |
| 420 | + exec_q = cumsum.sycl_queue |
| 421 | + mask_count = ti.mask_positions(ary_mask, cumsum, sycl_queue=exec_q) |
| 422 | + dst_shape = ary.shape[:pp] + (mask_count,) + ary.shape[pp + mask_nd :] |
| 423 | + dst = dpt.empty( |
| 424 | + dst_shape, dtype=ary.dtype, usm_type=ary.usm_type, device=ary.device |
| 425 | + ) |
| 426 | + hev, _ = ti._extract( |
| 427 | + src=ary, |
| 428 | + cumsum=cumsum, |
| 429 | + axis_start=pp, |
| 430 | + axis_end=pp + mask_nd, |
| 431 | + dst=dst, |
| 432 | + sycl_queue=exec_q, |
| 433 | + ) |
| 434 | + hev.wait() |
| 435 | + return dst |
| 436 | + |
| 437 | + |
| 438 | +def _nonzero_impl(ary): |
| 439 | + if not isinstance(ary, dpt.usm_ndarray): |
| 440 | + raise TypeError( |
| 441 | + f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}" |
| 442 | + ) |
| 443 | + exec_q = ary.sycl_queue |
| 444 | + usm_type = ary.usm_type |
| 445 | + mask_nelems = ary.size |
| 446 | + cumsum = dpt.empty( |
| 447 | + mask_nelems, dtype=dpt.int64, sycl_queue=exec_q, order="C" |
| 448 | + ) |
| 449 | + mask_count = ti.mask_positions(ary, cumsum, sycl_queue=exec_q) |
| 450 | + indexes = dpt.empty( |
| 451 | + (ary.ndim, mask_count), |
| 452 | + dtype=cumsum.dtype, |
| 453 | + usm_type=usm_type, |
| 454 | + sycl_queue=exec_q, |
| 455 | + order="C", |
| 456 | + ) |
| 457 | + hev, _ = ti._nonzero(cumsum, indexes, ary.shape, exec_q) |
| 458 | + res = tuple(indexes[i, :] for i in range(ary.ndim)) |
| 459 | + hev.wait() |
| 460 | + return res |
| 461 | + |
| 462 | + |
| 463 | +def _take_multi_index(ary, inds, p): |
| 464 | + if not isinstance(ary, dpt.usm_ndarray): |
| 465 | + raise TypeError |
| 466 | + queues_ = [ |
| 467 | + ary.sycl_queue, |
| 468 | + ] |
| 469 | + usm_types_ = [ |
| 470 | + ary.usm_type, |
| 471 | + ] |
| 472 | + if not isinstance(inds, list) and not isinstance(inds, tuple): |
| 473 | + inds = (inds,) |
| 474 | + all_integers = True |
| 475 | + for ind in inds: |
| 476 | + queues_.append(ind.sycl_queue) |
| 477 | + usm_types_.append(ind.usm_type) |
| 478 | + if all_integers: |
| 479 | + all_integers = ind.dtype.kind in "ui" |
| 480 | + exec_q = dpctl.utils.get_execution_queue(queues_) |
| 481 | + if exec_q is None: |
| 482 | + raise dpctl.utils.ExecutionPlacementError("") |
| 483 | + if not all_integers: |
| 484 | + raise IndexError( |
| 485 | + "arrays used as indices must be of integer (or boolean) type" |
| 486 | + ) |
| 487 | + if len(inds) > 1: |
| 488 | + inds = dpt.broadcast_arrays(*inds) |
| 489 | + ary_ndim = ary.ndim |
| 490 | + p = normalize_axis_index(operator.index(p), ary_ndim) |
| 491 | + |
| 492 | + res_shape = ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :] |
| 493 | + res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_) |
| 494 | + res = dpt.empty( |
| 495 | + res_shape, dtype=ary.dtype, usm_type=res_usm_type, sycl_queue=exec_q |
| 496 | + ) |
| 497 | + |
| 498 | + hev, _ = ti._take( |
| 499 | + src=ary, ind=inds, dst=res, axis_start=p, mode=0, sycl_queue=exec_q |
| 500 | + ) |
| 501 | + hev.wait() |
| 502 | + |
| 503 | + return res |
| 504 | + |
| 505 | + |
| 506 | +def _place_impl(ary, ary_mask, vals, axis=0): |
| 507 | + """Extract elements of ary by applying mask starting from slot |
| 508 | + dimension axis""" |
| 509 | + if not isinstance(ary, dpt.usm_ndarray): |
| 510 | + raise TypeError( |
| 511 | + f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}" |
| 512 | + ) |
| 513 | + if not isinstance(ary_mask, dpt.usm_ndarray): |
| 514 | + raise TypeError( |
| 515 | + f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}" |
| 516 | + ) |
| 517 | + if not isinstance(vals, dpt.usm_ndarray): |
| 518 | + raise TypeError( |
| 519 | + f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}" |
| 520 | + ) |
| 521 | + exec_q = dpctl.utils.get_execution_queue( |
| 522 | + (ary.sycl_queue, ary_mask.sycl_queue, vals.sycl_queue) |
| 523 | + ) |
| 524 | + if exec_q is None: |
| 525 | + raise dpctl.utils.ExecutionPlacementError( |
| 526 | + "arrays have different associated queues. " |
| 527 | + "Use `Y.to_device(X.device)` to migrate." |
| 528 | + ) |
| 529 | + ary_nd = ary.ndim |
| 530 | + pp = normalize_axis_index(operator.index(axis), ary_nd) |
| 531 | + mask_nd = ary_mask.ndim |
| 532 | + if pp < 0 or pp + mask_nd > ary_nd: |
| 533 | + raise ValueError( |
| 534 | + "Parameter p is inconsistent with input array dimensions" |
| 535 | + ) |
| 536 | + mask_nelems = ary_mask.size |
| 537 | + cumsum = dpt.empty(mask_nelems, dtype=dpt.int64, device=ary_mask.device) |
| 538 | + exec_q = cumsum.sycl_queue |
| 539 | + mask_count = ti.mask_positions(ary_mask, cumsum, sycl_queue=exec_q) |
| 540 | + expected_vals_shape = ( |
| 541 | + ary.shape[:pp] + (mask_count,) + ary.shape[pp + mask_nd :] |
| 542 | + ) |
| 543 | + if vals.dtype == ary.dtype: |
| 544 | + rhs = vals |
| 545 | + else: |
| 546 | + rhs = dpt.astype(vals, ary.dtype) |
| 547 | + rhs = dpt.broadcast_to(rhs, expected_vals_shape) |
| 548 | + hev, _ = ti._place( |
| 549 | + dst=ary, |
| 550 | + cumsum=cumsum, |
| 551 | + axis_start=pp, |
| 552 | + axis_end=pp + mask_nd, |
| 553 | + rhs=rhs, |
| 554 | + sycl_queue=exec_q, |
| 555 | + ) |
| 556 | + hev.wait() |
| 557 | + return |
| 558 | + |
| 559 | + |
| 560 | +def _put_multi_index(ary, inds, p, vals): |
| 561 | + if isinstance(vals, dpt.usm_ndarray): |
| 562 | + queues_ = [ary.sycl_queue, vals.sycl_queue] |
| 563 | + usm_types_ = [ary.usm_type, vals.usm_type] |
| 564 | + else: |
| 565 | + queues_ = [ |
| 566 | + ary.sycl_queue, |
| 567 | + ] |
| 568 | + usm_types_ = [ |
| 569 | + ary.usm_type, |
| 570 | + ] |
| 571 | + if not isinstance(inds, list) and not isinstance(inds, tuple): |
| 572 | + inds = (inds,) |
| 573 | + all_integers = True |
| 574 | + for ind in inds: |
| 575 | + if not isinstance(ind, dpt.usm_ndarray): |
| 576 | + raise TypeError |
| 577 | + queues_.append(ind.sycl_queue) |
| 578 | + usm_types_.append(ind.usm_type) |
| 579 | + if all_integers: |
| 580 | + all_integers = ind.dtype.kind in "ui" |
| 581 | + exec_q = dpctl.utils.get_execution_queue(queues_) |
| 582 | + if exec_q is None: |
| 583 | + raise dpctl.utils.ExecutionPlacementError( |
| 584 | + "Can not automatically determine where to allocate the " |
| 585 | + "result or performance execution. " |
| 586 | + "Use `usm_ndarray.to_device` method to migrate data to " |
| 587 | + "be associated with the same queue." |
| 588 | + ) |
| 589 | + if not all_integers: |
| 590 | + raise IndexError( |
| 591 | + "arrays used as indices must be of integer (or boolean) type" |
| 592 | + ) |
| 593 | + if len(inds) > 1: |
| 594 | + inds = dpt.broadcast_arrays(*inds) |
| 595 | + ary_ndim = ary.ndim |
| 596 | + |
| 597 | + p = normalize_axis_index(operator.index(p), ary_ndim) |
| 598 | + vals_shape = ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :] |
| 599 | + |
| 600 | + vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_) |
| 601 | + if not isinstance(vals, dpt.usm_ndarray): |
| 602 | + vals = dpt.asarray( |
| 603 | + vals, ary.dtype, usm_type=vals_usm_type, sycl_queue=exec_q |
| 604 | + ) |
| 605 | + |
| 606 | + vals = dpt.broadcast_to(vals, vals_shape) |
| 607 | + |
| 608 | + hev, _ = ti._put( |
| 609 | + dst=ary, ind=inds, val=vals, axis_start=p, mode=0, sycl_queue=exec_q |
| 610 | + ) |
| 611 | + hev.wait() |
| 612 | + |
| 613 | + return |
0 commit comments