From ed6c3984dca887f07c2c52be4eef791853711c46 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Tue, 16 Jul 2024 19:06:51 +0200 Subject: [PATCH 1/2] Align with dpt.where() which support scalar as input --- dpnp/dpnp_iface_searching.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/dpnp/dpnp_iface_searching.py b/dpnp/dpnp_iface_searching.py index b031ecab3559..5e96c3bcbfeb 100644 --- a/dpnp/dpnp_iface_searching.py +++ b/dpnp/dpnp_iface_searching.py @@ -395,13 +395,6 @@ def where(condition, x=None, y=None, /, *, order="K", out=None): usm_y = dpnp.get_usm_ndarray_or_scalar(y) usm_condition = dpnp.get_usm_ndarray(condition) - usm_type, queue = get_usm_allocations([condition, x, y]) - if dpnp.isscalar(usm_x): - usm_x = dpt.asarray(usm_x, usm_type=usm_type, sycl_queue=queue) - - if dpnp.isscalar(usm_y): - usm_y = dpt.asarray(usm_y, usm_type=usm_type, sycl_queue=queue) - usm_out = None if out is None else dpnp.get_usm_ndarray(out) result = dpnp_array._create_from_usm_ndarray( dpt.where(usm_condition, usm_x, usm_y, order=order, out=usm_out) From 9f001f44cfd810a3cd4eb46197fa8ba52c4b453f Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Tue, 16 Jul 2024 19:12:43 +0200 Subject: [PATCH 2/2] Removed unused import of get_usm_allocations --- dpnp/dpnp_iface_searching.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/dpnp/dpnp_iface_searching.py b/dpnp/dpnp_iface_searching.py index 5e96c3bcbfeb..88173b79ed72 100644 --- a/dpnp/dpnp_iface_searching.py +++ b/dpnp/dpnp_iface_searching.py @@ -37,17 +37,12 @@ """ -# pylint: disable=no-name-in-module - import dpctl.tensor as dpt import dpctl.tensor._tensor_impl as dti import dpnp from .dpnp_array import dpnp_array -from .dpnp_utils import ( - get_usm_allocations, -) from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call __all__ = ["argmax", "argmin", "searchsorted", "where"]