From 2d2bfc212d73b1b9fdf438a77adcfcc039db9c3b Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 8 May 2023 16:49:24 -0500 Subject: [PATCH 01/48] Initial check-in of elementwise functions: abs/cos/isnan/add --- dpctl/tensor/CMakeLists.txt | 8 + dpctl/tensor/__init__.py | 5 + dpctl/tensor/_elementwise_common.py | 522 +++++++++++++++ dpctl/tensor/_elementwise_funcs.py | 53 ++ dpctl/tensor/_type_utils.py | 155 +++++ .../kernels/elementwise_functions/abs.hpp | 322 +++++++++ .../kernels/elementwise_functions/add.hpp | 504 ++++++++++++++ .../kernels/elementwise_functions/common.hpp | 0 .../kernels/elementwise_functions/cos.hpp | 291 +++++++++ .../kernels/elementwise_functions/isnan.hpp | 296 +++++++++ .../libtensor/include/utils/type_dispatch.hpp | 83 +++ .../libtensor/include/utils/type_utils.hpp | 16 + .../source/elementwise_functions.cpp | 523 +++++++++++++++ .../source/elementwise_functions.hpp | 614 ++++++++++++++++++ dpctl/tensor/libtensor/source/tensor_py.cpp | 3 + dpctl/tests/test_tensor_elementwise.py | 377 +++++++++++ 16 files changed, 3772 insertions(+) create mode 100644 dpctl/tensor/_elementwise_common.py create mode 100644 dpctl/tensor/_elementwise_funcs.py create mode 100644 dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp create mode 100644 dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp create mode 100644 dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp create mode 100644 dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp create mode 100644 dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp create mode 100644 dpctl/tensor/libtensor/source/elementwise_functions.cpp create mode 100644 dpctl/tensor/libtensor/source/elementwise_functions.hpp create mode 100644 dpctl/tests/test_tensor_elementwise.py diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index 9a79830e47..40bb7cd7da 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -45,7 +45,15 @@ pybind11_add_module(${python_module_name} MODULE ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp ) +set(_clang_prefix "") +if (WIN32) + set(_clang_prefix "/clang:") +endif() +set_source_files_properties( + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp + PROPERTIES COMPILE_OPTIONS "${_clang_prefx}-fno-approx-func") target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int) target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel) if(UNIX) diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index ec9f27617b..3112e9c7b6 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -90,6 +90,7 @@ from dpctl.tensor._usmarray import usm_ndarray from ._constants import e, inf, nan, newaxis, pi +from ._elementwise_funcs import abs, add, cos, isnan __all__ = [ "Device", @@ -164,4 +165,8 @@ "pi", "nan", "inf", + "abs", + "add", + "cos", + "isnan", ] diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py new file mode 100644 index 0000000000..b0e147c436 --- /dev/null +++ b/dpctl/tensor/_elementwise_common.py @@ -0,0 +1,522 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers + +import numpy as np + +import dpctl +import dpctl.memory as dpm +import dpctl.tensor as dpt +import dpctl.tensor._tensor_impl as ti +from dpctl.tensor._manipulation_functions import _broadcast_shape_impl +from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer +from dpctl.utils import ExecutionPlacementError + +from ._type_utils import ( + _empty_like_orderK, + _empty_like_pair_orderK, + _find_buf_dtype, + _find_buf_dtype2, + _to_device_supported_dtype, +) + + +class UnaryElementwiseFunc: + """ + Class that implements unary element-wise functions. + """ + + def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs): + self.__name__ = "UnaryElementwiseFunc" + self.name_ = name + self.result_type_resolver_fn_ = result_type_resolver_fn + self.unary_fn_ = unary_dp_impl_fn + self.__doc__ = docs + + def __call__(self, x, order="K"): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError( + f"Expected :class:`dpctl.tensor.usm_ndarray`, got {type(x)}" + ) + if order not in ["C", "F", "K", "A"]: + order = "K" + buf_dt, res_dt = _find_buf_dtype( + x.dtype, self.result_type_resolver_fn_, x.sycl_device + ) + if res_dt is None: + raise RuntimeError + exec_q = x.sycl_queue + if buf_dt is None: + if order == "K": + r = _empty_like_orderK(x, res_dt) + else: + if order == "A": + order = "F" if x.flags.f_contiguous else "C" + r = dpt.empty_like(x, dtype=res_dt, order=order) + + ht, _ = self.unary_fn_(x, r, sycl_queue=exec_q) + ht.wait() + + return r + if order == "K": + buf = _empty_like_orderK(x, buf_dt) + else: + if order == "A": + order = "F" if x.flags.f_contiguous else "C" + buf = dpt.empty_like(x, dtype=buf_dt, order=order) + + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x, dst=buf, sycl_queue=exec_q + ) + if order == "K": + r = _empty_like_orderK(buf, res_dt) + else: + if order == "A": + order = "F" if buf.flags.f_contiguous else "C" + r = dpt.empty_like(buf, dtype=res_dt, order=order) + + ht, _ = self.unary_fn_(buf, r, sycl_queue=exec_q, depends=[copy_ev]) + ht.wait() + + return r + + +def _get_queue_usm_type(o): + """Return SYCL device where object `o` allocated memory, or None.""" + if isinstance(o, dpt.usm_ndarray): + return o.sycl_queue, o.usm_type + elif hasattr(o, "__sycl_usm_array_interface__"): + try: + m = dpm.as_usm_memory(o) + return m.sycl_queue, m.get_usm_type() + except Exception: + return None, None + return None, None + + +class WeakBooleanType: + "Python type representing type of Python boolean objects" + + def __init__(self, o): + self.o_ = o + + def get(self): + return self.o_ + + +class WeakIntegralType: + "Python type representing type of Python integral objects" + + def __init__(self, o): + self.o_ = o + + def get(self): + return self.o_ + + +class WeakInexactType: + """Python type representing type of Python real- or + complex-valued floating point objects""" + + def __init__(self, o): + self.o_ = o + + def get(self): + return self.o_ + + +def _get_dtype(o, dev): + if isinstance(o, dpt.usm_ndarray): + return o.dtype + if _is_buffer(o): + host_dt = np.array(o).dtype + dev_dt = _to_device_supported_dtype(host_dt, dev) + return dev_dt + if hasattr(o, "dtype"): + dev_dt = _to_device_supported_dtype(o.dtype, dev) + return dev_dt + if isinstance(o, bool): + return WeakBooleanType(o) + if isinstance(o, int): + return WeakIntegralType(o) + if isinstance(o, (float, complex)): + return WeakInexactType(o) + return np.object_ + + +def _validate_dtype(dt) -> bool: + return isinstance( + dt, (WeakBooleanType, WeakInexactType, WeakIntegralType) + ) or ( + isinstance(dt, dpt.dtype) + and dt + in [ + dpt.bool, + dpt.int8, + dpt.uint8, + dpt.int16, + dpt.uint16, + dpt.int32, + dpt.uint32, + dpt.int64, + dpt.uint64, + dpt.float16, + dpt.float32, + dpt.float64, + dpt.complex64, + dpt.complex128, + ] + ) + + +def _weak_type_num_kind(o): + _map = {"?": 0, "i": 1, "f": 2} + if isinstance(o, WeakBooleanType): + return _map["?"] + if isinstance(o, WeakIntegralType): + return _map["i"] + if isinstance(o, WeakInexactType): + return _map["f"] + raise TypeError + + +def _strong_dtype_num_kind(o): + _map = {"?": 0, "i": 1, "u": 1, "f": 2, "c": 2} + if not isinstance(o, dpt.dtype): + raise TypeError + k = o.kind + if k in _map: + return _map[k] + raise ValueError + + +def _resolve_weak_types(o1_dtype, o2_dtype, dev): + "Resolves weak data type per NEP-0050" + if isinstance( + o1_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType) + ): + if isinstance( + o2_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType) + ): + raise ValueError + o1_kind_num = _weak_type_num_kind(o1_dtype) + o2_kind_num = _strong_dtype_num_kind(o2_dtype) + if o1_kind_num > o2_kind_num: + if isinstance(o1_dtype, WeakBooleanType): + return dpt.bool, o2_dtype + if isinstance(o1_dtype, WeakIntegralType): + return dpt.int64, o2_dtype + if isinstance(o1_dtype, WeakInexactType): + if isinstance(o1_dtype.get(), complex): + return ( + _to_device_supported_dtype(dpt.complex128, dev), + o2_dtype, + ) + return _to_device_supported_dtype(dpt.float64, dev), o2_dtype + else: + return o2_dtype, o2_dtype + elif isinstance( + o2_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType) + ): + o1_kind_num = _strong_dtype_num_kind(o1_dtype) + o2_kind_num = _weak_type_num_kind(o2_dtype) + if o2_kind_num > o1_kind_num: + if isinstance(o2_dtype, WeakBooleanType): + return o1_dtype, dpt.bool + if isinstance(o2_dtype, WeakIntegralType): + return o1_dtype, dpt.int64 + if isinstance(o2_dtype, WeakInexactType): + if isinstance(o2_dtype.get(), complex): + return o1_dtype, _to_device_supported_dtype( + dpt.complex128, dev + ) + return ( + o1_dtype, + _to_device_supported_dtype(dpt.float64, dev), + ) + else: + return o1_dtype, o1_dtype + else: + return o1_dtype, o2_dtype + + +def _get_shape(o): + if isinstance(o, dpt.usm_ndarray): + return o.shape + if _is_buffer(o): + return memoryview(o).shape + if isinstance(o, numbers.Number): + return tuple() + return getattr(o, "shape", tuple()) + + +class BinaryElementwiseFunc: + """ + Class that implements binary element-wise functions. + """ + + def __init__(self, name, result_type_resolver_fn, binary_dp_impl_fn, docs): + self.__name__ = "BinaryElementwiseFunc" + self.name_ = name + self.result_type_resolver_fn_ = result_type_resolver_fn + self.binary_fn_ = binary_dp_impl_fn + self.__doc__ = docs + + def __str__(self): + return f"" + + def __repr__(self): + return f"" + + def __call__(self, o1, o2, order="K"): + q1, o1_usm_type = _get_queue_usm_type(o1) + q2, o2_usm_type = _get_queue_usm_type(o2) + if q1 is None and q2 is None: + raise ValueError( + "One of the arguments must represent USM allocation and " + "expose `__sycl_usm_array_interface__` property" + ) + if q1 is None: + exec_q = q2 + res_usm_type = o2_usm_type + elif q2 is None: + exec_q = q1 + res_usm_type = o1_usm_type + else: + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + o1_usm_type, + o2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + o1_shape = _get_shape(o1) + o2_shape = _get_shape(o2) + if not all( + isinstance(s, (tuple, list)) + for s in ( + o1_shape, + o2_shape, + ) + ): + raise TypeError( + "Shape of arguments can not be inferred. " + "Arguments are expected to be " + ) + try: + res_shape = _broadcast_shape_impl( + [ + o1_shape, + o2_shape, + ] + ) + except ValueError: + raise ValueError( + "operands could not be broadcast together with shapes " + f"{o1_shape} and {o2_shape}" + ) + sycl_dev = exec_q.sycl_device + o1_dtype = _get_dtype(o1, sycl_dev) + o2_dtype = _get_dtype(o2, sycl_dev) + if not all(_validate_dtype(o) for o in (o1_dtype, o2_dtype)): + raise ValueError("Operands of unsupported types") + + o1_dtype, o2_dtype = _resolve_weak_types(o1_dtype, o2_dtype, sycl_dev) + + buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( + o1_dtype, o2_dtype, self.result_type_resolver_fn_, sycl_dev + ) + + if res_dt is None: + raise TypeError( + "function 'add' does not support input types " + f"({o1_dtype}, {o2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + + if isinstance(o1, dpt.usm_ndarray): + src1 = o1 + else: + src1 = dpt.asarray(o1, dtype=o1_dtype, sycl_queue=exec_q) + if isinstance(o2, dpt.usm_ndarray): + src2 = o2 + else: + src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q) + + if buf1_dt is None and buf2_dt is None: + if order == "K": + r = _empty_like_pair_orderK( + src1, src2, res_dt, res_usm_type, exec_q + ) + else: + if order == "A": + order = ( + "F" + if all( + arr.flags.f_contiguous + for arr in ( + src1, + src2, + ) + ) + else "C" + ) + r = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + src1 = dpt.broadcast_to(src1, res_shape) + src2 = dpt.broadcast_to(src2, res_shape) + ht_, _ = self.binary_fn_( + src1=src1, src2=src2, dst=r, sycl_queue=exec_q + ) + ht_.wait() + return r + elif buf1_dt is None: + if order == "K": + buf2 = _empty_like_orderK(src2, buf2_dt) + else: + if order == "A": + order = "F" if src1.flags.f_contiguous else "C" + buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=src2, dst=buf2, sycl_queue=exec_q + ) + if order == "K": + r = _empty_like_pair_orderK( + src1, buf2, res_dt, res_usm_type, exec_q + ) + else: + if order == "A": + order = ( + "F" + if all( + arr.flags.f_contiguous + for arr in ( + src1, + buf2, + ) + ) + else "C" + ) + r = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + src1 = dpt.broadcast_to(src1, res_shape) + buf2 = dpt.broadcast_to(buf2, res_shape) + ht_, _ = self.binary_fn_( + src1=src1, + src2=buf2, + dst=r, + sycl_queue=exec_q, + depends=[copy_ev], + ) + ht_copy_ev.wait() + ht_.wait() + return r + elif buf2_dt is None: + if order == "K": + buf1 = _empty_like_orderK(src1, buf1_dt) + else: + if order == "A": + order = "F" if src1.flags.f_contiguous else "C" + buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order) + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=src1, dst=buf1, sycl_queue=exec_q + ) + if order == "K": + r = _empty_like_pair_orderK( + buf1, src2, res_dt, res_usm_type, exec_q + ) + else: + if order == "A": + order = ( + "F" + if all( + arr.flags.f_contiguous + for arr in ( + buf1, + src2, + ) + ) + else "C" + ) + r = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + buf1 = dpt.broadcast_to(buf1, res_shape) + src2 = dpt.broadcast_to(src2, res_shape) + ht_, _ = self.binary_fn_( + src1=buf1, + src2=src2, + dst=r, + sycl_queue=exec_q, + depends=[copy_ev], + ) + ht_copy_ev.wait() + ht_.wait() + return r + + if order in "KA": + if src1.flags.f_contiguous and src2.flags.f_contiguous: + order = "F" + else: + order = "C" + buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order) + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=src1, dst=buf1, sycl_queue=exec_q + ) + buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=src2, dst=buf2, sycl_queue=exec_q + ) + r = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + buf1 = dpt.broadcast_to(buf1, res_shape) + buf2 = dpt.broadcast_to(buf2, res_shape) + ht_, _ = self.binary_fn_( + src1=buf1, + src2=buf2, + dst=r, + sycl_queue=exec_q, + depends=[copy1_ev, copy2_ev], + ) + dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_]) + return r diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py new file mode 100644 index 0000000000..5113ad0540 --- /dev/null +++ b/dpctl/tensor/_elementwise_funcs.py @@ -0,0 +1,53 @@ +import dpctl.tensor._tensor_impl as ti + +from ._elementwise_common import BinaryElementwiseFunc, UnaryElementwiseFunc + +# ABS +_abs_docstring_ = """ +Calculate the absolute value element-wise. +""" + +abs = UnaryElementwiseFunc("abs", ti._abs_result_type, ti._abs, _abs_docstring_) + +# ADD + +_add_docstring_ = """ +add(x1, x2, order='K') + +Calculates the sum for each element `x1_i` of the input array `x1` with +the respective element `x2_i` of the input array `x2`. + +Args: + x1 (usm_ndarray): + First input array, expected to have numeric data type. + x2 (usm_ndarray): + Second input array, also expected to have numeric data type. +Returns: + usm_narray: + an array containing the element-wise sums. The data type of the + returned array is determined by the Type Promotion Rules. +""" +add = BinaryElementwiseFunc( + "add", ti._add_result_type, ti._add, _add_docstring_ +) + + +# COS + +_cos_docstring = """ +cos(x, order='K') + +Computes cosine for each element `x_i` for input array `x`. +""" + +cos = UnaryElementwiseFunc("cos", ti._cos_result_type, ti._cos, _cos_docstring) + +# ISNAN + +_isnan_docstring_ = """ +Computes if ever element of input array is a NaN. +""" + +isnan = UnaryElementwiseFunc( + "isnan", ti._isnan_result_type, ti._isnan, _isnan_docstring_ +) diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index 3ea6875fce..ac82c67722 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import builtins + import dpctl.tensor as dpt @@ -111,3 +113,156 @@ def _can_cast(from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool): return True return can_cast_v + + +def _empty_like_orderK(X, dt, usm_type=None, dev=None): + """Returns empty array like `x`, using order='K' + + For an array `x` that was obtained by permutation of a contiguous + array the returned array will have the same shape and the same + strides as `x`. + """ + if not isinstance(X, dpt.usm_ndarray): + raise TypeError(f"Expected usm_ndarray, got {type(X)}") + if usm_type is None: + usm_type = X.usm_type + if dev is None: + dev = X.device + fl = X.flags + if fl["C"] or X.size <= 1: + return dpt.empty_like( + X, dtype=dt, usm_type=usm_type, device=dev, order="C" + ) + elif fl["F"]: + return dpt.empty_like( + X, dtype=dt, usm_type=usm_type, device=dev, order="F" + ) + st = list(X.strides) + perm = sorted( + range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True + ) + inv_perm = sorted(range(X.ndim), key=lambda i: perm[i]) + st_sorted = [st[i] for i in perm] + sh = X.shape + sh_sorted = tuple(sh[i] for i in perm) + R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C") + if min(st_sorted) < 0: + sl = tuple( + slice(None, None, -1) + if st_sorted[i] < 0 + else slice(None, None, None) + for i in range(X.ndim) + ) + R = R[sl] + return dpt.permute_dims(R, inv_perm) + + +def _empty_like_pair_orderK(X1, X2, dt, usm_type, dev): + if not isinstance(X1, dpt.usm_ndarray): + raise TypeError(f"Expected usm_ndarray, got {type(X1)}") + if not isinstance(X2, dpt.usm_ndarray): + raise TypeError(f"Expected usm_ndarray, got {type(X2)}") + nd1 = X1.ndim + nd2 = X2.ndim + if nd1 > nd2: + return _empty_like_orderK(X1, dt, usm_type, dev) + elif nd1 < nd2: + return _empty_like_orderK(X2, dt, usm_type, dev) + fl1 = X1.flags + fl2 = X2.flags + if fl1["C"] or fl2["C"]: + return dpt.empty_like( + X1, dtype=dt, usm_type=usm_type, device=dev, order="C" + ) + if fl1["F"] and fl2["F"]: + return dpt.empty_like( + X1, dtype=dt, usm_type=usm_type, device=dev, order="F" + ) + st1 = list(X1.strides) + st2 = list(X2.strides) + perm = sorted( + range(nd1), + key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])), + reverse=True, + ) + inv_perm = sorted(range(nd1), key=lambda i: perm[i]) + st1_sorted = [st1[i] for i in perm] + st2_sorted = [st2[i] for i in perm] + sh = X1.shape + sh_sorted = tuple(sh[i] for i in perm) + R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C") + if max(min(st1_sorted), min(st2_sorted)) < 0: + sl = tuple( + slice(None, None, -1) + if (st1_sorted[i] < 0 and st2_sorted[i] < 0) + else slice(None, None, None) + for i in range(nd1) + ) + R = R[sl] + return dpt.permute_dims(R, inv_perm) + + +def _to_device_supported_dtype(dt, dev): + has_fp16 = dev.has_aspect_fp16 + has_fp64 = dev.has_aspect_fp64 + + if has_fp64: + if not has_fp16: + if dt is dpt.float16: + return dpt.float32 + else: + if dt is dpt.float64: + return dpt.float32 + elif dt is dpt.complex128: + return dpt.complex64 + if not has_fp16 and dt is dpt.float16: + return dpt.float32 + return dt + + +def _find_buf_dtype(arg_dtype, query_fn, sycl_dev): + res_dt = query_fn(arg_dtype) + if res_dt: + return None, res_dt + + _fp16 = sycl_dev.has_aspect_fp16 + _fp64 = sycl_dev.has_aspect_fp64 + all_dts = _all_data_types(_fp16, _fp64) + for buf_dt in all_dts: + if _can_cast(arg_dtype, buf_dt, _fp16, _fp64): + res_dt = query_fn(buf_dt) + if res_dt: + return buf_dt, res_dt + + return None, None + + +def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev): + res_dt = query_fn(arg1_dtype, arg2_dtype) + if res_dt: + return None, None, res_dt + + _fp16 = sycl_dev.has_aspect_fp16 + _fp64 = sycl_dev.has_aspect_fp64 + all_dts = _all_data_types(_fp16, _fp64) + for buf1_dt in all_dts: + for buf2_dt in all_dts: + if _can_cast(arg1_dtype, buf1_dt, _fp16, _fp64) and _can_cast( + arg2_dtype, buf2_dt, _fp16, _fp64 + ): + res_dt = query_fn(buf1_dt, buf2_dt) + if res_dt: + ret_buf1_dt = None if buf1_dt == arg1_dtype else buf1_dt + ret_buf2_dt = None if buf2_dt == arg2_dtype else buf2_dt + return ret_buf1_dt, ret_buf2_dt, res_dt + + return None, None, None + + +__all__ = [ + "_find_buf_dtype", + "_find_buf_dtype2", + "_empty_like_orderK", + "_empty_like_pair_orderK", + "_to_device_supported_dtype", +] diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp new file mode 100644 index 0000000000..f7221cddb6 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -0,0 +1,322 @@ +#pragma once +#include + +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace abs +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +template +struct AbsContigFunctor +{ +private: + const argT *in = nullptr; + resT *out = nullptr; + const size_t nelems_; + +public: + AbsContigFunctor(const argT *inp, resT *res, const size_t n_elems) + : in(inp), out(res), nelems_(n_elems) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + /* Each work-item processes vec_sz elements, contiguous in memory */ + /* NOTE: vec_sz must divide sg.max_local_range()[0] */ + + if constexpr (std::is_same_v || + (std::is_integral::value && + std::is_unsigned::value)) + { + static_assert(std::is_same_v); + + auto sg = ndit.get_sub_group(); + std::uint8_t sgSize = sg.get_local_range()[0]; + std::uint8_t max_sgSize = sg.get_max_local_range()[0]; + size_t base = n_vecs * vec_sz * + (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * max_sgSize); + + if (base + n_vecs * vec_sz * sgSize < nelems_ && + sgSize == max_sgSize) { + using in_ptrT = + sycl::multi_ptr; + using out_ptrT = + sycl::multi_ptr; + sycl::vec arg_vec; + +#pragma unroll + for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + arg_vec = sg.load(in_ptrT(&in[base + it * sgSize])); + sg.store(out_ptrT(&out[base + it * sgSize]), + arg_vec); + } + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < nelems_; + k += sgSize) { + out[k] = in[k]; + } + } + } + else { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; + size_t base = ndit.get_global_linear_id(); + + base = (base / sgSize) * sgSize * n_vecs * vec_sz + + (base % sgSize); + for (size_t offset = base; + offset < + std::min(nelems_, base + sgSize * (n_vecs * vec_sz)); + offset += sgSize) + { + out[offset] = std::abs(in[offset]); + } + } + else { + auto sg = ndit.get_sub_group(); + std::uint8_t sgSize = sg.get_local_range()[0]; + std::uint8_t maxsgSize = sg.get_max_local_range()[0]; + size_t base = n_vecs * vec_sz * + (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * maxsgSize); + + if (base + n_vecs * vec_sz < nelems_) { + using in_ptrT = sycl::multi_ptr< + const argT, sycl::access::address_space::global_space>; + using out_ptrT = sycl::multi_ptr< + resT, sycl::access::address_space::global_space>; + sycl::vec arg_vec; + +#pragma unroll + for (std::uint8_t it = 0; it < n_vecs * vec_sz; + it += vec_sz) { + arg_vec = + sg.load(in_ptrT(&in[base + it * sgSize])); +#pragma unroll + for (std::uint8_t k = 0; k < vec_sz; ++k) { + arg_vec[k] = std::abs(arg_vec[k]); + } + sg.store(out_ptrT(&out[base + it * sgSize]), + arg_vec); + } + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < nelems_; + k += sgSize) { + out[k] = std::abs(in[k]); + } + } + } + } + } +}; + +template struct AbsOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, float>, + td_ns::TypeMapEntry, double>, + td_ns::DefaultEntry>::result_type; +}; + +template +class abs_contig_kernel; + +typedef sycl::event (*abs_contig_impl_fn_ptr_t)( + sycl::queue, + size_t, + const char *, + char *, + const std::vector &); + +template +sycl::event abs_contig_impl(sycl::queue exec_q, + size_t nelems, + const char *arg_p, + char *res_p, + const std::vector &depends = {}) +{ + sycl::event abs_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + size_t lws = 64; + constexpr unsigned int vec_sz = 4; + constexpr unsigned int n_vecs = 2; + const size_t n_groups = + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); + + using resTy = typename AbsOutputType::value_type; + const argTy *arg_tp = reinterpret_cast(arg_p); + resTy *res_tp = reinterpret_cast(res_p); + + cgh.parallel_for>( + sycl::nd_range<1>(gws_range, lws_range), + AbsContigFunctor(arg_tp, res_tp, + nelems)); + }); + return abs_ev; +} + +template struct AbsContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = abs_contig_impl; + return fn; + } + } +}; + +template struct AbsTypeMapFactory +{ + /*! @brief get typeid for output type of std::abs(T x) */ + std::enable_if_t::value, int> get() + { + using rT = typename AbsOutputType::value_type; + ; + return td_ns::GetTypeid{}.get(); + } +}; + +template +struct AbsStridedFunctor +{ +private: + const argT *in = nullptr; + resT *out = nullptr; + IndexerT inp_res_indexer_; + +public: + AbsStridedFunctor(const argT *inp_p, + resT *res_p, + IndexerT two_offsets_indexer) + : in(inp_p), out(res_p), inp_res_indexer_(two_offsets_indexer) + { + } + + void operator()(sycl::id<1> wid) const + { + auto offsets_ = inp_res_indexer_(static_cast(wid[0])); + const auto &inp_offset = offsets_.get_first_offset(); + const auto &out_offset = offsets_.get_second_offset(); + + if constexpr (std::is_same_v || + (std::is_integral::value && + std::is_unsigned::value)) + { + out[out_offset] = in[inp_offset]; + } + else { + out[out_offset] = std::abs(in[inp_offset]); + } + } +}; + +template class abs_strided_kernel; + +typedef sycl::event (*abs_strided_impl_fn_ptr_t)( + sycl::queue, + size_t, + int, + const py::ssize_t *, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &, + const std::vector &); + +template +sycl::event abs_strided_impl(sycl::queue exec_q, + size_t nelems, + int nd, + const py::ssize_t *shape_and_strides, + const char *arg_p, + py::ssize_t arg_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + sycl::event abs_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.depends_on(additional_depends); + + using resTy = typename AbsOutputType::value_type; + using IndexerT = + typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + IndexerT indexer{nd, arg_offset, res_offset, shape_and_strides}; + + const argTy *arg_tp = reinterpret_cast(arg_p); + resTy *res_tp = reinterpret_cast(res_p); + + cgh.parallel_for>( + {nelems}, + AbsStridedFunctor(arg_tp, res_tp, indexer)); + }); + return abs_ev; +} + +template struct AbsStridedFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = abs_strided_impl; + return fn; + } + } +}; + +} // namespace abs +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp new file mode 100644 index 0000000000..d3f1d2fd82 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -0,0 +1,504 @@ +#pragma once +#include +#include + +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace add +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +template +struct AddContigFunctor +{ +private: + const argT1 *in1 = nullptr; + const argT2 *in2 = nullptr; + resT *out = nullptr; + const size_t nelems_; + +public: + AddContigFunctor(const argT1 *inp1, + const argT2 *inp2, + resT *res, + const size_t n_elems) + : in1(inp1), in2(inp2), out(res), nelems_(n_elems) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + /* Each work-item processes vec_sz elements, contiguous in memory */ + + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value || is_complex::value) { + std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; + size_t base = ndit.get_global_linear_id(); + + base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); + for (size_t offset = base; + offset < std::min(nelems_, base + sgSize * (n_vecs * vec_sz)); + offset += sgSize) + { + out[offset] = in1[offset] + in2[offset]; + } + } + else { + auto sg = ndit.get_sub_group(); + std::uint8_t sgSize = sg.get_local_range()[0]; + std::uint8_t maxsgSize = sg.get_max_local_range()[0]; + size_t base = n_vecs * vec_sz * + (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * maxsgSize); + + if (base + n_vecs * vec_sz < nelems_) { + using in_ptrT1 = + sycl::multi_ptr; + using in_ptrT2 = + sycl::multi_ptr; + using out_ptrT = + sycl::multi_ptr; + sycl::vec arg1_vec; + sycl::vec arg2_vec; + sycl::vec res_vec; + +#pragma unroll + for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + arg1_vec = + sg.load(in_ptrT1(&in1[base + it * sgSize])); + arg2_vec = + sg.load(in_ptrT2(&in2[base + it * sgSize])); + if constexpr (std::is_same_v && + std::is_same_v) { + res_vec = arg1_vec + arg2_vec; + } + else { + using dpctl::tensor::type_utils::vec_cast; + + auto tmp = arg1_vec + arg2_vec; + res_vec = std::move( + vec_cast(tmp)); + } + sg.store(out_ptrT(&out[base + it * sgSize]), + res_vec); + } + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < nelems_; + k += sgSize) { + out[k] = in1[k] + in2[k]; + } + } + } + } +}; + +template struct AddOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns:: + BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + T2, + std::complex, + std::complex>, + td_ns::BinaryTypeMapEntry, + T2, + std::complex, + std::complex>, + td_ns::DefaultEntry>::result_type; +}; + +template +class add_contig_kernel; + +typedef sycl::event (*add_contig_impl_fn_ptr_t)( + sycl::queue, + size_t, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &); + +template +sycl::event add_contig_impl(sycl::queue exec_q, + size_t nelems, + const char *arg1_p, + py::ssize_t arg1_offset, + const char *arg2_p, + py::ssize_t arg2_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends = {}) +{ + sycl::event add_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + size_t lws = 64; + constexpr unsigned int vec_sz = 4; + constexpr unsigned int n_vecs = 2; + const size_t n_groups = + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); + + using resTy = typename AddOutputType::value_type; + + const argTy1 *arg1_tp = + reinterpret_cast(arg1_p) + arg1_offset; + const argTy2 *arg2_tp = + reinterpret_cast(arg2_p) + arg2_offset; + resTy *res_tp = reinterpret_cast(res_p) + res_offset; + + cgh.parallel_for< + add_contig_kernel>( + sycl::nd_range<1>(gws_range, lws_range), + AddContigFunctor( + arg1_tp, arg2_tp, res_tp, nelems)); + }); + return add_ev; +} + +template struct AddContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = add_contig_impl; + return fn; + } + } +}; + +template struct AddTypeMapFactory +{ + /*! @brief get typeid for output type of std::add(T1 x, T2 y) */ + std::enable_if_t::value, int> get() + { + using rT = typename AddOutputType::value_type; + ; + return td_ns::GetTypeid{}.get(); + } +}; + +template +struct AddStridedFunctor +{ +private: + const argT1 *in1 = nullptr; + const argT2 *in2 = nullptr; + resT *out = nullptr; + ThreeOffsets_IndexerT three_offsets_indexer_; + +public: + AddStridedFunctor(const argT1 *inp1_tp, + const argT2 *inp2_tp, + resT *res_tp, + ThreeOffsets_IndexerT inps_res_indexer) + : in1(inp1_tp), in2(inp2_tp), out(res_tp), + three_offsets_indexer_(inps_res_indexer) + { + } + + void operator()(sycl::id<1> wid) const + { + const auto &three_offsets_ = + three_offsets_indexer_(static_cast(wid.get(0))); + + const auto &inp1_offset = three_offsets_.get_first_offset(); + const auto &inp2_offset = three_offsets_.get_second_offset(); + const auto &out_offset = three_offsets_.get_third_offset(); + + out[out_offset] = in1[inp1_offset] + in2[inp2_offset]; + } +}; + +template +class add_strided_strided_kernel; + +typedef sycl::event (*add_strided_impl_fn_ptr_t)( + sycl::queue, + size_t, + int, + const py::ssize_t *, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &, + const std::vector &); + +template +sycl::event add_strided_impl(sycl::queue exec_q, + size_t nelems, + int nd, + const py::ssize_t *shape_and_strides, + const char *arg1_p, + py::ssize_t arg1_offset, + const char *arg2_p, + py::ssize_t arg2_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + sycl::event abs_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.depends_on(additional_depends); + + using resTy = typename AddOutputType::value_type; + + using IndexerT = + typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + + IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset, + shape_and_strides}; + + const argTy1 *arg1_tp = reinterpret_cast(arg1_p); + const argTy2 *arg2_tp = reinterpret_cast(arg2_p); + resTy *res_tp = reinterpret_cast(res_p); + + cgh.parallel_for< + add_strided_strided_kernel>( + {nelems}, AddStridedFunctor( + arg1_tp, arg2_tp, res_tp, indexer)); + }); + return abs_ev; +} + +template struct AddStridedFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = add_strided_impl; + return fn; + } + } +}; + +template +class add_matrix_vector_broadcast_sg_krn; + +typedef sycl::event (*add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t)( + sycl::queue, + std::vector &, + size_t, + size_t, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &); + +template +sycl::event add_contig_matrix_contig_row_broadcast_impl( + sycl::queue exec_q, + std::vector &host_tasks, + size_t n0, + size_t n1, + const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix + py::ssize_t mat_offset, + const char *vec_p, // typeless pointer to (n1,) contiguous row + py::ssize_t vec_offset, + char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, + // res[i,j] = mat[i,j] + vec[j] + py::ssize_t res_offset, + const std::vector &depends = {}) +{ + const argT1 *mat = reinterpret_cast(mat_p) + mat_offset; + const argT2 *vec = reinterpret_cast(vec_p) + vec_offset; + resT *res = reinterpret_cast(res_p) + res_offset; + + const auto &dev = exec_q.get_device(); + const auto &sg_sizes = dev.get_info(); + // Get device-specific kernel info max_sub_group_size + size_t max_sgSize = + *(std::max_element(std::begin(sg_sizes), std::end(sg_sizes))); + + size_t n1_padded = n1 + max_sgSize; + argT2 *padded_vec = sycl::malloc_device(n1_padded, exec_q); + + if (padded_vec == nullptr) { + throw std::runtime_error("Could not allocate memory on the device"); + } + sycl::event make_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); // ensure vec contains actual data + cgh.parallel_for({n1_padded}, [=](sycl::id<1> id) { + auto i = id[0]; + padded_vec[i] = vec[i % n1]; + }); + }); + + // sub-group spans work-items [I, I + sgSize) + // base = ndit.get_global_linear_id() - sg.get_local_id()[0] + // Generically, sg.load( &mat[base]) may load arrays from + // different rows of mat. The start corresponds to row (base / n0) + // We read sg.load(&padded_vec[(base / n0)]). The vector is padded to + // ensure that reads are accessible + + size_t lws = 64; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(make_padded_vec_ev); + + auto lwsRange = sycl::range<1>(lws); + size_t n_groups = (n0 * n1 + lws - 1) / lws; + auto gwsRange = sycl::range<1>(n_groups * lws); + + cgh.parallel_for>( + sycl::nd_range<1>(gwsRange, lwsRange), + [=](sycl::nd_item<1> ndit) + { + auto sg = ndit.get_sub_group(); + size_t gid = ndit.get_global_linear_id(); + + size_t base = gid - sg.get_local_id()[0]; + + using in_ptrT1 = + sycl::multi_ptr; + using in_ptrT2 = + sycl::multi_ptr; + using res_ptrT = + sycl::multi_ptr; + + const argT1 mat_el = sg.load(in_ptrT1(&mat[base])); + const argT2 vec_el = sg.load(in_ptrT2(&padded_vec[base % n1])); + + resT res_el = mat_el + vec_el; + + sg.store(res_ptrT(&res[base]), res_el); + } + ); + }); + + sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(comp_ev); + sycl::context ctx = exec_q.get_context(); + cgh.host_task([ctx, padded_vec]() { sycl::free(padded_vec, ctx); }); + }); + host_tasks.push_back(tmp_cleanup_ev); + + return comp_ev; +} + +template +struct AddContigMatrixContigRowBroadcastFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + fnT fn = nullptr; + return fn; + } + else { + using resT = typename AddOutputType::value_type; + if constexpr (dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = + add_contig_matrix_contig_row_broadcast_impl; + return fn; + } + } + } +}; + +} // namespace add +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp new file mode 100644 index 0000000000..97f0fd1f26 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -0,0 +1,291 @@ +#pragma once +#include +#include + +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace cos +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +template +struct CosContigFunctor +{ +private: + const argT *in = nullptr; + resT *out = nullptr; + const size_t nelems_; + +public: + CosContigFunctor(const argT *inp, resT *res, const size_t nelems) + : in(inp), out(res), nelems_(nelems) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + auto sg = ndit.get_sub_group(); + + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + std::uint8_t sgSize = sg.get_local_range()[0]; + size_t base = ndit.get_global_linear_id(); + + base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); + for (size_t offset = base; + offset < std::min(base + sgSize * (n_vecs * vec_sz), nelems_); + offset += sgSize) + { + using realT = typename argT::value_type; + // cos(x + I*y) = cos(x)*cosh(y) - I*sin(x)*sinh(y) + auto v = std::real(in[offset]); + realT cosX_val; + const realT sinX_val = sycl::sincos(-v, &cosX_val); + v = std::imag(in[offset]); + const realT sinhY_val = sycl::sinh(v); + const realT coshY_val = sycl::cosh(v); + + const realT res_re = coshY_val * cosX_val; + const realT res_im = sinX_val * sinhY_val; + out[offset] = resT{res_re, res_im}; + } + } + else { + using dpctl::tensor::type_utils::vec_cast; + + std::uint8_t sgSize = sg.get_local_range()[0]; + size_t base = n_vecs * vec_sz * + (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + if (base + n_vecs * vec_sz * sg.get_max_local_range()[0] < nelems_) + { + using in_ptrT = + sycl::multi_ptr; + using out_ptrT = + sycl::multi_ptr; + +#pragma unroll + for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + sycl::vec x = + sg.load(in_ptrT(&in[base + it * sgSize])); + + sycl::vec res_vec = sycl::cos( + vec_cast(x)); + sg.store(out_ptrT(&out[base + it * sgSize]), + res_vec); + } + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < nelems_; + k += sgSize) + out[k] = sycl::cos(static_cast(in[k])); + } + } + } +}; + +template struct CosOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, std::complex>, + td_ns::TypeMapEntry, std::complex>, + td_ns::DefaultEntry>::result_type; +}; + +typedef sycl::event (*cos_contig_impl_fn_ptr_t)( + sycl::queue, + size_t, + const char *, + char *, + const std::vector &); + +template +class cos_contig_kernel; + +template +sycl::event cos_contig_impl(sycl::queue exec_q, + size_t nelems, + const char *arg_p, + char *res_p, + const std::vector &depends = {}) +{ + sycl::event cos_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + constexpr size_t lws = 64; + constexpr unsigned int vec_sz = 4; + constexpr unsigned int n_vecs = 2; + static_assert(lws % vec_sz == 0); + auto gws_range = sycl::range<1>( + ((nelems + n_vecs * lws * vec_sz - 1) / (lws * n_vecs * vec_sz)) * + lws); + auto lws_range = sycl::range<1>(lws); + + using resTy = typename CosOutputType::value_type; + const argTy *arg_tp = reinterpret_cast(arg_p); + resTy *res_tp = reinterpret_cast(res_p); + + cgh.parallel_for>( + sycl::nd_range<1>(gws_range, lws_range), + CosContigFunctor(arg_tp, res_tp, + nelems)); + }); + return cos_ev; +} + +template struct CosContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = cos_contig_impl; + return fn; + } + } +}; + +template struct CosTypeMapFactory +{ + /*! @brief get typeid for output type of sycl::cos(T x) */ + std::enable_if_t::value, int> get() + { + using rT = typename CosOutputType::value_type; + ; + return td_ns::GetTypeid{}.get(); + } +}; + +template +struct CosStridedFunctor +{ +private: + const argT *in = nullptr; + resT *out = nullptr; + IndexerT inp_out_indexer_; + +public: + CosStridedFunctor(const argT *inp_tp, + resT *res_tp, + IndexerT arg_res_indexer) + : in(inp_tp), out(res_tp), inp_out_indexer_(arg_res_indexer) + { + } + + void operator()(sycl::id<1> wid) const + { + auto offsets_ = inp_out_indexer_(static_cast(wid.get(0))); + const py::ssize_t &inp_offset = offsets_.get_first_offset(); + const py::ssize_t &out_offset = offsets_.get_second_offset(); + + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using realT = typename argT::value_type; + // cos(x + I*y) = cos(x)*cosh(y) - I*sin(x)*sinh(y) + auto v = std::real(in[inp_offset]); + realT cosX_val; + const realT sinX_val = sycl::sincos(-v, &cosX_val); + v = std::imag(in[inp_offset]); + const realT sinhY_val = sycl::sinh(v); + const realT coshY_val = sycl::cosh(v); + + const realT res_re = coshY_val * cosX_val; + const realT res_im = sinX_val * sinhY_val; + out[out_offset] = resT{res_re, res_im}; + } + else { + out[out_offset] = std::cos(static_cast(in[inp_offset])); + } + } +}; + +template class cos_strided_kernel; + +typedef sycl::event (*cos_strided_impl_fn_ptr_t)( + sycl::queue, + size_t, + int, + const py::ssize_t *, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &, + const std::vector &); + +template +sycl::event cos_strided_impl(sycl::queue exec_q, + size_t nelems, + int nd, + const py::ssize_t *shape_and_strides, + const char *arg_p, + py::ssize_t arg_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.depends_on(additional_depends); + + using resTy = typename CosOutputType::value_type; + using IndexerT = + typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + IndexerT arg_res_indexer(nd, arg_offset, res_offset, shape_and_strides); + + const argTy *arg_tp = reinterpret_cast(arg_p); + resTy *res_tp = reinterpret_cast(res_p); + + cgh.parallel_for>( + {nelems}, CosStridedFunctor( + arg_tp, res_tp, arg_res_indexer)); + }); + return comp_ev; +} + +template struct CosStridedFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = cos_strided_impl; + return fn; + } + } +}; + +} // namespace cos +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp new file mode 100644 index 0000000000..aa7abecb98 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp @@ -0,0 +1,296 @@ +#pragma once +#include +#include + +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace isnan +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +template +struct IsNanContigFunctor +{ +private: + const argT *in = nullptr; + resT *out = nullptr; + const size_t nelems_; + +public: + IsNanContigFunctor(const argT *inp, resT *res, const size_t nelems) + : in(inp), out(res), nelems_(nelems) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + using dpctl::tensor::type_utils::is_complex; + using dpctl::tensor::type_utils::vec_cast; + + if constexpr (is_complex::value) { + std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; + size_t base = ndit.get_global_linear_id(); + + base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); + for (size_t offset = base; + offset < std::min(nelems_, base + sgSize * (n_vecs * vec_sz)); + offset += sgSize) + { + const bool real_isnan = sycl::isnan(std::real(in[offset])); + const bool imag_isnan = sycl::isnan(std::imag(in[offset])); + out[offset] = real_isnan || imag_isnan; + } + } + else if constexpr (std::is_same::value || + std::is_integral::value) + { + using out_ptrT = + sycl::multi_ptr; + + auto sg = ndit.get_sub_group(); + std::uint8_t sgSize = sg.get_local_range()[0]; + std::uint8_t max_sgSize = sg.get_max_local_range()[0]; + size_t base = n_vecs * vec_sz * + (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + if (base + n_vecs * vec_sz * max_sgSize < nelems_ && + max_sgSize == sgSize) { + sycl::vec res_vec(false); +#pragma unroll + for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + sg.store(out_ptrT(&out[base + it * sgSize]), + res_vec); + } + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < nelems_; + k += sgSize) { + out[k] = false; + } + } + } + else { + using in_ptrT = + sycl::multi_ptr; + using out_ptrT = + sycl::multi_ptr; + static_assert(std::is_same::value); + + auto sg = ndit.get_sub_group(); + std::uint16_t sgSize = sg.get_local_range()[0]; + std::uint16_t max_sgSize = sg.get_max_local_range()[0]; + size_t base = n_vecs * vec_sz * + (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * max_sgSize); + if (base + n_vecs * vec_sz * max_sgSize < nelems_ && + sgSize == max_sgSize) { + sycl::vec x; + +#pragma unroll + for (std::uint16_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + x = sg.load(in_ptrT(&in[base + it * sgSize])); + // returns vec + auto res_vec = sycl::isnan(x); + // cast it to bool + sycl::vec res_bool = + vec_cast(res_vec); + sg.store(out_ptrT(&out[base + it * sgSize]), + res_bool); + } + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < nelems_; + k += sgSize) { + out[k] = static_cast(sycl::isnan(in[k])); + } + } + } + } +}; + +template struct IsNanOutputType +{ + using value_type = bool; +}; + +typedef sycl::event (*isnan_contig_impl_fn_ptr_t)( + sycl::queue, + size_t, + const char *, + char *, + const std::vector &); + +template +class isnan_contig_kernel; + +template +sycl::event isnan_contig_impl(sycl::queue exec_q, + size_t nelems, + const char *arg_p, + char *res_p, + const std::vector &depends = {}) +{ + sycl::event isnan_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + constexpr size_t lws = 64; + constexpr std::uint8_t vec_sz = 4; + constexpr std::uint8_t n_vecs = 2; + static_assert(lws % vec_sz == 0); + auto gws_range = sycl::range<1>( + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)) * + lws); + auto lws_range = sycl::range<1>(lws); + + using resTy = typename IsNanOutputType::value_type; + const argTy *arg_tp = reinterpret_cast(arg_p); + resTy *res_tp = reinterpret_cast(res_p); + + cgh.parallel_for< + class isnan_contig_kernel>( + sycl::nd_range<1>(gws_range, lws_range), + IsNanContigFunctor(arg_tp, res_tp, + nelems)); + }); + return isnan_ev; +} + +template struct IsNanContigFactory +{ + fnT get() + { + fnT fn = isnan_contig_impl; + return fn; + } +}; + +template struct IsNanTypeMapFactory +{ + /*! @brief get typeid for output type of sycl::isnan(T x) */ + std::enable_if_t::value, int> get() + { + using rT = typename IsNanOutputType::value_type; + ; + return td_ns::GetTypeid{}.get(); + } +}; + +template +struct IsNanStridedFunctor +{ +private: + const argT *inp_ = nullptr; + resT *res_ = nullptr; + IndexerT inp_out_indexer_; + +public: + IsNanStridedFunctor(const argT *inp_p, + resT *res_p, + IndexerT inp_out_indexer) + : inp_(inp_p), res_(res_p), inp_out_indexer_(inp_out_indexer) + { + } + + void operator()(sycl::id<1> wid) const + { + const argT *const &in = inp_; + resT *const &out = res_; + + auto offsets_ = inp_out_indexer_(wid.get(0)); + const py::ssize_t &inp_offset = offsets_.get_first_offset(); + const py::ssize_t &out_offset = offsets_.get_second_offset(); + + using dpctl::tensor::type_utils::is_complex; + if constexpr (std::is_same_v || + (std::is_integral::value)) { + out[out_offset] = false; + } + else if constexpr (is_complex::value) { + const bool real_isnan = sycl::isnan(std::real(in[inp_offset])); + const bool imag_isnan = sycl::isnan(std::imag(in[inp_offset])); + + out[out_offset] = real_isnan || imag_isnan; + } + else { + out[out_offset] = sycl::isnan(in[inp_offset]); + } + } +}; + +template class isnan_strided_kernel; + +typedef sycl::event (*isnan_strided_impl_fn_ptr_t)( + sycl::queue, + size_t, + int, + const py::ssize_t *, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &, + const std::vector &); + +template +sycl::event +isnan_strided_impl(sycl::queue exec_q, + size_t nelems, + int nd, + const py::ssize_t *shape_and_strides, + const char *arg_p, + py::ssize_t arg_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + sycl::event abs_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.depends_on(additional_depends); + + using resTy = typename IsNanOutputType::value_type; + using IndexerT = + typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + IndexerT arg_res_indexer{nd, arg_offset, res_offset, shape_and_strides}; + + const argTy *arg_tptr = reinterpret_cast(arg_p); + resTy *res_tptr = reinterpret_cast(res_p); + + cgh.parallel_for>( + {nelems}, IsNanStridedFunctor( + arg_tptr, res_tptr, arg_res_indexer)); + }); + return abs_ev; +} + +template struct IsNanStridedFactory +{ + fnT get() + { + fnT fn = isnan_strided_impl; + return fn; + } +}; + +} // namespace isnan +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp b/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp index 5d7d6b8a8c..07fbbc6baf 100644 --- a/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp +++ b/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp @@ -250,6 +250,89 @@ struct usm_ndarray_types } }; +/*! @brief struct to define result_type typename for Ty == ArgTy */ +template +struct TypeMapEntry : std::bool_constant> +{ + using result_type = ResTy; +}; + +/*! @brief struct to define result_type typename for Ty1 == ArgTy1 && Ty2 == + * ArgTy2 */ +template +struct BinaryTypeMapEntry + : std::bool_constant, + std::is_same>> +{ + using result_type = ResTy; +}; + +/*! @brief fall-through struct with specified result_type, usually void */ +template struct DefaultEntry : std::true_type +{ + using result_type = Ty; +}; + +/*! @brief Utility struct to convert C++ type into typeid integer */ +template struct GetTypeid +{ + int get() + { + if constexpr (std::is_same_v) { + return static_cast(typenum_t::BOOL); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::INT8); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::UINT8); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::INT16); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::UINT16); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::INT32); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::UINT32); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::INT64); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::UINT64); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::HALF); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::FLOAT); + } + else if constexpr (std::is_same_v) { + return static_cast(typenum_t::DOUBLE); + } + else if constexpr (std::is_same_v>) { + return static_cast(typenum_t::CFLOAT); + } + else if constexpr (std::is_same_v>) { + return static_cast(typenum_t::CDOUBLE); + } + else if constexpr (std::is_same_v) { // special token + return -1; + } + + assert(("Unsupported type T", false)); + return -2; + } +}; + } // namespace type_dispatch } // namespace tensor diff --git a/dpctl/tensor/libtensor/include/utils/type_utils.hpp b/dpctl/tensor/libtensor/include/utils/type_utils.hpp index 8464418d8b..36c00404c9 100644 --- a/dpctl/tensor/libtensor/include/utils/type_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/type_utils.hpp @@ -26,6 +26,7 @@ #include #include #include +#include namespace dpctl { @@ -100,6 +101,21 @@ template void validate_type_for_device(const sycl::queue &q) validate_type_for_device(q.get_device()); } +template +auto vec_cast_impl(const Vec &v, std::index_sequence) +{ + return Op{v[I]...}; +} + +template > +auto vec_cast(const sycl::vec &s) +{ + return vec_cast_impl, sycl::vec>(s, Indices{}); +} + } // namespace type_utils } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp new file mode 100644 index 0000000000..8d7283bfaa --- /dev/null +++ b/dpctl/tensor/libtensor/source/elementwise_functions.cpp @@ -0,0 +1,523 @@ +//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2022 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions, +/// specifically functions for elementwise operations. +//===----------------------------------------------------------------------===// + +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include + +#include "elementwise_functions.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/elementwise_functions/abs.hpp" +#include "kernels/elementwise_functions/add.hpp" +#include "kernels/elementwise_functions/cos.hpp" +#include "kernels/elementwise_functions/isnan.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +py::dtype _dtype_from_typenum(td_ns::typenum_t dst_typenum_t) +{ + switch (dst_typenum_t) { + case td_ns::typenum_t::BOOL: + return py::dtype("?"); + case td_ns::typenum_t::INT8: + return py::dtype("i1"); + case td_ns::typenum_t::UINT8: + return py::dtype("u1"); + case td_ns::typenum_t::INT16: + return py::dtype("i2"); + case td_ns::typenum_t::UINT16: + return py::dtype("u2"); + case td_ns::typenum_t::INT32: + return py::dtype("i4"); + case td_ns::typenum_t::UINT32: + return py::dtype("u4"); + case td_ns::typenum_t::INT64: + return py::dtype("i8"); + case td_ns::typenum_t::UINT64: + return py::dtype("u8"); + case td_ns::typenum_t::HALF: + return py::dtype("f2"); + case td_ns::typenum_t::FLOAT: + return py::dtype("f4"); + case td_ns::typenum_t::DOUBLE: + return py::dtype("f8"); + case td_ns::typenum_t::CFLOAT: + return py::dtype("c8"); + case td_ns::typenum_t::CDOUBLE: + return py::dtype("c16"); + default: + throw py::value_error("Unrecognized dst_typeid"); + } +} + +int _result_typeid(int arg_typeid, const int *fn_output_id) +{ + if (arg_typeid < 0 || arg_typeid >= td_ns::num_types) { + throw py::value_error("Input typeid " + std::to_string(arg_typeid) + + " is outside of expected bounds."); + } + + return fn_output_id[arg_typeid]; +} + +// ABS +namespace impl +{ + +using dpctl::tensor::kernels::abs::abs_contig_impl_fn_ptr_t; +using dpctl::tensor::kernels::abs::abs_strided_impl_fn_ptr_t; + +static abs_contig_impl_fn_ptr_t abs_contig_dispatch_vector[td_ns::num_types]; +static int abs_output_typeid_vector[td_ns::num_types]; +static abs_strided_impl_fn_ptr_t abs_strided_dispatch_vector[td_ns::num_types]; + +void populate_abs_dispatch_vectors(void) +{ + using namespace td_ns; + + using dpctl::tensor::kernels::abs::AbsContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(abs_contig_dispatch_vector); + + using dpctl::tensor::kernels::abs::AbsStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(abs_strided_dispatch_vector); + + using dpctl::tensor::kernels::abs::AbsTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(abs_output_typeid_vector); +}; + +} // namespace impl + +// ISNAN +namespace impl +{ + +using dpctl::tensor::kernels::isnan::isnan_contig_impl_fn_ptr_t; +using dpctl::tensor::kernels::isnan::isnan_strided_impl_fn_ptr_t; + +static isnan_contig_impl_fn_ptr_t + isnan_contig_dispatch_vector[td_ns::num_types]; +static int isnan_output_typeid_vector[td_ns::num_types]; +static isnan_strided_impl_fn_ptr_t + isnan_strided_dispatch_vector[td_ns::num_types]; + +void populate_isnan_dispatch_vectors(void) +{ + using namespace td_ns; + + using dpctl::tensor::kernels::isnan::IsNanContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(isnan_contig_dispatch_vector); + + using dpctl::tensor::kernels::isnan::IsNanStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(isnan_strided_dispatch_vector); + + using dpctl::tensor::kernels::isnan::IsNanTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(isnan_output_typeid_vector); +} + +} // namespace impl + +// COS +namespace impl +{ + +using dpctl::tensor::kernels::cos::cos_contig_impl_fn_ptr_t; +using dpctl::tensor::kernels::cos::cos_strided_impl_fn_ptr_t; + +static cos_contig_impl_fn_ptr_t cos_contig_dispatch_vector[td_ns::num_types]; +static int cos_output_typeid_vector[td_ns::num_types]; +static cos_strided_impl_fn_ptr_t cos_strided_dispatch_vector[td_ns::num_types]; + +void populate_cos_dispatch_vectors(void) +{ + using namespace td_ns; + + using dpctl::tensor::kernels::cos::CosContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(cos_contig_dispatch_vector); + + using dpctl::tensor::kernels::cos::CosStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(cos_strided_dispatch_vector); + + using dpctl::tensor::kernels::cos::CosTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(cos_output_typeid_vector); +} + +} // namespace impl + +// ADD + +namespace impl +{ + +using dpctl::tensor::kernels::add::add_contig_impl_fn_ptr_t; +using dpctl::tensor::kernels::add:: + add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using dpctl::tensor::kernels::add::add_strided_impl_fn_ptr_t; + +static add_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static int add_output_id_table[td_ns::num_types][td_ns::num_types]; + +static add_strided_impl_fn_ptr_t add_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t + add_contig_matrix_contig_row_broadcast_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_add_dispatch_tables(void) +{ + using namespace td_ns; + + using dpctl::tensor::kernels::add::AddContigFactory; + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(add_contig_dispatch_table); + + using dpctl::tensor::kernels::add::AddStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(add_strided_dispatch_table); + + using dpctl::tensor::kernels::add::AddTypeMapFactory; + DispatchTableBuilder dtb3; + dtb3.populate_dispatch_table(add_output_id_table); + + using dpctl::tensor::kernels::add::AddContigMatrixContigRowBroadcastFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table( + add_contig_matrix_contig_row_broadcast_dispatch_table); +}; + +} // namespace impl + +namespace py = pybind11; + +void init_elementwise_functions(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + + // U01: ==== ABS (x) + { + impl::populate_abs_dispatch_vectors(); + using impl::abs_contig_dispatch_vector; + using impl::abs_output_typeid_vector; + using impl::abs_strided_dispatch_vector; + + auto abs_pyapi = [&](arrayT src, arrayT dst, sycl::queue exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, abs_output_typeid_vector, + abs_contig_dispatch_vector, abs_strided_dispatch_vector); + }; + m.def("_abs", abs_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto abs_result_type_pyapi = [&](py::dtype dtype) { + return py_unary_ufunc_result_type(dtype, abs_output_typeid_vector); + }; + m.def("_abs_result_type", abs_result_type_pyapi); + } + + // U02: ==== ACOS (x) + // FIXME: + // U03: ===== ACOSH (x) + // FIXME: + + // B01: ===== ADD (x1, x2) + { + impl::populate_add_dispatch_tables(); + using impl::add_contig_dispatch_table; + using impl::add_contig_matrix_contig_row_broadcast_dispatch_table; + using impl::add_output_id_table; + using impl::add_strided_dispatch_table; + + auto add_pyapi = [&](dpctl::tensor::usm_ndarray src1, + dpctl::tensor::usm_ndarray src2, + dpctl::tensor::usm_ndarray dst, sycl::queue exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, add_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + add_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + add_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + add_contig_matrix_contig_row_broadcast_dispatch_table); + }; + auto add_result_type_pyapi = [&](py::dtype dtype1, py::dtype dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + add_output_id_table); + }; + m.def("_add", add_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_add_result_type", add_result_type_pyapi, ""); + } + + // U04: ===== ASIN (x) + // FIXME: + + // U05: ===== ASINH (x) + // FIXME: + + // U06: ===== ATAN (x) + // FIXME: + + // B02: ===== ATAN2 (x1, x2) + // FIXME: + + // U07: ===== ATANH (x) + // FIXME: + + // B03: ===== BITWISE_AND (x1, x2) + // FIXME: + + // B04: ===== BITWISE_LEFT_SHIFT (x1, x2) + // FIXME: + + // U08: ===== BITWISE_INVERT (x) + // FIXME: + + // B05: ===== BITWISE_OR (x1, x2) + // FIXME: + + // B06: ===== BITWISE_RIGHT_SHIFT (x1, x2) + // FIXME: + + // B07: ===== BITWISE_XOR (x1, x2) + // FIXME: + + // U09: ==== CEIL (x) + // FIXME: + + // U10: ==== CONJ (x) + // FIXME: + + // U11: ==== COS (x) + { + impl::populate_cos_dispatch_vectors(); + using impl::cos_contig_dispatch_vector; + using impl::cos_output_typeid_vector; + using impl::cos_strided_dispatch_vector; + + auto cos_pyapi = [&](arrayT src, arrayT dst, sycl::queue exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, cos_output_typeid_vector, + cos_contig_dispatch_vector, cos_strided_dispatch_vector); + }; + m.def("_cos", cos_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto cos_result_type_pyapi = [&](py::dtype dtype) { + return py_unary_ufunc_result_type(dtype, cos_output_typeid_vector); + }; + m.def("_cos_result_type", cos_result_type_pyapi); + } + + // U12: ==== COSH (x) + // FIXME: + + // B08: ==== DIVIDE (x1, x2) + // FIXME: + + // B09: ==== EQUAL (x1, x2) + // FIXME: + + // U13: ==== EXP (x) + // FIXME: + + // U14: ==== EXPM1 (x) + // FIXME: + + // U15: ==== FLOOR (x) + // FIXME: + + // B10: ==== FLOOR_DIVIDE (x1, x2) + // FIXME: + + // B11: ==== GREATER (x1, x2) + // FIXME: + + // B12: ==== GREATER_EQUAL (x1, x2) + // FIXME: + + // U16: ==== IMAG (x) + // FIXME: + + // U17: ==== ISFINITE (x) + // FIXME: + + // U18: ==== ISINF (x) + // FIXME: + + // U19: ==== ISNAN (x) + { + impl::populate_isnan_dispatch_vectors(); + + using impl::isnan_contig_dispatch_vector; + using impl::isnan_output_typeid_vector; + using impl::isnan_strided_dispatch_vector; + auto isnan_pyapi = [&](dpctl::tensor::usm_ndarray src, + dpctl::tensor::usm_ndarray dst, + sycl::queue exec_q, + const std::vector &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, isnan_output_typeid_vector, + isnan_contig_dispatch_vector, isnan_strided_dispatch_vector); + }; + auto isnan_result_type_pyapi = [&](py::dtype dtype) { + return py_unary_ufunc_result_type(dtype, + isnan_output_typeid_vector); + }; + m.def("_isnan", isnan_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_isnan_result_type", isnan_result_type_pyapi, ""); + } + + // B13: ==== LESS (x1, x2) + // FIXME: + + // B14: ==== LESS_EQUAL (x1, x2) + // FIXME: + + // U20: ==== LOG (x) + // FIXME: + + // U21: ==== LOG1P (x) + // FIXME: + + // U22: ==== LOG2 (x) + // FIXME: + + // U23: ==== LOG10 (x) + // FIXME: + + // B15: ==== LOGADDEXP (x1, x2) + // FIXME: + + // B16: ==== LOGICAL_AND (x1, x2) + // FIXME: + + // U24: ==== LOGICAL_NOT (x) + // FIXME: + + // B17: ==== LOGICAL_OR (x1, x2) + // FIXME: + + // B18: ==== LOGICAL_XOR (x1, x2) + // FIXME: + + // B19: ==== MULTIPLY (x1, x2) + // FIXME: + + // U25: ==== NEGATIVE (x) + // FIXME: + + // B20: ==== NOT_EQUAL (x1, x2) + // FIXME: + + // U26: ==== POSITIVE (x) + // FIXME: + + // B21: ==== POW (x1, x2) + // FIXME: + + // U27: ==== REAL (x) + // FIXME: + + // B22: ==== REMAINDER (x1, x2) + // FIXME: + + // U28: ==== ROUND (x) + // FIXME: + + // U29: ==== SIGN (x) + // FIXME: + + // U30: ==== SIN (x) + // FIXME: + + // U31: ==== SINH (x) + // FIXME: + + // U32: ==== SQUARE (x) + // FIXME: + + // U33: ==== SQRT (x) + // FIXME: + + // B23: ==== SUBTRACT (x1, x2) + // FIXME: + + // U34: ==== TAN (x) + // FIXME: + + // U35: ==== TANH (x) + // FIXME: + + // U36: ==== TRUNC (x) + // FIXME: +} + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.hpp b/dpctl/tensor/libtensor/source/elementwise_functions.hpp new file mode 100644 index 0000000000..c29ced5e1d --- /dev/null +++ b/dpctl/tensor/libtensor/source/elementwise_functions.hpp @@ -0,0 +1,614 @@ +//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2022 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions, +/// specifically functions for elementwise operations. +//===----------------------------------------------------------------------===// + +#pragma once + +#include "dpctl4pybind11.hpp" +#include +#include +#include +#include +#include +#include + +#include "simplify_iteration_space.hpp" +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +extern py::dtype _dtype_from_typenum(td_ns::typenum_t dst_typenum_t); +extern int _result_typeid(int arg_typeid, const int *fn_output_id); + +template +std::pair +py_unary_ufunc(dpctl::tensor::usm_ndarray src, + dpctl::tensor::usm_ndarray dst, + sycl::queue q, + const std::vector &depends, + // + const output_typesT &output_type_vec, + const contig_dispatchT &contig_dispatch_vector, + const strided_dispatchT &strided_dispatch_vector) +{ + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + int func_output_typeid = output_type_vec[src_typeid]; + + // check that types are supported + if (dst_typeid != func_output_typeid) { + throw py::value_error( + "Destination array has unexpected elemental data type."); + } + + // check that queues are compatible + if (!dpctl::utils::queues_are_compatible(q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + // check that dimensions are the same + int src_nd = src.get_ndim(); + if (src_nd != dst.get_ndim()) { + throw py::value_error("Array dimensions are not the same."); + } + + // check that shapes are the same + const py::ssize_t *src_shape = src.get_shape_raw(); + const py::ssize_t *dst_shape = dst.get_shape_raw(); + bool shapes_equal(true); + size_t src_nelems(1); + + for (int i = 0; i < src_nd; ++i) { + src_nelems *= static_cast(src_shape[i]); + shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]); + } + if (!shapes_equal) { + throw py::value_error("Array shapes are not the same."); + } + + // if nelems is zero, return + if (src_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + // ensure that output is ample enough to accomodate all elements + auto dst_offsets = dst.get_minmax_offsets(); + // destination must be ample enough to accomodate all elements + { + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < src_nelems) { + throw py::value_error( + "Destination array can not accomodate all the " + "elements of source array."); + } + } + + // check memory overlap + const char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + + // check that arrays do not overlap, and concurrent copying is safe. + auto src_offsets = src.get_minmax_offsets(); + int src_elem_size = src.get_elemsize(); + int dst_elem_size = dst.get_elemsize(); + + bool memory_overlap = + ((dst_data - src_data > src_offsets.second * src_elem_size - + dst_offsets.first * dst_elem_size) && + (src_data - dst_data > dst_offsets.second * dst_elem_size - + src_offsets.first * src_elem_size)); + if (memory_overlap) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + // handle contiguous inputs + bool is_src_c_contig = src.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_dst_f_contig = dst.is_f_contiguous(); + + bool both_c_contig = (is_src_c_contig && is_dst_c_contig); + bool both_f_contig = (is_src_f_contig && is_dst_f_contig); + + if (both_c_contig || both_f_contig) { + auto contig_fn = contig_dispatch_vector[src_typeid]; + + if (contig_fn == nullptr) { + throw std::runtime_error( + "Contiguous implementation is missing for src_typeid=" + + std::to_string(src_typeid)); + } + + auto comp_ev = contig_fn(q, src_nelems, src_data, dst_data, depends); + sycl::event ht_ev = + dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev}); + + return std::make_pair(ht_ev, comp_ev); + } + + // simplify iteration space + // if 1d with strides 1 - input is contig + // dispatch to strided + + auto const &src_strides = src.get_strides_vector(); + auto const &dst_strides = dst.get_strides_vector(); + + using shT = std::vector; + shT simplified_shape; + shT simplified_src_strides; + shT simplified_dst_strides; + py::ssize_t src_offset(0); + py::ssize_t dst_offset(0); + + int nd = src_nd; + const py::ssize_t *shape = src_shape; + + dpctl::tensor::py_internal::simplify_iteration_space( + nd, shape, src_strides, dst_strides, + // output + simplified_shape, simplified_src_strides, simplified_dst_strides, + src_offset, dst_offset); + + if (nd == 1 && simplified_src_strides[0] == 1 && + simplified_dst_strides[0] == 1) { + // Special case of contiguous data + auto contig_fn = contig_dispatch_vector[src_typeid]; + + if (contig_fn == nullptr) { + throw std::runtime_error( + "Contiguous implementation is missing for src_typeid=" + + std::to_string(src_typeid)); + } + + int src_elem_size = src.get_elemsize(); + int dst_elem_size = dst.get_elemsize(); + auto comp_ev = + contig_fn(q, src_nelems, src_data + src_elem_size * src_offset, + dst_data + dst_elem_size * dst_offset, depends); + + sycl::event ht_ev = + dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev}); + + return std::make_pair(ht_ev, comp_ev); + } + + // Strided implementation + auto strided_fn = strided_dispatch_vector[src_typeid]; + + if (strided_fn == nullptr) { + throw std::runtime_error( + "Strided implementation is missing for src_typeid=" + + std::to_string(src_typeid)); + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + + std::vector host_tasks{}; + host_tasks.reserve(2); + + const auto &ptr_size_event_triple_ = device_allocate_and_pack( + q, host_tasks, simplified_shape, simplified_src_strides, + simplified_dst_strides); + py::ssize_t *shape_strides = std::get<0>(ptr_size_event_triple_); + sycl::event copy_shape_ev = std::get<2>(ptr_size_event_triple_); + + if (shape_strides == nullptr) { + throw std::runtime_error("Device memory allocation failed"); + } + + sycl::event strided_fn_ev = + strided_fn(q, src_nelems, nd, shape_strides, src_data, src_offset, + dst_data, dst_offset, depends, {copy_shape_ev}); + + // async free of shape_strides temporary + auto ctx = q.get_context(); + sycl::event tmp_cleanup_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(strided_fn_ev); + cgh.host_task( + [ctx, shape_strides]() { sycl::free(shape_strides, ctx); }); + }); + host_tasks.push_back(tmp_cleanup_ev); + + return std::make_pair( + dpctl::utils::keep_args_alive(q, {src, dst}, host_tasks), + strided_fn_ev); +} + +template +py::object py_unary_ufunc_result_type(py::dtype input_dtype, + const output_typesT &output_types) +{ + int tn = input_dtype.num(); // NumPy type numbers are the same as in dpctl + int src_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + src_typeid = array_types.typenum_to_lookup_id(tn); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + int dst_typeid = _result_typeid(src_typeid, output_types); + + if (dst_typeid < 0) { + auto res = py::none(); + return py::cast(res); + } + else { + auto dst_typenum_t = static_cast(dst_typeid); + + auto dt = _dtype_from_typenum(dst_typenum_t); + + return py::cast(dt); + } +} + +// ======================== Binary functions =========================== + +namespace +{ +template +bool isEqual(Container const &c, std::initializer_list const &l) +{ + return std::equal(std::begin(c), std::end(c), std::begin(l), std::end(l)); +} +} // namespace + +template +std::pair py_binary_ufunc( + dpctl::tensor::usm_ndarray src1, + dpctl::tensor::usm_ndarray src2, + dpctl::tensor::usm_ndarray dst, // dst = op(src1, src2), elementwise + sycl::queue exec_q, + const std::vector depends, + // + const output_typesT &output_type_table, + const contig_dispatchT &contig_dispatch_table, + const strided_dispatchT &strided_dispatch_table, + const matrix_row_dispatchT &contig_matrix_row_broadcast_dispatch_table) +{ + // check type_nums + int src1_typenum = src1.get_typenum(); + int src2_typenum = src2.get_typenum(); + int dst_typenum = dst.get_typenum(); + + auto array_types = td_ns::usm_ndarray_types(); + int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum); + int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + int output_typeid = output_type_table[src1_typeid][src2_typeid]; + + if (output_typeid != dst_typeid) { + throw py::value_error( + "Destination array has unexpected elemental data type."); + } + + // check that queues are compatible + if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + // check shapes, broadcasting is assumed done by caller + // check that dimensions are the same + int dst_nd = dst.get_ndim(); + if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) { + throw py::value_error("Array dimensions are not the same."); + } + + // check that shapes are the same + const py::ssize_t *src1_shape = src1.get_shape_raw(); + const py::ssize_t *src2_shape = src2.get_shape_raw(); + const py::ssize_t *dst_shape = dst.get_shape_raw(); + bool shapes_equal(true); + size_t src_nelems(1); + + for (int i = 0; i < dst_nd; ++i) { + src_nelems *= static_cast(src1_shape[i]); + shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] && + src2_shape[i] == dst_shape[i]); + } + if (!shapes_equal) { + throw py::value_error("Array shapes are not the same."); + } + + // if nelems is zero, return + if (src_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + // ensure that output is ample enough to accomodate all elements + auto dst_offsets = dst.get_minmax_offsets(); + // destination must be ample enough to accomodate all elements + { + size_t range = + static_cast(dst_offsets.second - dst_offsets.first); + if (range + 1 < src_nelems) { + throw py::value_error( + "Destination array can not accomodate all the " + "elements of source array."); + } + } + + // check memory overlap + const char *src1_data = src1.get_data(); + const char *src2_data = src2.get_data(); + char *dst_data = dst.get_data(); + + // check that arrays do not overlap, and concurrent copying is safe. + auto src1_offsets = src1.get_minmax_offsets(); + int src1_elem_size = src1.get_elemsize(); + auto src2_offsets = src2.get_minmax_offsets(); + int src2_elem_size = src2.get_elemsize(); + int dst_elem_size = dst.get_elemsize(); + + bool memory_overlap_src1_dst = + ((dst_data - src1_data > src1_offsets.second * src1_elem_size - + dst_offsets.first * dst_elem_size) && + (src1_data - dst_data > dst_offsets.second * dst_elem_size - + src1_offsets.first * src1_elem_size)); + bool memory_overlap_src2_dst = + ((dst_data - src2_data > src2_offsets.second * src2_elem_size - + dst_offsets.first * dst_elem_size) && + (src2_data - dst_data > dst_offsets.second * dst_elem_size - + src2_offsets.first * src2_elem_size)); + if (memory_overlap_src1_dst || memory_overlap_src2_dst) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + // handle contiguous inputs + bool is_src1_c_contig = src1.is_c_contiguous(); + bool is_src1_f_contig = src1.is_f_contiguous(); + + bool is_src2_c_contig = src2.is_c_contiguous(); + bool is_src2_f_contig = src2.is_f_contiguous(); + + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_dst_f_contig = dst.is_f_contiguous(); + + bool all_c_contig = + (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig); + bool all_f_contig = + (is_src1_f_contig && is_src2_f_contig && is_dst_f_contig); + + // dispatch for contiguous inputs + if (all_c_contig || all_f_contig) { + auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid]; + + if (contig_fn != nullptr) { + auto comp_ev = contig_fn(exec_q, src_nelems, src1_data, 0, + src2_data, 0, dst_data, 0, depends); + sycl::event ht_ev = dpctl::utils::keep_args_alive( + exec_q, {src1, src2, dst}, {comp_ev}); + + return std::make_pair(ht_ev, comp_ev); + } + } + + // simplify strides + auto const &src1_strides = src1.get_strides_vector(); + auto const &src2_strides = src2.get_strides_vector(); + auto const &dst_strides = dst.get_strides_vector(); + + using shT = std::vector; + shT simplified_shape; + shT simplified_src1_strides; + shT simplified_src2_strides; + shT simplified_dst_strides; + py::ssize_t src1_offset(0); + py::ssize_t src2_offset(0); + py::ssize_t dst_offset(0); + + int nd = dst_nd; + const py::ssize_t *shape = src1_shape; + + // all args except itemsizes and is_?_contig bools can be modified by + // reference + dpctl::tensor::py_internal::simplify_iteration_space_3( + nd, shape, src1_strides, src2_strides, dst_strides, + // outputs + simplified_shape, simplified_src1_strides, simplified_src2_strides, + simplified_dst_strides, src1_offset, src2_offset, dst_offset); + + std::vector host_tasks{}; + + if (nd < 3) { + static constexpr auto unit_stride = + std::initializer_list{1}; + + if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) && + isEqual(simplified_src2_strides, unit_stride) && + isEqual(simplified_dst_strides, unit_stride)) + { + auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid]; + + if (contig_fn != nullptr) { + auto comp_ev = contig_fn(exec_q, src_nelems, src1_data, + src1_offset, src2_data, src2_offset, + dst_data, dst_offset, depends); + sycl::event ht_ev = dpctl::utils::keep_args_alive( + exec_q, {src1, src2, dst}, {comp_ev}); + + return std::make_pair(ht_ev, comp_ev); + } + } + if (nd == 2) { + static constexpr auto zero_one_strides = + std::initializer_list{0, 1}; + static constexpr auto one_zero_strides = + std::initializer_list{1, 0}; + constexpr py::ssize_t one{1}; + // special case of C-contiguous matrix and a row + if (isEqual(simplified_src2_strides, zero_one_strides) && + isEqual(simplified_src1_strides, {simplified_shape[1], one}) && + isEqual(simplified_dst_strides, {simplified_shape[1], one})) + { + auto matrix_row_broadcast_fn = + contig_matrix_row_broadcast_dispatch_table[src1_typeid] + [src2_typeid]; + if (matrix_row_broadcast_fn != nullptr) { + size_t n0 = simplified_shape[0]; + size_t n1 = simplified_shape[1]; + sycl::event comp_ev = matrix_row_broadcast_fn( + exec_q, host_tasks, n0, n1, src1_data, src1_offset, + src2_data, src2_offset, dst_data, dst_offset, depends); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, + host_tasks), + comp_ev); + } + } + if (isEqual(simplified_src1_strides, one_zero_strides) && + isEqual(simplified_src2_strides, {one, simplified_shape[0]}) && + isEqual(simplified_dst_strides, {one, simplified_shape[0]})) + { + auto matrix_row_broadcast_fn = + contig_matrix_row_broadcast_dispatch_table[src2_typeid] + [src1_typeid]; + if (matrix_row_broadcast_fn != nullptr) { + size_t n0 = simplified_shape[1]; + size_t n1 = simplified_shape[0]; + sycl::event comp_ev = matrix_row_broadcast_fn( + exec_q, host_tasks, n0, n1, src2_data, src2_offset, + src1_data, src1_offset, dst_data, dst_offset, depends); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, + host_tasks), + comp_ev); + } + } + } + } + + // dispatch to strided code + auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid]; + + if (strided_fn == nullptr) { + throw std::runtime_error( + "Contiguous implementation is missing for src1_typeid=" + + std::to_string(src1_typeid) + + " and src2_typeid=" + std::to_string(src2_typeid)); + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &ptr_sz_event_triple_ = device_allocate_and_pack( + exec_q, host_tasks, simplified_shape, simplified_src1_strides, + simplified_src2_strides, simplified_dst_strides); + + py::ssize_t *shape_strides = std::get<0>(ptr_sz_event_triple_); + sycl::event copy_shape_ev = std::get<2>(ptr_sz_event_triple_); + + if (shape_strides == nullptr) { + throw std::runtime_error("Unabled to allocate device memory"); + } + + sycl::event strided_fn_ev = strided_fn( + exec_q, src_nelems, nd, shape_strides, src1_data, src1_offset, + src2_data, src2_offset, dst_data, dst_offset, depends, {copy_shape_ev}); + + // async free of shape_strides temporary + auto ctx = exec_q.get_context(); + + sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(strided_fn_ev); + cgh.host_task( + [ctx, shape_strides]() { sycl::free(shape_strides, ctx); }); + }); + + host_tasks.push_back(tmp_cleanup_ev); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, host_tasks), + strided_fn_ev); +} + +template +py::object py_binary_ufunc_result_type(py::dtype input1_dtype, + py::dtype input2_dtype, + const output_typesT &output_types_table) +{ + int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl + int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl + int src1_typeid = -1; + int src2_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + src1_typeid = array_types.typenum_to_lookup_id(tn1); + src2_typeid = array_types.typenum_to_lookup_id(tn2); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 || + src2_typeid >= td_ns::num_types) + { + throw std::runtime_error("binary output type lookup failed"); + } + int dst_typeid = output_types_table[src1_typeid][src2_typeid]; + + if (dst_typeid < 0) { + auto res = py::none(); + return py::cast(res); + } + else { + auto dst_typenum_t = static_cast(dst_typeid); + + auto dt = _dtype_from_typenum(dst_typenum_t); + + return py::cast(dt); + } +} + +extern void init_elementwise_functions(py::module_ m); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 2cf627be18..6d4901f135 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -38,6 +38,7 @@ #include "copy_for_reshape.hpp" #include "copy_numpy_ndarray_into_usm_ndarray.hpp" #include "device_support_queries.hpp" +#include "elementwise_functions.hpp" #include "eye_ctor.hpp" #include "full_ctor.hpp" #include "integer_advanced_indexing.hpp" @@ -346,4 +347,6 @@ PYBIND11_MODULE(_tensor_impl, m) m.def("_where", &py_where, "", py::arg("condition"), py::arg("x1"), py::arg("x2"), py::arg("dst"), py::arg("sycl_queue"), py::arg("depends") = py::list()); + + dpctl::tensor::py_internal::init_elementwise_functions(m); } diff --git a/dpctl/tests/test_tensor_elementwise.py b/dpctl/tests/test_tensor_elementwise.py new file mode 100644 index 0000000000..4936574843 --- /dev/null +++ b/dpctl/tests/test_tensor_elementwise.py @@ -0,0 +1,377 @@ +import itertools + +import numpy as np +import pytest + +import dpctl +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +_all_dtypes = [ + "b1", + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", +] +_usm_types = ["device", "shared", "host"] + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_abs_out_type(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + X = dpt.asarray(0, dtype=arg_dt, sycl_queue=q) + if np.issubdtype(arg_dt, np.complexfloating): + type_map = { + np.dtype("c8"): np.dtype("f4"), + np.dtype("c16"): np.dtype("f8"), + } + assert dpt.abs(X).dtype == type_map[arg_dt] + else: + assert dpt.abs(X).dtype == arg_dt + + +@pytest.mark.parametrize("usm_type", _usm_types) +def test_abs_usm_type(usm_type): + q = get_queue_or_skip() + + arg_dt = np.dtype("i4") + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q) + X[..., 0::2] = 1 + X[..., 1::2] = 0 + + Y = dpt.abs(X) + assert Y.usm_type == X.usm_type + assert Y.sycl_queue == X.sycl_queue + assert Y.flags.c_contiguous + + expected_Y = dpt.asnumpy(X) + assert np.allclose(dpt.asnumpy(Y), expected_Y) + + +@pytest.mark.parametrize("dtype", _all_dtypes[1:]) +def test_abs_order(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) + X[..., 0::2] = 1 + X[..., 1::2] = 0 + + for ord in ["C", "F", "A", "K"]: + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms) + Y = dpt.abs(U, order=ord) + expected_Y = np.ones(Y.shape, dtype=Y.dtype) + expected_Y[..., 1::2] = 0 + expected_Y = np.transpose(expected_Y, perms) + assert np.allclose(dpt.asnumpy(Y), expected_Y) + + +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_abs_complex(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) + Xnp = np.random.standard_normal( + size=input_shape + ) + 1j * np.random.standard_normal(size=input_shape) + Xnp = Xnp.astype(arg_dt) + X[...] = Xnp + + for ord in ["C", "F", "A", "K"]: + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms) + Y = dpt.abs(U, order=ord) + expected_Y = np.abs(np.transpose(Xnp[:, ::-1, ::-1, :], perms)) + tol = dpt.finfo(Y.dtype).resolution + np.testing.assert_allclose( + dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol + ) + + +def _compare_dtypes(dt, ref_dt, sycl_queue=None): + assert isinstance(sycl_queue, dpctl.SyclQueue) + dev = sycl_queue.sycl_device + expected_dt = ref_dt + if not dev.has_aspect_fp64: + if expected_dt == dpt.float64: + expected_dt = dpt.float32 + elif expected_dt == dpt.complex128: + expected_dt = dpt.complex64 + if not dev.has_aspect_fp16: + if expected_dt == dpt.float16: + expected_dt = dpt.float32 + return dt == expected_dt + + +@pytest.mark.parametrize("op1_dtype", _all_dtypes) +@pytest.mark.parametrize("op2_dtype", _all_dtypes) +def test_add_dtype_matrix(op1_dtype, op2_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(op1_dtype, q) + skip_if_dtype_not_supported(op2_dtype, q) + + sz = 127 + ar1 = dpt.ones(sz, dtype=op1_dtype) + ar2 = dpt.ones_like(ar1, dtype=op2_dtype) + + r = dpt.add(ar1, ar2) + assert isinstance(r, dpt.usm_ndarray) + expected_dtype = np.add( + np.zeros(1, dtype=op1_dtype), np.zeros(1, dtype=op2_dtype) + ).dtype + assert _compare_dtypes(r.dtype, expected_dtype, sycl_queue=q) + assert r.shape == ar1.shape + assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all() + assert r.sycl_queue == ar1.sycl_queue + + ar3 = dpt.ones(sz, dtype=op1_dtype) + ar4 = dpt.ones(2 * sz, dtype=op2_dtype) + + r = dpt.add(ar3[::-1], ar4[::2]) + assert isinstance(r, dpt.usm_ndarray) + expected_dtype = np.add( + np.zeros(1, dtype=op1_dtype), np.zeros(1, dtype=op2_dtype) + ).dtype + assert _compare_dtypes(r.dtype, expected_dtype, sycl_queue=q) + assert r.shape == ar3.shape + assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all() + + +@pytest.mark.parametrize("op1_usm_type", _usm_types) +@pytest.mark.parametrize("op2_usm_type", _usm_types) +def test_add_usm_type_matrix(op1_usm_type, op2_usm_type): + get_queue_or_skip() + + sz = 128 + ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type) + ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type) + + r = dpt.add(ar1, ar2) + assert isinstance(r, dpt.usm_ndarray) + expected_usm_type = dpctl.utils.get_coerced_usm_type( + (op1_usm_type, op2_usm_type) + ) + assert r.usm_type == expected_usm_type + + +def test_add_order(): + get_queue_or_skip() + + ar1 = dpt.ones((20, 20), dtype="i4", order="C") + ar2 = dpt.ones((20, 20), dtype="i4", order="C") + r1 = dpt.add(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.add(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.add(ar1, ar2, order="A") + assert r3.flags.c_contiguous + r4 = dpt.add(ar1, ar2, order="K") + assert r4.flags.c_contiguous + + ar1 = dpt.ones((20, 20), dtype="i4", order="F") + ar2 = dpt.ones((20, 20), dtype="i4", order="F") + r1 = dpt.add(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.add(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.add(ar1, ar2, order="A") + assert r3.flags.f_contiguous + r4 = dpt.add(ar1, ar2, order="K") + assert r4.flags.f_contiguous + + ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] + ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] + r4 = dpt.add(ar1, ar2, order="K") + assert r4.strides == (20, -1) + + ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT + ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT + r4 = dpt.add(ar1, ar2, order="K") + assert r4.strides == (-1, 20) + + +def test_add_broadcasting(): + get_queue_or_skip() + + m = dpt.ones((100, 5), dtype="i4") + v = dpt.arange(5, dtype="i4") + + r = dpt.add(m, v) + + assert (dpt.asnumpy(r) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all() + + r2 = dpt.add(v, m) + assert (dpt.asnumpy(r2) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all() + + +def _map_to_device_dtype(dt, dev): + if np.issubdtype(dt, np.integer): + return dt + if np.issubdtype(dt, np.floating): + dtc = np.dtype(dt).char + if dtc == "d": + return dt if dev.has_aspect_fp64 else dpt.float32 + elif dtc == "e": + return dt if dev.has_aspect_fp16 else dpt.float32 + return dt + if np.issubdtype(dt, np.complexfloating): + dtc = np.dtype(dt).char + if dtc == "D": + return dt if dev.has_aspect_fp64 else dpt.complex64 + return dt + return dt + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_cos_out_type(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + X = dpt.asarray(0, dtype=dtype, sycl_queue=q) + expected_dtype = np.cos(np.array(0, dtype=dtype)).dtype + expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device) + assert dpt.cos(X).dtype == expected_dtype + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"]) +def test_cos_output(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n_seq = 100 + n_rep = 137 + + Xnp = np.linspace(-np.pi / 4, np.pi / 4, num=n_seq, dtype=dtype) + X = dpt.asarray(np.repeat(Xnp, n_rep), dtype=dtype, sycl_queue=q) + + Y = dpt.cos(X) + tol = 8 * dpt.finfo(Y.dtype).resolution + + np.testing.assert_allclose( + dpt.asnumpy(Y), np.repeat(np.cos(Xnp), n_rep), atol=tol, rtol=tol + ) + + +@pytest.mark.parametrize("usm_type", ["device", "shared", "host"]) +def test_cos_usm_type(usm_type): + q = get_queue_or_skip() + + arg_dt = np.dtype("f4") + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q) + X[..., 0::2] = np.pi / 6 + X[..., 1::2] = np.pi / 3 + + Y = dpt.cos(X) + assert Y.usm_type == X.usm_type + assert Y.sycl_queue == X.sycl_queue + assert Y.flags.c_contiguous + + expected_Y = np.empty(input_shape, dtype=arg_dt) + expected_Y[..., 0::2] = np.cos(np.float32(np.pi / 6)) + expected_Y[..., 1::2] = np.cos(np.float32(np.pi / 3)) + assert np.allclose(dpt.asnumpy(Y), expected_Y) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_cos_order(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) + X[..., 0::2] = np.pi / 6 + X[..., 1::2] = np.pi / 3 + + for ord in ["C", "F", "A", "K"]: + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms) + Y = dpt.cos(U, order=ord) + expected_Y = np.cos(dpt.asnumpy(U)) + assert np.allclose(dpt.asnumpy(Y), expected_Y) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_isnan_out_type(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + X = dpt.asarray(0, dtype=dtype, sycl_queue=q) + assert dpt.isnan(X).dtype == dpt.bool + + +def test_isnan_output(): + q = get_queue_or_skip() + + Xnp = np.asarray(np.nan) + X = dpt.asarray(np.nan, sycl_queue=q) + assert dpt.asnumpy(dpt.isnan(X))[()] == np.isnan(Xnp) + + +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_isnan_complex(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = complex(np.nan, np.nan) + y2 = complex(1, np.nan) + y3 = complex(np.nan, 1) + y4 = complex(2, 1) + y5 = complex(np.inf, 1) + + Ynp = np.repeat(np.array([y1, y2, y3, y4, y5], dtype=dtype), 123) + Y = dpt.asarray(Ynp, sycl_queue=q) + assert np.array_equal(dpt.asnumpy(dpt.isnan(Y))[()], np.isnan(Ynp)) + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) +def test_isnan_floats(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = np.nan + y2 = 1 + y3 = np.inf + + for mult in [123, 137, 255, 271, 272]: + Ynp = np.repeat(np.array([y1, y2, y3], dtype=dtype), mult) + Y = dpt.asarray(Ynp, sycl_queue=q) + assert np.array_equal(dpt.asnumpy(dpt.isnan(Y))[()], np.isnan(Ynp)) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_isnan_order(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.ones(input_shape, dtype=arg_dt, sycl_queue=q) + + for ord in ["C", "F", "A", "K"]: + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[::2, ::-1, ::-1, ::5], perms) + Y = dpt.isnan(U, order=ord) + expected_Y = np.full(Y.shape, False, dtype=Y.dtype) + assert np.allclose(dpt.asnumpy(Y), expected_Y) From d7b42382390a93745a417a2cd2558404ba05b8d4 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 9 May 2023 09:43:39 -0500 Subject: [PATCH 02/48] Fixed add_contig_impl and add_matrix_vector_broadcasting_contig_impl Corrected/added checks for validity of sub-groups reads/writes. Added -fno-approx-func flag to compile element-wise functions, as well as -fno-finite-math-only flag. Fixed test_cos_order test to account for NumPy using float16 for intermediate computations for inputs of type "i1", but CPU RT does not support float16. --- dpctl/tensor/CMakeLists.txt | 2 +- .../kernels/elementwise_functions/add.hpp | 39 ++++++++++++------- dpctl/tests/test_tensor_elementwise.py | 12 +++++- 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index 40bb7cd7da..1b3a59f02f 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -53,7 +53,7 @@ if (WIN32) endif() set_source_files_properties( ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp - PROPERTIES COMPILE_OPTIONS "${_clang_prefx}-fno-approx-func") + PROPERTIES COMPILE_OPTIONS "${_clang_prefx}-fno-approx-func;${_clang_prefx}-fno-finite-math-only") target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int) target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel) if(UNIX) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index d3f1d2fd82..5f008128d6 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -66,7 +66,8 @@ struct AddContigFunctor (ndit.get_group(0) * ndit.get_local_range(0) + sg.get_group_id()[0] * maxsgSize); - if (base + n_vecs * vec_sz < nelems_) { + if ((base + n_vecs * vec_sz * sgSize < nelems_) && + (sgSize == maxsgSize)) { using in_ptrT1 = sycl::multi_ptr; @@ -428,7 +429,8 @@ sycl::event add_contig_matrix_contig_row_broadcast_impl( cgh.depends_on(make_padded_vec_ev); auto lwsRange = sycl::range<1>(lws); - size_t n_groups = (n0 * n1 + lws - 1) / lws; + size_t n_elems = n0 * n1; + size_t n_groups = (n_elems + lws - 1) / lws; auto gwsRange = sycl::range<1>(n_groups * lws); cgh.parallel_for>( @@ -438,24 +440,31 @@ sycl::event add_contig_matrix_contig_row_broadcast_impl( auto sg = ndit.get_sub_group(); size_t gid = ndit.get_global_linear_id(); + std::uint8_t sgSize = sg.get_local_range()[0]; size_t base = gid - sg.get_local_id()[0]; - using in_ptrT1 = - sycl::multi_ptr; - using in_ptrT2 = - sycl::multi_ptr; - using res_ptrT = - sycl::multi_ptr; + if (base + sgSize < n_elems) { + using in_ptrT1 = sycl::multi_ptr< + const argT1, sycl::access::address_space::global_space>; + using in_ptrT2 = sycl::multi_ptr< + const argT2, sycl::access::address_space::global_space>; + using res_ptrT = sycl::multi_ptr< + resT, sycl::access::address_space::global_space>; - const argT1 mat_el = sg.load(in_ptrT1(&mat[base])); - const argT2 vec_el = sg.load(in_ptrT2(&padded_vec[base % n1])); + const argT1 mat_el = sg.load(in_ptrT1(&mat[base])); + const argT2 vec_el = + sg.load(in_ptrT2(&padded_vec[base % n1])); - resT res_el = mat_el + vec_el; + resT res_el = mat_el + vec_el; - sg.store(res_ptrT(&res[base]), res_el); + sg.store(res_ptrT(&res[base]), res_el); + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < n_elems; + k += sgSize) { + res[k] = mat[k] + padded_vec[k % n1]; + } + } } ); }); diff --git a/dpctl/tests/test_tensor_elementwise.py b/dpctl/tests/test_tensor_elementwise.py index 4936574843..e81c24fecb 100644 --- a/dpctl/tests/test_tensor_elementwise.py +++ b/dpctl/tests/test_tensor_elementwise.py @@ -290,7 +290,9 @@ def test_cos_usm_type(usm_type): expected_Y = np.empty(input_shape, dtype=arg_dt) expected_Y[..., 0::2] = np.cos(np.float32(np.pi / 6)) expected_Y[..., 1::2] = np.cos(np.float32(np.pi / 3)) - assert np.allclose(dpt.asnumpy(Y), expected_Y) + tol = 8 * dpt.finfo(Y.dtype).resolution + + np.testing.assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol) @pytest.mark.parametrize("dtype", _all_dtypes) @@ -309,7 +311,13 @@ def test_cos_order(dtype): U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms) Y = dpt.cos(U, order=ord) expected_Y = np.cos(dpt.asnumpy(U)) - assert np.allclose(dpt.asnumpy(Y), expected_Y) + tol = 8 * max( + dpt.finfo(Y.dtype).resolution, + np.finfo(expected_Y.dtype).resolution, + ) + np.testing.assert_allclose( + dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol + ) @pytest.mark.parametrize("dtype", _all_dtypes) From 70615d1573b06d5f7ec119637dcb50aa2f4a8f65 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 9 May 2023 10:55:26 -0500 Subject: [PATCH 03/48] Fixed typo in CMakeLists.txt cause -fno-apporox-func to be ignored --- dpctl/tensor/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index 1b3a59f02f..56cd33b6b2 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -53,7 +53,7 @@ if (WIN32) endif() set_source_files_properties( ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp - PROPERTIES COMPILE_OPTIONS "${_clang_prefx}-fno-approx-func;${_clang_prefx}-fno-finite-math-only") + PROPERTIES COMPILE_OPTIONS "${_clang_prefix}-fno-approx-func;${_clang_prefix}-fno-finite-math-only") target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int) target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel) if(UNIX) From a947b3a11995a37c073960989891fdbcffd46b18 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 10 May 2023 13:37:52 -0500 Subject: [PATCH 04/48] Fixing typos discovered by added tests --- dpctl/tensor/_elementwise_common.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index b0e147c436..f878edb32f 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -191,17 +191,21 @@ def _weak_type_num_kind(o): return _map["i"] if isinstance(o, WeakInexactType): return _map["f"] - raise TypeError + raise TypeError( + f"Unexpected type {o} while expecting " + "`WeakBooleanType`, `WeakIntegralType`, or " + "`WeakInexactType`." + ) def _strong_dtype_num_kind(o): - _map = {"?": 0, "i": 1, "u": 1, "f": 2, "c": 2} + _map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 2} if not isinstance(o, dpt.dtype): raise TypeError k = o.kind if k in _map: return _map[k] - raise ValueError + raise ValueError(f"Unrecognized kind {k} for dtype {o}") def _resolve_weak_types(o1_dtype, o2_dtype, dev): From 872f372e2dd23990f1994ac9b39258a4fc49d225 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 10 May 2023 16:46:37 -0500 Subject: [PATCH 05/48] Added tests to increase coverage of elementwise functions --- dpctl/tensor/_elementwise_common.py | 68 +++---- dpctl/tests/test_tensor_elementwise.py | 234 ++++++++++++++++++++++--- 2 files changed, 230 insertions(+), 72 deletions(-) diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index f878edb32f..f8cb197dc7 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -49,9 +49,7 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs): def __call__(self, x, order="K"): if not isinstance(x, dpt.usm_ndarray): - raise TypeError( - f"Expected :class:`dpctl.tensor.usm_ndarray`, got {type(x)}" - ) + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") if order not in ["C", "F", "K", "A"]: order = "K" buf_dt, res_dt = _find_buf_dtype( @@ -85,8 +83,6 @@ def __call__(self, x, order="K"): if order == "K": r = _empty_like_orderK(buf, res_dt) else: - if order == "A": - order = "F" if buf.flags.f_contiguous else "C" r = dpt.empty_like(buf, dtype=res_dt, order=order) ht, _ = self.unary_fn_(buf, r, sycl_queue=exec_q, depends=[copy_ev]) @@ -142,6 +138,8 @@ def get(self): def _get_dtype(o, dev): if isinstance(o, dpt.usm_ndarray): return o.dtype + if hasattr(o, "__sycl_usm_array_interface__"): + return dpt.asarray(o).dtype if _is_buffer(o): host_dt = np.array(o).dtype dev_dt = _to_device_supported_dtype(host_dt, dev) @@ -224,13 +222,12 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev): return dpt.bool, o2_dtype if isinstance(o1_dtype, WeakIntegralType): return dpt.int64, o2_dtype - if isinstance(o1_dtype, WeakInexactType): - if isinstance(o1_dtype.get(), complex): - return ( - _to_device_supported_dtype(dpt.complex128, dev), - o2_dtype, - ) - return _to_device_supported_dtype(dpt.float64, dev), o2_dtype + if isinstance(o1_dtype.get(), complex): + return ( + _to_device_supported_dtype(dpt.complex128, dev), + o2_dtype, + ) + return _to_device_supported_dtype(dpt.float64, dev), o2_dtype else: return o2_dtype, o2_dtype elif isinstance( @@ -243,15 +240,12 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev): return o1_dtype, dpt.bool if isinstance(o2_dtype, WeakIntegralType): return o1_dtype, dpt.int64 - if isinstance(o2_dtype, WeakInexactType): - if isinstance(o2_dtype.get(), complex): - return o1_dtype, _to_device_supported_dtype( - dpt.complex128, dev - ) - return ( - o1_dtype, - _to_device_supported_dtype(dpt.float64, dev), - ) + if isinstance(o2_dtype.get(), complex): + return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev) + return ( + o1_dtype, + _to_device_supported_dtype(dpt.float64, dev), + ) else: return o1_dtype, o1_dtype else: @@ -287,10 +281,14 @@ def __repr__(self): return f"" def __call__(self, o1, o2, order="K"): + if order not in ["K", "C", "F", "A"]: + order = "K" q1, o1_usm_type = _get_queue_usm_type(o1) q2, o2_usm_type = _get_queue_usm_type(o2) if q1 is None and q2 is None: - raise ValueError( + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments. " "One of the arguments must represent USM allocation and " "expose `__sycl_usm_array_interface__` property" ) @@ -415,18 +413,6 @@ def __call__(self, o1, o2, order="K"): src1, buf2, res_dt, res_usm_type, exec_q ) else: - if order == "A": - order = ( - "F" - if all( - arr.flags.f_contiguous - for arr in ( - src1, - buf2, - ) - ) - else "C" - ) r = dpt.empty( res_shape, dtype=res_dt, @@ -461,18 +447,6 @@ def __call__(self, o1, o2, order="K"): buf1, src2, res_dt, res_usm_type, exec_q ) else: - if order == "A": - order = ( - "F" - if all( - arr.flags.f_contiguous - for arr in ( - buf1, - src2, - ) - ) - else "C" - ) r = dpt.empty( res_shape, dtype=res_dt, @@ -493,7 +467,7 @@ def __call__(self, o1, o2, order="K"): ht_.wait() return r - if order in "KA": + if order in ["K", "A"]: if src1.flags.f_contiguous and src2.flags.f_contiguous: order = "F" else: diff --git a/dpctl/tests/test_tensor_elementwise.py b/dpctl/tests/test_tensor_elementwise.py index e81c24fecb..85daf797f0 100644 --- a/dpctl/tests/test_tensor_elementwise.py +++ b/dpctl/tests/test_tensor_elementwise.py @@ -1,3 +1,4 @@ +import ctypes import itertools import numpy as np @@ -5,6 +6,8 @@ import dpctl import dpctl.tensor as dpt +import dpctl.tensor._type_utils as tu +import dpctl.utils from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported _all_dtypes = [ @@ -26,6 +29,157 @@ _usm_types = ["device", "shared", "host"] +class MockDevice: + def __init__(self, fp16: bool, fp64: bool): + self.has_aspect_fp16 = fp16 + self.has_aspect_fp64 = fp64 + + +def _map_to_device_dtype(dt, dev): + return tu._to_device_supported_dtype(dt, dev) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_type_utils_map_to_device_type(dtype): + for fp64 in [ + True, + False, + ]: + for fp16 in [True, False]: + dev = MockDevice(fp16, fp64) + dt_in = dpt.dtype(dtype) + dt_out = _map_to_device_dtype(dt_in, dev) + assert isinstance(dt_out, dpt.dtype) + + +def test_type_util_all_data_types(): + for fp64 in [ + True, + False, + ]: + for fp16 in [True, False]: + r = tu._all_data_types(fp16, fp64) + assert isinstance(r, list) + # 11: bool + 4 signed + 4 unsigned inegral + float32 + complex64 + assert len(r) == 11 + int(fp16) + 2 * int(fp64) + + +def test_type_util_can_cast(): + for fp64 in [ + True, + False, + ]: + for fp16 in [True, False]: + for from_ in _all_dtypes: + for to_ in _all_dtypes: + r = tu._can_cast( + dpt.dtype(from_), dpt.dtype(to_), fp16, fp64 + ) + assert isinstance(r, bool) + + +def test_type_utils_empty_like_orderK(): + try: + a = dpt.empty((10, 10), dtype=dpt.int32, order="F") + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + X = tu._empty_like_orderK(a, dpt.int32, a.usm_type, a.device) + assert X.flags["F"] + + +def test_type_utils_empty_like_orderK_invalid_args(): + with pytest.raises(TypeError): + tu._empty_like_orderK([1, 2, 3], dpt.int32, "device", None) + with pytest.raises(TypeError): + tu._empty_like_pair_orderK( + [1, 2, 3], + ( + 1, + 2, + 3, + ), + dpt.int32, + "device", + None, + ) + try: + a = dpt.empty(10, dtype=dpt.int32) + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + with pytest.raises(TypeError): + tu._empty_like_pair_orderK( + a, + ( + 1, + 2, + 3, + ), + dpt.int32, + "device", + None, + ) + + +def test_type_utils_find_buf_dtype(): + def _denier_fn(dt): + return False + + for fp64 in [ + True, + False, + ]: + for fp16 in [True, False]: + dev = MockDevice(fp16, fp64) + arg_dt = dpt.float64 + r = tu._find_buf_dtype(arg_dt, _denier_fn, dev) + assert r == ( + None, + None, + ) + + +def test_type_utils_find_buf_dtype2(): + def _denier_fn(dt1, dt2): + return False + + for fp64 in [ + True, + False, + ]: + for fp16 in [True, False]: + dev = MockDevice(fp16, fp64) + arg1_dt = dpt.float64 + arg2_dt = dpt.complex64 + r = tu._find_buf_dtype2(arg1_dt, arg2_dt, _denier_fn, dev) + assert r == ( + None, + None, + None, + ) + + +def test_unary_func_arg_validation(): + with pytest.raises(TypeError): + dpt.abs([1, 2, 3]) + try: + a = dpt.arange(8) + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + dpt.abs(a, order="invalid") + + +def test_binary_func_arg_vaidation(): + with pytest.raises(dpctl.utils.ExecutionPlacementError): + dpt.add([1, 2, 3], 1) + try: + a = dpt.arange(8) + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + with pytest.raises(ValueError): + dpt.add(a, Ellipsis) + dpt.add(a, a, order="invalid") + + @pytest.mark.parametrize("dtype", _all_dtypes) def test_abs_out_type(dtype): q = get_queue_or_skip() @@ -111,15 +265,7 @@ def test_abs_complex(dtype): def _compare_dtypes(dt, ref_dt, sycl_queue=None): assert isinstance(sycl_queue, dpctl.SyclQueue) dev = sycl_queue.sycl_device - expected_dt = ref_dt - if not dev.has_aspect_fp64: - if expected_dt == dpt.float64: - expected_dt = dpt.float32 - elif expected_dt == dpt.complex128: - expected_dt = dpt.complex64 - if not dev.has_aspect_fp16: - if expected_dt == dpt.float16: - expected_dt = dpt.float32 + expected_dt = _map_to_device_dtype(ref_dt, dev) return dt == expected_dt @@ -224,22 +370,60 @@ def test_add_broadcasting(): assert (dpt.asnumpy(r2) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all() -def _map_to_device_dtype(dt, dev): - if np.issubdtype(dt, np.integer): - return dt - if np.issubdtype(dt, np.floating): - dtc = np.dtype(dt).char - if dtc == "d": - return dt if dev.has_aspect_fp64 else dpt.float32 - elif dtc == "e": - return dt if dev.has_aspect_fp16 else dpt.float32 - return dt - if np.issubdtype(dt, np.complexfloating): - dtc = np.dtype(dt).char - if dtc == "D": - return dt if dev.has_aspect_fp64 else dpt.complex64 - return dt - return dt +@pytest.mark.parametrize("arr_dt", _all_dtypes) +def test_add_python_scalar(arr_dt): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arr_dt, q) + + X = dpt.zeros((10, 10), dtype=arr_dt, sycl_queue=q) + py_zeros = ( + bool(0), + int(0), + float(0), + complex(0), + np.float32(0), + ctypes.c_int(0), + ) + for sc in py_zeros: + R = dpt.add(X, sc) + assert isinstance(R, dpt.usm_ndarray) + R = dpt.add(sc, X) + assert isinstance(R, dpt.usm_ndarray) + + +class MockArray: + def __init__(self, arr): + self.data_ = arr + + @property + def __sycl_usm_array_interface__(self): + return self.data_.__sycl_usm_array_interface__ + + +def test_add_mock_array(): + get_queue_or_skip() + a = dpt.arange(10) + b = dpt.ones(10) + c = MockArray(b) + r = dpt.add(a, c) + assert isinstance(r, dpt.usm_ndarray) + + +def test_add_canary_mock_array(): + get_queue_or_skip() + a = dpt.arange(10) + + class Canary: + def __init__(self): + pass + + @property + def __sycl_usm_array_interface__(self): + return None + + c = Canary() + with pytest.raises(ValueError): + dpt.add(a, c) @pytest.mark.parametrize("dtype", _all_dtypes) From 6d71e46db345c00a622c57016ec7211ada6a1067 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 11 May 2023 16:12:24 -0500 Subject: [PATCH 06/48] Implemented isinf, isfinite, reused templates to define Contig and Strided unary functors Added tests, --- dpctl/tensor/__init__.py | 4 +- dpctl/tensor/_elementwise_funcs.py | 22 +- .../kernels/elementwise_functions/abs.hpp | 156 ++---------- .../kernels/elementwise_functions/common.hpp | 231 ++++++++++++++++++ .../kernels/elementwise_functions/cos.hpp | 161 ++++-------- .../elementwise_functions/isfinite.hpp | 209 ++++++++++++++++ .../kernels/elementwise_functions/isinf.hpp | 207 ++++++++++++++++ .../kernels/elementwise_functions/isnan.hpp | 70 +++++- .../source/elementwise_functions.cpp | 177 ++++++++++++-- dpctl/tests/test_tensor_elementwise.py | 138 ++++++++++- 10 files changed, 1091 insertions(+), 284 deletions(-) create mode 100644 dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp create mode 100644 dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 3112e9c7b6..48c8c045a8 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -90,7 +90,7 @@ from dpctl.tensor._usmarray import usm_ndarray from ._constants import e, inf, nan, newaxis, pi -from ._elementwise_funcs import abs, add, cos, isnan +from ._elementwise_funcs import abs, add, cos, isfinite, isinf, isnan __all__ = [ "Device", @@ -168,5 +168,7 @@ "abs", "add", "cos", + "isinf", "isnan", + "isfinite", ] diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index 5113ad0540..566fea9658 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -42,12 +42,32 @@ cos = UnaryElementwiseFunc("cos", ti._cos_result_type, ti._cos, _cos_docstring) +# ISFINITE + +_isfinite_docstring_ = """ +Computes if every element of input array is a finite number. +""" + +isfinite = UnaryElementwiseFunc( + "isfinite", ti._isfinite_result_type, ti._isfinite, _isfinite_docstring_ +) + # ISNAN _isnan_docstring_ = """ -Computes if ever element of input array is a NaN. +Computes if every element of input array is a NaN. """ isnan = UnaryElementwiseFunc( "isnan", ti._isnan_result_type, ti._isnan, _isnan_docstring_ ) + +# ISINF + +_isinf_docstring_ = """ +Computes if every element of input array is an infinity. +""" + +isinf = UnaryElementwiseFunc( + "isinf", ti._isinf_result_type, ti._isinf, _isinf_docstring_ +) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index f7221cddb6..a95cf78d69 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -1,11 +1,15 @@ #pragma once #include +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" #include "utils/type_utils.hpp" #include +#include + namespace dpctl { namespace tensor @@ -18,120 +22,40 @@ namespace abs namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -template -struct AbsContigFunctor +using dpctl::tensor::type_utils::is_complex; + +template struct AbsFunctor { -private: - const argT *in = nullptr; - resT *out = nullptr; - const size_t nelems_; -public: - AbsContigFunctor(const argT *inp, resT *res, const size_t n_elems) - : in(inp), out(res), nelems_(n_elems) - { - } + using is_constant = typename std::false_type; + // constexpr resT constant_value = resT{}; + using supports_vec = typename std::false_type; + using supports_sg_loadstore = typename std::negation< + std::disjunction, is_complex>>; - void operator()(sycl::nd_item<1> ndit) const + resT operator()(const argT &x) { - /* Each work-item processes vec_sz elements, contiguous in memory */ - /* NOTE: vec_sz must divide sg.max_local_range()[0] */ if constexpr (std::is_same_v || (std::is_integral::value && std::is_unsigned::value)) { static_assert(std::is_same_v); - - auto sg = ndit.get_sub_group(); - std::uint8_t sgSize = sg.get_local_range()[0]; - std::uint8_t max_sgSize = sg.get_max_local_range()[0]; - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * max_sgSize); - - if (base + n_vecs * vec_sz * sgSize < nelems_ && - sgSize == max_sgSize) { - using in_ptrT = - sycl::multi_ptr; - using out_ptrT = - sycl::multi_ptr; - sycl::vec arg_vec; - -#pragma unroll - for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { - arg_vec = sg.load(in_ptrT(&in[base + it * sgSize])); - sg.store(out_ptrT(&out[base + it * sgSize]), - arg_vec); - } - } - else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems_; - k += sgSize) { - out[k] = in[k]; - } - } + return x; } else { - using dpctl::tensor::type_utils::is_complex; - if constexpr (is_complex::value) { - std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; - size_t base = ndit.get_global_linear_id(); - - base = (base / sgSize) * sgSize * n_vecs * vec_sz + - (base % sgSize); - for (size_t offset = base; - offset < - std::min(nelems_, base + sgSize * (n_vecs * vec_sz)); - offset += sgSize) - { - out[offset] = std::abs(in[offset]); - } - } - else { - auto sg = ndit.get_sub_group(); - std::uint8_t sgSize = sg.get_local_range()[0]; - std::uint8_t maxsgSize = sg.get_max_local_range()[0]; - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * maxsgSize); - - if (base + n_vecs * vec_sz < nelems_) { - using in_ptrT = sycl::multi_ptr< - const argT, sycl::access::address_space::global_space>; - using out_ptrT = sycl::multi_ptr< - resT, sycl::access::address_space::global_space>; - sycl::vec arg_vec; - -#pragma unroll - for (std::uint8_t it = 0; it < n_vecs * vec_sz; - it += vec_sz) { - arg_vec = - sg.load(in_ptrT(&in[base + it * sgSize])); -#pragma unroll - for (std::uint8_t k = 0; k < vec_sz; ++k) { - arg_vec[k] = std::abs(arg_vec[k]); - } - sg.store(out_ptrT(&out[base + it * sgSize]), - arg_vec); - } - } - else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems_; - k += sgSize) { - out[k] = std::abs(in[k]); - } - } - } + return std::abs(x); } } }; +template +using AbsContigFunctor = elementwise_common:: + UnaryContigFunctor, vec_sz, n_vecs>; + template struct AbsOutputType { using value_type = typename std::disjunction< // disjunction is C++17 @@ -220,39 +144,9 @@ template struct AbsTypeMapFactory } }; -template -struct AbsStridedFunctor -{ -private: - const argT *in = nullptr; - resT *out = nullptr; - IndexerT inp_res_indexer_; - -public: - AbsStridedFunctor(const argT *inp_p, - resT *res_p, - IndexerT two_offsets_indexer) - : in(inp_p), out(res_p), inp_res_indexer_(two_offsets_indexer) - { - } - - void operator()(sycl::id<1> wid) const - { - auto offsets_ = inp_res_indexer_(static_cast(wid[0])); - const auto &inp_offset = offsets_.get_first_offset(); - const auto &out_offset = offsets_.get_second_offset(); - - if constexpr (std::is_same_v || - (std::is_integral::value && - std::is_unsigned::value)) - { - out[out_offset] = in[inp_offset]; - } - else { - out[out_offset] = std::abs(in[inp_offset]); - } - } -}; +template +using AbsStridedFunctor = elementwise_common:: + UnaryStridedFunctor>; template class abs_strided_kernel; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp index e69de29bb2..a69cac19f4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp @@ -0,0 +1,231 @@ +#pragma once +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace elementwise_common +{ + +/*! @brief Functor for unary function evaluation on contiguous array */ +template +struct UnaryContigFunctor +{ +private: + const argT *in = nullptr; + resT *out = nullptr; + const size_t nelems_; + +public: + UnaryContigFunctor(const argT *inp, resT *res, const size_t n_elems) + : in(inp), out(res), nelems_(n_elems) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + UnaryOperatorT op{}; + /* Each work-item processes vec_sz elements, contiguous in memory */ + /* NOTE: vec_sz must divide sg.max_local_range()[0] */ + + if constexpr (UnaryOperatorT::is_constant::value) { + // value of operator is known to be a known constant + constexpr resT const_val = UnaryOperatorT::constant_value; + using out_ptrT = + sycl::multi_ptr; + + auto sg = ndit.get_sub_group(); + std::uint8_t sgSize = sg.get_local_range()[0]; + std::uint8_t max_sgSize = sg.get_max_local_range()[0]; + size_t base = n_vecs * vec_sz * + (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + if (base + n_vecs * vec_sz * sgSize < nelems_ && + max_sgSize == sgSize) { + sycl::vec res_vec(const_val); +#pragma unroll + for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + sg.store(out_ptrT(&out[base + it * sgSize]), + res_vec); + } + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < nelems_; + k += sgSize) { + out[k] = const_val; + } + } + } + else if constexpr (UnaryOperatorT::supports_sg_loadstore::value && + UnaryOperatorT::supports_vec::value) + { + using in_ptrT = + sycl::multi_ptr; + using out_ptrT = + sycl::multi_ptr; + + auto sg = ndit.get_sub_group(); + std::uint16_t sgSize = sg.get_local_range()[0]; + std::uint16_t max_sgSize = sg.get_max_local_range()[0]; + size_t base = n_vecs * vec_sz * + (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * max_sgSize); + if (base + n_vecs * vec_sz * sgSize < nelems_ && + sgSize == max_sgSize) { + sycl::vec x; + +#pragma unroll + for (std::uint16_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + x = sg.load(in_ptrT(&in[base + it * sgSize])); + sycl::vec res_vec = op(x); + sg.store(out_ptrT(&out[base + it * sgSize]), + res_vec); + } + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < nelems_; + k += sgSize) { + // scalar call + out[k] = op(in[k]); + } + } + } + else if constexpr (UnaryOperatorT::supports_sg_loadstore::value && + std::is_same_v) + { + // default: use scalar-value function + + auto sg = ndit.get_sub_group(); + std::uint8_t sgSize = sg.get_local_range()[0]; + std::uint8_t maxsgSize = sg.get_max_local_range()[0]; + size_t base = n_vecs * vec_sz * + (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * maxsgSize); + + if ((base + n_vecs * vec_sz * sgSize < nelems_) && + (maxsgSize == sgSize)) { + using in_ptrT = + sycl::multi_ptr; + using out_ptrT = + sycl::multi_ptr; + sycl::vec arg_vec; + +#pragma unroll + for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + arg_vec = sg.load(in_ptrT(&in[base + it * sgSize])); +#pragma unroll + for (std::uint8_t k = 0; k < vec_sz; ++k) { + arg_vec[k] = op(arg_vec[k]); + } + sg.store(out_ptrT(&out[base + it * sgSize]), + arg_vec); + } + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < nelems_; + k += sgSize) { + out[k] = op(in[k]); + } + } + } + else if constexpr (UnaryOperatorT::supports_sg_loadstore::value) { + // default: use scalar-value function + + auto sg = ndit.get_sub_group(); + std::uint8_t sgSize = sg.get_local_range()[0]; + std::uint8_t maxsgSize = sg.get_max_local_range()[0]; + size_t base = n_vecs * vec_sz * + (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * maxsgSize); + + if ((base + n_vecs * vec_sz * sgSize < nelems_) && + (maxsgSize == sgSize)) { + using in_ptrT = + sycl::multi_ptr; + using out_ptrT = + sycl::multi_ptr; + sycl::vec arg_vec; + sycl::vec res_vec; + +#pragma unroll + for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + arg_vec = sg.load(in_ptrT(&in[base + it * sgSize])); +#pragma unroll + for (std::uint8_t k = 0; k < vec_sz; ++k) { + res_vec[k] = op(arg_vec[k]); + } + sg.store(out_ptrT(&out[base + it * sgSize]), + res_vec); + } + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < nelems_; + k += sgSize) { + out[k] = op(in[k]); + } + } + } + else { + std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; + size_t base = ndit.get_global_linear_id(); + + base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); + for (size_t offset = base; + offset < std::min(nelems_, base + sgSize * (n_vecs * vec_sz)); + offset += sgSize) + { + out[offset] = op(in[offset]); + } + } + } +}; + +template +struct UnaryStridedFunctor +{ +private: + const argT *inp_ = nullptr; + resT *res_ = nullptr; + IndexerT inp_out_indexer_; + +public: + UnaryStridedFunctor(const argT *inp_p, + resT *res_p, + IndexerT inp_out_indexer) + : inp_(inp_p), res_(res_p), inp_out_indexer_(inp_out_indexer) + { + } + + void operator()(sycl::id<1> wid) const + { + const argT *const &in = inp_; + resT *const &out = res_; + + auto offsets_ = inp_out_indexer_(wid.get(0)); + const ssize_t &inp_offset = offsets_.get_first_offset(); + const ssize_t &out_offset = offsets_.get_second_offset(); + + UnaryOpT op{}; + + out[out_offset] = op(in[inp_offset]); + } +}; + +} // namespace elementwise_common +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index 97f0fd1f26..e036505ce9 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -2,6 +2,8 @@ #include #include +#include "kernels/elementwise_functions/common.hpp" + #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" #include "utils/type_utils.hpp" @@ -19,88 +21,56 @@ namespace cos namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; -template -struct CosContigFunctor -{ -private: - const argT *in = nullptr; - resT *out = nullptr; - const size_t nelems_; +using dpctl::tensor::type_utils::is_complex; -public: - CosContigFunctor(const argT *inp, resT *res, const size_t nelems) - : in(inp), out(res), nelems_(nelems) - { - } +template struct CosFunctor +{ - void operator()(sycl::nd_item<1> ndit) const + // is function constant for given argT + using is_constant = typename std::false_type; + // constant value, if constant + // constexpr resT constant_value = resT{}; + // is function defined for sycl::vec + using supports_vec = typename std::false_type; + // do both argTy and resTy support sugroup store/load operation + using supports_sg_loadstore = typename std::negation< + std::disjunction, is_complex>>; + + resT operator()(const argT &in) { - auto sg = ndit.get_sub_group(); - - using dpctl::tensor::type_utils::is_complex; if constexpr (is_complex::value) { - std::uint8_t sgSize = sg.get_local_range()[0]; - size_t base = ndit.get_global_linear_id(); - - base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); - for (size_t offset = base; - offset < std::min(base + sgSize * (n_vecs * vec_sz), nelems_); - offset += sgSize) - { - using realT = typename argT::value_type; - // cos(x + I*y) = cos(x)*cosh(y) - I*sin(x)*sinh(y) - auto v = std::real(in[offset]); - realT cosX_val; - const realT sinX_val = sycl::sincos(-v, &cosX_val); - v = std::imag(in[offset]); - const realT sinhY_val = sycl::sinh(v); - const realT coshY_val = sycl::cosh(v); + using realT = typename argT::value_type; + // cos(x + I*y) = cos(x)*cosh(y) - I*sin(x)*sinh(y) + auto v = std::real(in); + realT cosX_val; + const realT sinX_val = sycl::sincos(-v, &cosX_val); + v = std::imag(in); + const realT sinhY_val = sycl::sinh(v); + const realT coshY_val = sycl::cosh(v); - const realT res_re = coshY_val * cosX_val; - const realT res_im = sinX_val * sinhY_val; - out[offset] = resT{res_re, res_im}; - } + const realT res_re = coshY_val * cosX_val; + const realT res_im = sinX_val * sinhY_val; + return resT{res_re, res_im}; } else { - using dpctl::tensor::type_utils::vec_cast; - - std::uint8_t sgSize = sg.get_local_range()[0]; - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * sgSize); - if (base + n_vecs * vec_sz * sg.get_max_local_range()[0] < nelems_) - { - using in_ptrT = - sycl::multi_ptr; - using out_ptrT = - sycl::multi_ptr; - -#pragma unroll - for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { - sycl::vec x = - sg.load(in_ptrT(&in[base + it * sgSize])); - - sycl::vec res_vec = sycl::cos( - vec_cast(x)); - sg.store(out_ptrT(&out[base + it * sgSize]), - res_vec); - } - } - else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems_; - k += sgSize) - out[k] = sycl::cos(static_cast(in[k])); - } + static_assert(std::is_floating_point_v || + std::is_same_v); + return std::cos(in); } } }; +template +using CosContigFunctor = elementwise_common:: + UnaryContigFunctor, vec_sz, n_vecs>; + +template +using CosStridedFunctor = elementwise_common:: + UnaryStridedFunctor>; + template struct CosOutputType { using value_type = typename std::disjunction< // disjunction is C++17 @@ -180,49 +150,6 @@ template struct CosTypeMapFactory } }; -template -struct CosStridedFunctor -{ -private: - const argT *in = nullptr; - resT *out = nullptr; - IndexerT inp_out_indexer_; - -public: - CosStridedFunctor(const argT *inp_tp, - resT *res_tp, - IndexerT arg_res_indexer) - : in(inp_tp), out(res_tp), inp_out_indexer_(arg_res_indexer) - { - } - - void operator()(sycl::id<1> wid) const - { - auto offsets_ = inp_out_indexer_(static_cast(wid.get(0))); - const py::ssize_t &inp_offset = offsets_.get_first_offset(); - const py::ssize_t &out_offset = offsets_.get_second_offset(); - - using dpctl::tensor::type_utils::is_complex; - if constexpr (is_complex::value) { - using realT = typename argT::value_type; - // cos(x + I*y) = cos(x)*cosh(y) - I*sin(x)*sinh(y) - auto v = std::real(in[inp_offset]); - realT cosX_val; - const realT sinX_val = sycl::sincos(-v, &cosX_val); - v = std::imag(in[inp_offset]); - const realT sinhY_val = sycl::sinh(v); - const realT coshY_val = sycl::cosh(v); - - const realT res_re = coshY_val * cosX_val; - const realT res_im = sinX_val * sinhY_val; - out[out_offset] = resT{res_re, res_im}; - } - else { - out[out_offset] = std::cos(static_cast(in[inp_offset])); - } - } -}; - template class cos_strided_kernel; typedef sycl::event (*cos_strided_impl_fn_ptr_t)( @@ -262,9 +189,11 @@ sycl::event cos_strided_impl(sycl::queue exec_q, const argTy *arg_tp = reinterpret_cast(arg_p); resTy *res_tp = reinterpret_cast(res_p); + sycl::range<1> gRange{nelems}; + cgh.parallel_for>( - {nelems}, CosStridedFunctor( - arg_tp, res_tp, arg_res_indexer)); + gRange, CosStridedFunctor(arg_tp, res_tp, + arg_res_indexer)); }); return comp_ev; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp new file mode 100644 index 0000000000..f5a9ec0527 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp @@ -0,0 +1,209 @@ +#pragma once +#include +#include + +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace isfinite +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +using dpctl::tensor::type_utils::is_complex; +using dpctl::tensor::type_utils::vec_cast; + +template struct IsFiniteFunctor +{ + static_assert(std::is_same_v); + + /* + std::is_same::value || + std::is_integral::value + */ + using is_constant = typename std::disjunction, + std::is_integral>; + static constexpr resT constant_value = true; + using supports_vec = typename std::false_type; + using supports_sg_loadstore = typename std::negation< + std::disjunction, is_complex>>; + + resT operator()(const argT &in) const + { + if constexpr (is_complex::value) { + const bool real_isfinite = std::isfinite(std::real(in)); + const bool imag_isfinite = std::isfinite(std::imag(in)); + return (real_isfinite && imag_isfinite); + } + else if constexpr (std::is_same::value || + std::is_integral::value) + { + return constant_value; + } + else if constexpr (std::is_same_v) { + return sycl::isfinite(in); + } + else { + return std::isfinite(in); + } + } + + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::isfinite(in); + + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + + return vec_cast(res_vec); + } +}; + +template +using IsFiniteContigFunctor = elementwise_common:: + UnaryContigFunctor, vec_sz, n_vecs>; + +template +using IsFiniteStridedFunctor = elementwise_common:: + UnaryStridedFunctor>; + +template struct IsFiniteOutputType +{ + using value_type = bool; +}; + +typedef sycl::event (*isfinite_contig_impl_fn_ptr_t)( + sycl::queue, + size_t, + const char *, + char *, + const std::vector &); + +template +class isfinite_contig_kernel; + +template +sycl::event isfinite_contig_impl(sycl::queue exec_q, + size_t nelems, + const char *arg_p, + char *res_p, + const std::vector &depends = {}) +{ + sycl::event isfinite_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + constexpr size_t lws = 64; + constexpr std::uint8_t vec_sz = 4; + constexpr std::uint8_t n_vecs = 2; + static_assert(lws % vec_sz == 0); + size_t n_groups = + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); + auto gws_range = sycl::range<1>(n_groups * lws); + auto lws_range = sycl::range<1>(lws); + + using resTy = typename IsFiniteOutputType::value_type; + const argTy *arg_tp = reinterpret_cast(arg_p); + resTy *res_tp = reinterpret_cast(res_p); + + cgh.parallel_for< + class isfinite_contig_kernel>( + sycl::nd_range<1>(gws_range, lws_range), + IsFiniteContigFunctor(arg_tp, res_tp, + nelems)); + }); + return isfinite_ev; +} + +template struct IsFiniteContigFactory +{ + fnT get() + { + fnT fn = isfinite_contig_impl; + return fn; + } +}; + +template struct IsFiniteTypeMapFactory +{ + /*! @brief get typeid for output type of sycl::isfinite(T x) */ + std::enable_if_t::value, int> get() + { + using rT = typename IsFiniteOutputType::value_type; + ; + return td_ns::GetTypeid{}.get(); + } +}; + +template class isfinite_strided_kernel; + +typedef sycl::event (*isfinite_strided_impl_fn_ptr_t)( + sycl::queue, + size_t, + int, + const py::ssize_t *, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &, + const std::vector &); + +template +sycl::event +isfinite_strided_impl(sycl::queue exec_q, + size_t nelems, + int nd, + const py::ssize_t *shape_and_strides, + const char *arg_p, + py::ssize_t arg_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + sycl::event isfinite_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.depends_on(additional_depends); + + using resTy = typename IsFiniteOutputType::value_type; + using IndexerT = + typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + IndexerT arg_res_indexer{nd, arg_offset, res_offset, shape_and_strides}; + + const argTy *arg_tptr = reinterpret_cast(arg_p); + resTy *res_tptr = reinterpret_cast(res_p); + + cgh.parallel_for>( + {nelems}, IsFiniteStridedFunctor( + arg_tptr, res_tptr, arg_res_indexer)); + }); + return isfinite_ev; +} + +template struct IsFiniteStridedFactory +{ + fnT get() + { + fnT fn = isfinite_strided_impl; + return fn; + } +}; + +} // namespace isfinite +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp new file mode 100644 index 0000000000..656f867f5d --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp @@ -0,0 +1,207 @@ +#pragma once +#include +#include + +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace isinf +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +using dpctl::tensor::type_utils::is_complex; +using dpctl::tensor::type_utils::vec_cast; + +template struct IsInfFunctor +{ + static_assert(std::is_same_v); + + using is_constant = typename std::disjunction, + std::is_integral>; + static constexpr resT constant_value = false; + using supports_vec = typename std::false_type; + using supports_sg_loadstore = typename std::negation< + std::disjunction, is_complex>>; + + resT operator()(const argT &in) const + { + if constexpr (is_complex::value) { + const bool real_isinf = std::isinf(std::real(in)); + const bool imag_isinf = std::isinf(std::imag(in)); + return (real_isinf || imag_isinf); + } + else if constexpr (std::is_same::value || + std::is_integral::value) + { + return constant_value; + } + else if constexpr (std::is_same_v) { + return sycl::isinf(in); + } + else { + return std::isinf(in); + } + } + + // unused (since support_vec is set to false_type) due to bug in sycl::isinf + // implementation in OpenCL CPU RT + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::isinf(in); + + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + + return vec_cast(res_vec); + } +}; + +template +using IsInfContigFunctor = elementwise_common:: + UnaryContigFunctor, vec_sz, n_vecs>; + +template +using IsInfStridedFunctor = elementwise_common:: + UnaryStridedFunctor>; + +template struct IsInfOutputType +{ + using value_type = bool; +}; + +typedef sycl::event (*isinf_contig_impl_fn_ptr_t)( + sycl::queue, + size_t, + const char *, + char *, + const std::vector &); + +template +class isinf_contig_kernel; + +template +sycl::event isinf_contig_impl(sycl::queue exec_q, + size_t nelems, + const char *arg_p, + char *res_p, + const std::vector &depends = {}) +{ + sycl::event isinf_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + constexpr size_t lws = 64; + constexpr std::uint8_t vec_sz = 4; + constexpr std::uint8_t n_vecs = 2; + static_assert(lws % vec_sz == 0); + auto gws_range = sycl::range<1>( + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)) * + lws); + auto lws_range = sycl::range<1>(lws); + + using resTy = typename IsInfOutputType::value_type; + const argTy *arg_tp = reinterpret_cast(arg_p); + resTy *res_tp = reinterpret_cast(res_p); + + cgh.parallel_for< + class isinf_contig_kernel>( + sycl::nd_range<1>(gws_range, lws_range), + IsInfContigFunctor(arg_tp, res_tp, + nelems)); + }); + return isinf_ev; +} + +template struct IsInfContigFactory +{ + fnT get() + { + fnT fn = isinf_contig_impl; + return fn; + } +}; + +template struct IsInfTypeMapFactory +{ + /*! @brief get typeid for output type of sycl::isinf(T x) */ + std::enable_if_t::value, int> get() + { + using rT = typename IsInfOutputType::value_type; + return td_ns::GetTypeid{}.get(); + } +}; + +template class isinf_strided_kernel; + +typedef sycl::event (*isinf_strided_impl_fn_ptr_t)( + sycl::queue, + size_t, + int, + const py::ssize_t *, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &, + const std::vector &); + +template +sycl::event +isinf_strided_impl(sycl::queue exec_q, + size_t nelems, + int nd, + const py::ssize_t *shape_and_strides, + const char *arg_p, + py::ssize_t arg_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + sycl::event abs_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.depends_on(additional_depends); + + using resTy = typename IsInfOutputType::value_type; + using IndexerT = + typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + IndexerT arg_res_indexer{nd, arg_offset, res_offset, shape_and_strides}; + + const argTy *arg_tptr = reinterpret_cast(arg_p); + resTy *res_tptr = reinterpret_cast(res_p); + + sycl::range<1> gRange{nelems}; + + cgh.parallel_for>( + gRange, IsInfStridedFunctor( + arg_tptr, res_tptr, arg_res_indexer)); + }); + return abs_ev; +} + +template struct IsInfStridedFactory +{ + fnT get() + { + fnT fn = isinf_strided_impl; + return fn; + } +}; + +} // namespace isinf +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp index aa7abecb98..3d72b6e707 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp @@ -19,11 +19,69 @@ namespace isnan namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +using dpctl::tensor::type_utils::is_complex; +using dpctl::tensor::type_utils::vec_cast; + +template struct IsNanFunctor +{ + static_assert(std::is_same_v); + + /* + std::is_same::value || + std::is_integral::value + */ + using is_constant = typename std::disjunction, + std::is_integral>; + static constexpr resT constant_value = false; + using supports_vec = typename std::true_type; + using supports_sg_loadstore = typename std::negation< + std::disjunction, is_complex>>; + + resT operator()(const argT &in) const + { + if constexpr (is_complex::value) { + const bool real_isnan = sycl::isnan(std::real(in)); + const bool imag_isnan = sycl::isnan(std::imag(in)); + return (real_isnan || imag_isnan); + } + else if constexpr (std::is_same::value || + std::is_integral::value) + { + return constant_value; + } + else { + return sycl::isnan(in); + } + } + + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::isnan(in); + + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + + return vec_cast(res_vec); + } +}; + +template +using IsNanContigFunctor = elementwise_common:: + UnaryContigFunctor, vec_sz, n_vecs>; + +template +using IsNanStridedFunctor = elementwise_common:: + UnaryStridedFunctor>; + template -struct IsNanContigFunctor +struct IsNanContigFunctorOld { private: const argT *in = nullptr; @@ -31,7 +89,7 @@ struct IsNanContigFunctor const size_t nelems_; public: - IsNanContigFunctor(const argT *inp, resT *res, const size_t nelems) + IsNanContigFunctorOld(const argT *inp, resT *res, const size_t nelems) : in(inp), out(res), nelems_(nelems) { } @@ -193,7 +251,7 @@ template struct IsNanTypeMapFactory }; template -struct IsNanStridedFunctor +struct IsNanStridedFunctorOld { private: const argT *inp_ = nullptr; @@ -201,9 +259,9 @@ struct IsNanStridedFunctor IndexerT inp_out_indexer_; public: - IsNanStridedFunctor(const argT *inp_p, - resT *res_p, - IndexerT inp_out_indexer) + IsNanStridedFunctorOld(const argT *inp_p, + resT *res_p, + IndexerT inp_out_indexer) : inp_(inp_p), res_(res_p), inp_out_indexer_(inp_out_indexer) { } diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp index 8d7283bfaa..65b2b1d9a1 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions.cpp @@ -35,6 +35,8 @@ #include "kernels/elementwise_functions/abs.hpp" #include "kernels/elementwise_functions/add.hpp" #include "kernels/elementwise_functions/cos.hpp" +#include "kernels/elementwise_functions/isfinite.hpp" +#include "kernels/elementwise_functions/isinf.hpp" #include "kernels/elementwise_functions/isnan.hpp" namespace dpctl @@ -96,8 +98,9 @@ int _result_typeid(int arg_typeid, const int *fn_output_id) namespace impl { -using dpctl::tensor::kernels::abs::abs_contig_impl_fn_ptr_t; -using dpctl::tensor::kernels::abs::abs_strided_impl_fn_ptr_t; +namespace abs_fn_ns = dpctl::tensor::kernels::abs; +using abs_fn_ns::abs_contig_impl_fn_ptr_t; +using abs_fn_ns::abs_strided_impl_fn_ptr_t; static abs_contig_impl_fn_ptr_t abs_contig_dispatch_vector[td_ns::num_types]; static int abs_output_typeid_vector[td_ns::num_types]; @@ -106,31 +109,106 @@ static abs_strided_impl_fn_ptr_t abs_strided_dispatch_vector[td_ns::num_types]; void populate_abs_dispatch_vectors(void) { using namespace td_ns; + namespace fn_ns = abs_fn_ns; - using dpctl::tensor::kernels::abs::AbsContigFactory; + using fn_ns::AbsContigFactory; DispatchVectorBuilder dvb1; dvb1.populate_dispatch_vector(abs_contig_dispatch_vector); - using dpctl::tensor::kernels::abs::AbsStridedFactory; + using fn_ns::AbsStridedFactory; DispatchVectorBuilder dvb2; dvb2.populate_dispatch_vector(abs_strided_dispatch_vector); - using dpctl::tensor::kernels::abs::AbsTypeMapFactory; + using fn_ns::AbsTypeMapFactory; DispatchVectorBuilder dvb3; dvb3.populate_dispatch_vector(abs_output_typeid_vector); }; } // namespace impl -// ISNAN +// ISFINITE namespace impl { +namespace isfinite_fn_ns = dpctl::tensor::kernels::isfinite; +using isfinite_fn_ns::isfinite_contig_impl_fn_ptr_t; +using isfinite_fn_ns::isfinite_strided_impl_fn_ptr_t; + +static isfinite_contig_impl_fn_ptr_t + isfinite_contig_dispatch_vector[td_ns::num_types]; +static int isfinite_output_typeid_vector[td_ns::num_types]; +static isfinite_strided_impl_fn_ptr_t + isfinite_strided_dispatch_vector[td_ns::num_types]; + +void populate_isfinite_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = isfinite_fn_ns; + + using fn_ns::IsFiniteContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(isfinite_contig_dispatch_vector); + + using fn_ns::IsFiniteStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(isfinite_strided_dispatch_vector); -using dpctl::tensor::kernels::isnan::isnan_contig_impl_fn_ptr_t; -using dpctl::tensor::kernels::isnan::isnan_strided_impl_fn_ptr_t; + using fn_ns::IsFiniteTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(isfinite_output_typeid_vector); +} + +} // namespace impl + +// ISINF +namespace impl +{ +namespace isinf_fn_ns = dpctl::tensor::kernels::isinf; +using isinf_fn_ns::isinf_contig_impl_fn_ptr_t; +using isinf_fn_ns::isinf_strided_impl_fn_ptr_t; + +static isinf_contig_impl_fn_ptr_t + isinf_contig_dispatch_vector[td_ns::num_types]; +static int isinf_output_typeid_vector[td_ns::num_types]; +static isinf_strided_impl_fn_ptr_t + isinf_strided_dispatch_vector[td_ns::num_types]; + +void populate_isinf_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = isinf_fn_ns; + + using fn_ns::IsInfContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(isinf_contig_dispatch_vector); + + using fn_ns::IsInfStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(isinf_strided_dispatch_vector); + + using fn_ns::IsInfTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(isinf_output_typeid_vector); +} + +} // namespace impl + +// ISNAN +namespace impl +{ +namespace isnan_fn_ns = dpctl::tensor::kernels::isnan; +using isnan_fn_ns::isnan_contig_impl_fn_ptr_t; +using isnan_fn_ns::isnan_strided_impl_fn_ptr_t; static isnan_contig_impl_fn_ptr_t isnan_contig_dispatch_vector[td_ns::num_types]; @@ -141,20 +219,21 @@ static isnan_strided_impl_fn_ptr_t void populate_isnan_dispatch_vectors(void) { using namespace td_ns; + namespace fn_ns = isnan_fn_ns; - using dpctl::tensor::kernels::isnan::IsNanContigFactory; + using fn_ns::IsNanContigFactory; DispatchVectorBuilder dvb1; dvb1.populate_dispatch_vector(isnan_contig_dispatch_vector); - using dpctl::tensor::kernels::isnan::IsNanStridedFactory; + using fn_ns::IsNanStridedFactory; DispatchVectorBuilder dvb2; dvb2.populate_dispatch_vector(isnan_strided_dispatch_vector); - using dpctl::tensor::kernels::isnan::IsNanTypeMapFactory; + using fn_ns::IsNanTypeMapFactory; DispatchVectorBuilder dvb3; dvb3.populate_dispatch_vector(isnan_output_typeid_vector); } @@ -165,8 +244,9 @@ void populate_isnan_dispatch_vectors(void) namespace impl { -using dpctl::tensor::kernels::cos::cos_contig_impl_fn_ptr_t; -using dpctl::tensor::kernels::cos::cos_strided_impl_fn_ptr_t; +namespace cos_fn_ns = dpctl::tensor::kernels::cos; +using cos_fn_ns::cos_contig_impl_fn_ptr_t; +using cos_fn_ns::cos_strided_impl_fn_ptr_t; static cos_contig_impl_fn_ptr_t cos_contig_dispatch_vector[td_ns::num_types]; static int cos_output_typeid_vector[td_ns::num_types]; @@ -175,19 +255,20 @@ static cos_strided_impl_fn_ptr_t cos_strided_dispatch_vector[td_ns::num_types]; void populate_cos_dispatch_vectors(void) { using namespace td_ns; + namespace fn_ns = cos_fn_ns; - using dpctl::tensor::kernels::cos::CosContigFactory; + using fn_ns::CosContigFactory; DispatchVectorBuilder dvb1; dvb1.populate_dispatch_vector(cos_contig_dispatch_vector); - using dpctl::tensor::kernels::cos::CosStridedFactory; + using fn_ns::CosStridedFactory; DispatchVectorBuilder dvb2; dvb2.populate_dispatch_vector(cos_strided_dispatch_vector); - using dpctl::tensor::kernels::cos::CosTypeMapFactory; + using fn_ns::CosTypeMapFactory; DispatchVectorBuilder dvb3; dvb3.populate_dispatch_vector(cos_output_typeid_vector); } @@ -198,11 +279,11 @@ void populate_cos_dispatch_vectors(void) namespace impl { +namespace fn_ns = dpctl::tensor::kernels::add; -using dpctl::tensor::kernels::add::add_contig_impl_fn_ptr_t; -using dpctl::tensor::kernels::add:: - add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; -using dpctl::tensor::kernels::add::add_strided_impl_fn_ptr_t; +using fn_ns::add_contig_impl_fn_ptr_t; +using fn_ns::add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using fn_ns::add_strided_impl_fn_ptr_t; static add_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] [td_ns::num_types]; @@ -219,22 +300,22 @@ void populate_add_dispatch_tables(void) { using namespace td_ns; - using dpctl::tensor::kernels::add::AddContigFactory; + using fn_ns::AddContigFactory; DispatchTableBuilder dtb1; dtb1.populate_dispatch_table(add_contig_dispatch_table); - using dpctl::tensor::kernels::add::AddStridedFactory; + using fn_ns::AddStridedFactory; DispatchTableBuilder dtb2; dtb2.populate_dispatch_table(add_strided_dispatch_table); - using dpctl::tensor::kernels::add::AddTypeMapFactory; + using fn_ns::AddTypeMapFactory; DispatchTableBuilder dtb3; dtb3.populate_dispatch_table(add_output_id_table); - using dpctl::tensor::kernels::add::AddContigMatrixContigRowBroadcastFactory; + using fn_ns::AddContigMatrixContigRowBroadcastFactory; DispatchTableBuilder dtb4; @@ -275,6 +356,7 @@ void init_elementwise_functions(py::module_ m) // U02: ==== ACOS (x) // FIXME: + // U03: ===== ACOSH (x) // FIXME: @@ -404,10 +486,53 @@ void init_elementwise_functions(py::module_ m) // FIXME: // U17: ==== ISFINITE (x) - // FIXME: + { + impl::populate_isfinite_dispatch_vectors(); + + using impl::isfinite_contig_dispatch_vector; + using impl::isfinite_output_typeid_vector; + using impl::isfinite_strided_dispatch_vector; + auto isfinite_pyapi = + [&](dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst, + sycl::queue exec_q, + const std::vector &depends = {}) { + return py_unary_ufunc(src, dst, exec_q, depends, + isfinite_output_typeid_vector, + isfinite_contig_dispatch_vector, + isfinite_strided_dispatch_vector); + }; + auto isfinite_result_type_pyapi = [&](py::dtype dtype) { + return py_unary_ufunc_result_type(dtype, + isfinite_output_typeid_vector); + }; + m.def("_isfinite", isfinite_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_isfinite_result_type", isfinite_result_type_pyapi, ""); + } // U18: ==== ISINF (x) - // FIXME: + { + impl::populate_isinf_dispatch_vectors(); + + using impl::isinf_contig_dispatch_vector; + using impl::isinf_output_typeid_vector; + using impl::isinf_strided_dispatch_vector; + auto isinf_pyapi = [&](dpctl::tensor::usm_ndarray src, + dpctl::tensor::usm_ndarray dst, + sycl::queue exec_q, + const std::vector &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, isinf_output_typeid_vector, + isinf_contig_dispatch_vector, isinf_strided_dispatch_vector); + }; + auto isinf_result_type_pyapi = [&](py::dtype dtype) { + return py_unary_ufunc_result_type(dtype, + isinf_output_typeid_vector); + }; + m.def("_isinf", isinf_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_isinf_result_type", isinf_result_type_pyapi, ""); + } // U19: ==== ISNAN (x) { diff --git a/dpctl/tests/test_tensor_elementwise.py b/dpctl/tests/test_tensor_elementwise.py index 85daf797f0..1a5ecc0dd9 100644 --- a/dpctl/tests/test_tensor_elementwise.py +++ b/dpctl/tests/test_tensor_elementwise.py @@ -518,7 +518,7 @@ def test_isnan_output(): Xnp = np.asarray(np.nan) X = dpt.asarray(np.nan, sycl_queue=q) - assert dpt.asnumpy(dpt.isnan(X))[()] == np.isnan(Xnp) + assert dpt.asnumpy(dpt.isnan(X)) == np.isnan(Xnp) @pytest.mark.parametrize("dtype", ["c8", "c16"]) @@ -534,7 +534,7 @@ def test_isnan_complex(dtype): Ynp = np.repeat(np.array([y1, y2, y3, y4, y5], dtype=dtype), 123) Y = dpt.asarray(Ynp, sycl_queue=q) - assert np.array_equal(dpt.asnumpy(dpt.isnan(Y))[()], np.isnan(Ynp)) + assert np.array_equal(dpt.asnumpy(dpt.isnan(Y)), np.isnan(Ynp)) @pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) @@ -549,7 +549,7 @@ def test_isnan_floats(dtype): for mult in [123, 137, 255, 271, 272]: Ynp = np.repeat(np.array([y1, y2, y3], dtype=dtype), mult) Y = dpt.asarray(Ynp, sycl_queue=q) - assert np.array_equal(dpt.asnumpy(dpt.isnan(Y))[()], np.isnan(Ynp)) + assert np.array_equal(dpt.asnumpy(dpt.isnan(Y)), np.isnan(Ynp)) @pytest.mark.parametrize("dtype", _all_dtypes) @@ -567,3 +567,135 @@ def test_isnan_order(dtype): Y = dpt.isnan(U, order=ord) expected_Y = np.full(Y.shape, False, dtype=Y.dtype) assert np.allclose(dpt.asnumpy(Y), expected_Y) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_isfinite_out_type(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + X = dpt.asarray(0, dtype=dtype, sycl_queue=q) + assert dpt.isfinite(X).dtype == dpt.bool + + +def test_isfinite_output(): + q = get_queue_or_skip() + + Xnp = np.asarray(np.nan) + X = dpt.asarray(np.nan, sycl_queue=q) + assert dpt.asnumpy(dpt.isfinite(X)) == np.isfinite(Xnp) + + +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_isfinite_complex(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = complex(np.nan, np.nan) + y2 = complex(1, np.nan) + y3 = complex(np.nan, 1) + y4 = complex(2, 1) + y5 = complex(np.inf, 1) + + Ynp = np.repeat(np.array([y1, y2, y3, y4, y5], dtype=dtype), 12) + Y = dpt.asarray(Ynp, sycl_queue=q) + assert np.array_equal(dpt.asnumpy(dpt.isfinite(Y)), np.isfinite(Ynp)) + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) +def test_isfinite_floats(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = np.nan + y2 = 1 + y3 = np.inf + + for mult in [123, 137, 255, 271, 272]: + Ynp = np.repeat(np.array([y1, y2, y3], dtype=dtype), mult) + Y = dpt.asarray(Ynp, sycl_queue=q) + assert np.array_equal(dpt.asnumpy(dpt.isfinite(Y)), np.isfinite(Ynp)) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_isfinite_order(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.ones(input_shape, dtype=arg_dt, sycl_queue=q) + + for ord in ["C", "F", "A", "K"]: + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[::2, ::-1, ::-1, ::5], perms) + Y = dpt.isfinite(U, order=ord) + expected_Y = np.full(Y.shape, True, dtype=Y.dtype) + assert np.allclose(dpt.asnumpy(Y), expected_Y) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_isinf_out_type(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + X = dpt.asarray(0, dtype=dtype, sycl_queue=q) + assert dpt.isinf(X).dtype == dpt.bool + + +def test_isinf_output(): + q = get_queue_or_skip() + + Xnp = np.asarray(np.inf) + X = dpt.asarray(np.inf, sycl_queue=q) + assert dpt.asnumpy(dpt.isinf(X)) == np.isinf(Xnp) + + +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_isinf_complex(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = complex(np.inf, np.inf) + y2 = complex(1, np.inf) + y3 = complex(np.inf, 1) + y4 = complex(2, 1) + y5 = complex(np.inf, 1) + y6 = complex(np.inf, np.nan) + + Ynp = np.repeat(np.array([y1, y2, y3, y4, y5, y6], dtype=dtype), 123) + Y = dpt.asarray(Ynp, sycl_queue=q) + assert np.array_equal(dpt.asnumpy(dpt.isinf(Y)), np.isinf(Ynp)) + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) +def test_isinf_floats(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = np.nan + y2 = 1 + y3 = np.inf + y4 = -np.inf + + for mult in [123, 137, 255, 271, 272]: + Ynp = np.repeat(np.array([y1, y2, y3, y4], dtype=dtype), mult) + Y = dpt.asarray(Ynp, sycl_queue=q) + assert np.array_equal(dpt.asnumpy(dpt.isinf(Y)), np.isinf(Ynp)) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_isinf_order(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.ones(input_shape, dtype=arg_dt, sycl_queue=q) + + for ord in ["C", "F", "A", "K"]: + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[::2, ::-1, ::-1, ::5], perms) + Y = dpt.isinf(U, order=ord) + expected_Y = np.full(Y.shape, False, dtype=Y.dtype) + assert np.allclose(dpt.asnumpy(Y), expected_Y) From 049b523b5f5f8ce7daa7858a0d49fdeb7ac05628 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 11 May 2023 21:39:24 -0500 Subject: [PATCH 07/48] Added missing include statements --- .../kernels/elementwise_functions/abs.hpp | 4 +++ .../kernels/elementwise_functions/add.hpp | 2 ++ .../kernels/elementwise_functions/common.hpp | 32 +++++++++++++++++-- .../kernels/elementwise_functions/cos.hpp | 3 ++ .../elementwise_functions/isfinite.hpp | 3 ++ .../kernels/elementwise_functions/isinf.hpp | 3 ++ .../kernels/elementwise_functions/isnan.hpp | 2 ++ 7 files changed, 47 insertions(+), 2 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index a95cf78d69..1ad27037b3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -1,5 +1,9 @@ #pragma once #include +#include +#include +#include +#include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index 5f008128d6..f26cfefbc2 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -1,6 +1,8 @@ #pragma once #include +#include #include +#include #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp index a69cac19f4..16b2deb914 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp @@ -1,5 +1,33 @@ +//=== common.hpp - -----------------------------------*-C++-*--/===// +//= Implementation of tensor elementwise operation kernels ------===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2022 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise operations over tensor . +//===----------------------------------------------------------------------===// + #pragma once #include +#include +#include +#include namespace dpctl { @@ -216,8 +244,8 @@ struct UnaryStridedFunctor resT *const &out = res_; auto offsets_ = inp_out_indexer_(wid.get(0)); - const ssize_t &inp_offset = offsets_.get_first_offset(); - const ssize_t &out_offset = offsets_.get_second_offset(); + const py::ssize_t &inp_offset = offsets_.get_first_offset(); + const py::ssize_t &out_offset = offsets_.get_second_offset(); UnaryOpT op{}; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index e036505ce9..4cc6454404 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -1,6 +1,9 @@ #pragma once #include +#include +#include #include +#include #include "kernels/elementwise_functions/common.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp index f5a9ec0527..258993c6a5 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp @@ -1,6 +1,9 @@ #pragma once #include +#include +#include #include +#include #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp index 656f867f5d..98302617e3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp @@ -1,6 +1,9 @@ #pragma once #include +#include +#include #include +#include #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp index 3d72b6e707..73522a4b7e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp @@ -1,6 +1,8 @@ #pragma once #include +#include #include +#include #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" From b49744ab1712917f4067a82c5f5236e70ff3fac6 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 11 May 2023 22:13:30 -0500 Subject: [PATCH 08/48] Debug: Set -H when compiling elementwise_functions.cpp --- dpctl/tensor/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index 56cd33b6b2..0748ec5871 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -53,7 +53,7 @@ if (WIN32) endif() set_source_files_properties( ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp - PROPERTIES COMPILE_OPTIONS "${_clang_prefix}-fno-approx-func;${_clang_prefix}-fno-finite-math-only") + PROPERTIES COMPILE_OPTIONS "${_clang_prefix}-fno-fast-math;-H") target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int) target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel) if(UNIX) From 56b134ccde3aaa804b466005fdd3e077076d923c Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Fri, 12 May 2023 09:38:21 -0500 Subject: [PATCH 09/48] Added -fno-approx-func and -fno-finite-math-only --- dpctl/tensor/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index 0748ec5871..f74ae7ccc5 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -53,7 +53,7 @@ if (WIN32) endif() set_source_files_properties( ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp - PROPERTIES COMPILE_OPTIONS "${_clang_prefix}-fno-fast-math;-H") + PROPERTIES COMPILE_OPTIONS "${_clang_prefix}-fno-fast-math;${_clang_prefix}-fno-finite-math-only;${_clang_prefix}-fno-approx-func;-H") target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int) target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel) if(UNIX) From 4952132bc477a4454beeda64f58518ca73ec8287 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 12 May 2023 10:54:18 -0700 Subject: [PATCH 10/48] Implements dpctl.tensor.sqrt --- dpctl/tensor/__init__.py | 3 +- dpctl/tensor/_elementwise_funcs.py | 10 + .../kernels/elementwise_functions/sqrt.hpp | 207 ++++++++++++++++++ .../source/elementwise_functions.cpp | 59 ++++- 4 files changed, 277 insertions(+), 2 deletions(-) create mode 100644 dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 48c8c045a8..c427dc99ab 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -90,7 +90,7 @@ from dpctl.tensor._usmarray import usm_ndarray from ._constants import e, inf, nan, newaxis, pi -from ._elementwise_funcs import abs, add, cos, isfinite, isinf, isnan +from ._elementwise_funcs import abs, add, cos, isfinite, isinf, isnan, sqrt __all__ = [ "Device", @@ -171,4 +171,5 @@ "isinf", "isnan", "isfinite", + "sqrt", ] diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index 566fea9658..97b873773b 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -71,3 +71,13 @@ isinf = UnaryElementwiseFunc( "isinf", ti._isinf_result_type, ti._isinf, _isinf_docstring_ ) + +# SQRT + +_sqrt_docstring_ = """ +Computes sqrt for each element `x_i` for input array `x`. +""" + +sqrt = UnaryElementwiseFunc( + "sqrt", ti._sqrt_result_type, ti._sqrt, _sqrt_docstring_ +) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp new file mode 100644 index 0000000000..719670cea0 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp @@ -0,0 +1,207 @@ +#pragma once +#include +#include +#include +#include +#include + +#include "kernels/elementwise_functions/common.hpp" + +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace sqrt +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +using dpctl::tensor::type_utils::is_complex; + +template struct SqrtFunctor +{ + + // is function constant for given argT + using is_constant = typename std::false_type; + // constant value, if constant + // constexpr resT constant_value = resT{}; + // is function defined for sycl::vec + using supports_vec = typename std::false_type; + // do both argTy and resTy support sugroup store/load operation + using supports_sg_loadstore = typename std::negation< + std::disjunction, is_complex>>; + + resT operator()(const argT &in) + { + return std::sqrt(in); + } +}; + +template +using SqrtContigFunctor = elementwise_common:: + UnaryContigFunctor, vec_sz, n_vecs>; + +template +using SqrtStridedFunctor = elementwise_common:: + UnaryStridedFunctor>; + +template struct SqrtOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, + td_ns::TypeMapEntry, std::complex>, + td_ns::TypeMapEntry, std::complex>, + td_ns::DefaultEntry>::result_type; +}; + +typedef sycl::event (*sqrt_contig_impl_fn_ptr_t)( + sycl::queue, + size_t, + const char *, + char *, + const std::vector &); + +template +class sqrt_contig_kernel; + +template +sycl::event sqrt_contig_impl(sycl::queue exec_q, + size_t nelems, + const char *arg_p, + char *res_p, + const std::vector &depends = {}) +{ + sycl::event sqrt_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + constexpr size_t lws = 64; + constexpr unsigned int vec_sz = 4; + constexpr unsigned int n_vecs = 2; + static_assert(lws % vec_sz == 0); + auto gws_range = sycl::range<1>( + ((nelems + n_vecs * lws * vec_sz - 1) / (lws * n_vecs * vec_sz)) * + lws); + auto lws_range = sycl::range<1>(lws); + + using resTy = typename SqrtOutputType::value_type; + const argTy *arg_tp = reinterpret_cast(arg_p); + resTy *res_tp = reinterpret_cast(res_p); + + cgh.parallel_for< + class sqrt_contig_kernel>( + sycl::nd_range<1>(gws_range, lws_range), + SqrtContigFunctor(arg_tp, res_tp, + nelems)); + }); + return sqrt_ev; +} + +template struct SqrtContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = sqrt_contig_impl; + return fn; + } + } +}; + +template struct SqrtTypeMapFactory +{ + /*! @brief get typeid for output type of std::sqrt(T x) */ + std::enable_if_t::value, int> get() + { + using rT = typename SqrtOutputType::value_type; + ; + return td_ns::GetTypeid{}.get(); + } +}; + +template class sqrt_strided_kernel; + +typedef sycl::event (*sqrt_strided_impl_fn_ptr_t)( + sycl::queue, + size_t, + int, + const py::ssize_t *, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &, + const std::vector &); + +template +sycl::event +sqrt_strided_impl(sycl::queue exec_q, + size_t nelems, + int nd, + const py::ssize_t *shape_and_strides, + const char *arg_p, + py::ssize_t arg_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.depends_on(additional_depends); + + using resTy = typename SqrtOutputType::value_type; + using IndexerT = + typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + IndexerT arg_res_indexer(nd, arg_offset, res_offset, shape_and_strides); + + const argTy *arg_tp = reinterpret_cast(arg_p); + resTy *res_tp = reinterpret_cast(res_p); + + sycl::range<1> gRange{nelems}; + + cgh.parallel_for>( + gRange, SqrtStridedFunctor( + arg_tp, res_tp, arg_res_indexer)); + }); + return comp_ev; +} + +template struct SqrtStridedFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = sqrt_strided_impl; + return fn; + } + } +}; + +} // namespace sqrt +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp index 65b2b1d9a1..acae73541b 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions.cpp @@ -38,6 +38,7 @@ #include "kernels/elementwise_functions/isfinite.hpp" #include "kernels/elementwise_functions/isinf.hpp" #include "kernels/elementwise_functions/isnan.hpp" +#include "kernels/elementwise_functions/sqrt.hpp" namespace dpctl { @@ -325,6 +326,43 @@ void populate_add_dispatch_tables(void) } // namespace impl +// SQRT +namespace impl +{ + +namespace sqrt_fn_ns = dpctl::tensor::kernels::sqrt; +using sqrt_fn_ns::sqrt_contig_impl_fn_ptr_t; +using sqrt_fn_ns::sqrt_strided_impl_fn_ptr_t; + +static sqrt_contig_impl_fn_ptr_t sqrt_contig_dispatch_vector[td_ns::num_types]; +static int sqrt_output_typeid_vector[td_ns::num_types]; +static sqrt_strided_impl_fn_ptr_t + sqrt_strided_dispatch_vector[td_ns::num_types]; + +void populate_sqrt_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = sqrt_fn_ns; + + using fn_ns::SqrtContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(sqrt_contig_dispatch_vector); + + using fn_ns::SqrtStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(sqrt_strided_dispatch_vector); + + using fn_ns::SqrtTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(sqrt_output_typeid_vector); +} + +} // namespace impl + namespace py = pybind11; void init_elementwise_functions(py::module_ m) @@ -628,7 +666,26 @@ void init_elementwise_functions(py::module_ m) // FIXME: // U33: ==== SQRT (x) - // FIXME: + { + impl::populate_sqrt_dispatch_vectors(); + using impl::sqrt_contig_dispatch_vector; + using impl::sqrt_output_typeid_vector; + using impl::sqrt_strided_dispatch_vector; + + auto sqrt_pyapi = [&](arrayT src, arrayT dst, sycl::queue exec_q, + const event_vecT &depends = {}) { + return py_unary_ufunc( + src, dst, exec_q, depends, sqrt_output_typeid_vector, + sqrt_contig_dispatch_vector, sqrt_strided_dispatch_vector); + }; + m.def("_sqrt", sqrt_pyapi, "", py::arg("src"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto sqrt_result_type_pyapi = [&](py::dtype dtype) { + return py_unary_ufunc_result_type(dtype, sqrt_output_typeid_vector); + }; + m.def("_sqrt_result_type", sqrt_result_type_pyapi); + } // B23: ==== SUBTRACT (x1, x2) // FIXME: From 1798bf27c14fd0708989b839a561e82a9cf955bd Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Fri, 12 May 2023 13:57:03 -0500 Subject: [PATCH 11/48] Removed unneeded temporaries --- .../include/kernels/elementwise_functions/common.hpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp index 16b2deb914..352dda5fdc 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp @@ -240,16 +240,13 @@ struct UnaryStridedFunctor void operator()(sycl::id<1> wid) const { - const argT *const &in = inp_; - resT *const &out = res_; - - auto offsets_ = inp_out_indexer_(wid.get(0)); + const auto &offsets_ = inp_out_indexer_(wid.get(0)); const py::ssize_t &inp_offset = offsets_.get_first_offset(); - const py::ssize_t &out_offset = offsets_.get_second_offset(); + const py::ssize_t &res_offset = offsets_.get_second_offset(); UnaryOpT op{}; - out[out_offset] = op(in[inp_offset]); + res_[res_offset] = op(inp_[inp_offset]); } }; From d2c9aa4e7cb9d7dd0ba2ae25b149b3b2c7300f78 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 15 May 2023 05:25:26 -0500 Subject: [PATCH 12/48] Use sycl::isinf for vec, but std::inf for scalars --- .../include/kernels/elementwise_functions/isinf.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp index 98302617e3..0da88e687e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp @@ -32,7 +32,9 @@ template struct IsInfFunctor using is_constant = typename std::disjunction, std::is_integral>; static constexpr resT constant_value = false; - using supports_vec = typename std::false_type; + using supports_vec = + typename std::disjunction, + std::is_floating_point>; using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -56,8 +58,6 @@ template struct IsInfFunctor } } - // unused (since support_vec is set to false_type) due to bug in sycl::isinf - // implementation in OpenCL CPU RT template sycl::vec operator()(const sycl::vec &in) { From ca46b1b34118f1c5783409036bcf63f0f948bbc8 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 15 May 2023 06:09:32 -0500 Subject: [PATCH 13/48] Removed -H as it is not helpful, but verbose --- dpctl/tensor/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index f74ae7ccc5..80eb254237 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -53,7 +53,7 @@ if (WIN32) endif() set_source_files_properties( ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp - PROPERTIES COMPILE_OPTIONS "${_clang_prefix}-fno-fast-math;${_clang_prefix}-fno-finite-math-only;${_clang_prefix}-fno-approx-func;-H") + PROPERTIES COMPILE_OPTIONS "${_clang_prefix}-fno-fast-math") target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int) target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel) if(UNIX) From 88e89315d5d64bcf0d6e698b824530d473786cf2 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 16 May 2023 08:34:03 -0500 Subject: [PATCH 14/48] Update meta.yaml to require sysroot_linux-64 >=2.28, and use dppy/label/tools to get it until it comes online from conda-forge --- .github/workflows/conda-package.yml | 2 +- conda-recipe/meta.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index 6e5ab73645..3e602e31e3 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -44,7 +44,7 @@ jobs: run: conda install conda-build - name: Build conda package run: | - CHANNELS="-c intel -c main --override-channels" + CHANNELS="-c dppy/label/tools -c intel -c main --override-channels" VERSIONS="--python ${{ matrix.python }}" TEST="--no-test" conda build \ diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index a2bc599e38..0bcfe56c18 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -14,7 +14,7 @@ requirements: build: - {{ compiler('cxx') }} - {{ compiler('dpcpp') }} >=2023.1 # [not osx] - - sysroot_linux-64 >=2.17 # [linux] + - sysroot_linux-64 >=2.28 # [linux] host: - setuptools - cmake >=3.21 From 95a8142e5aa8b490686565afcac4b940ee40faf0 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 16 May 2023 11:11:19 -0500 Subject: [PATCH 15/48] Use verbose lsplatform in run_test.sh --- conda-recipe/run_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conda-recipe/run_test.sh b/conda-recipe/run_test.sh index 66193be58c..4b3566d7bd 100644 --- a/conda-recipe/run_test.sh +++ b/conda-recipe/run_test.sh @@ -3,5 +3,5 @@ set -e ${PYTHON} -c "import dpctl; print(dpctl.__version__)" -${PYTHON} -c "import dpctl; dpctl.lsplatform()" +${PYTHON} -c "import dpctl; dpctl.lsplatform(verbosity=2)" ${PYTHON} -m pytest -q -ra --disable-warnings -p no:faulthandler --cov dpctl --cov-report term-missing --pyargs dpctl -vv From b1495a266bdfb30fc048e7766f384746f57a2d00 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 16 May 2023 09:41:48 -0500 Subject: [PATCH 16/48] Added gdb call in test_linux as a separate step This calls crashing test_tensor_elementwise under gdb in batch mode in CI. gdb call exit code is ignored --- .github/workflows/conda-package.yml | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index 3e602e31e3..6bd208a326 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -161,13 +161,20 @@ jobs: run: | . $CONDA/etc/profile.d/conda.sh conda activate test_dpctl - python -c "import dpctl; dpctl.lsplatform()" + python -c "import dpctl; dpctl.lsplatform(verbosity=2)" + - name: Install gdb + run: | + sudo apt-get install -y gdb + - name: Run test_elementwise under gdb + run: | + . $CONDA/etc/profile.d/conda.sh + conda activate test_dpctl + gdb --batch -ex r -ex 'info sharedlibrary' -ex 'set print elements 1000' -ex bt --args ${CONDA_PREFIX}/bin/python -m pytest -q -ra --disable-warnings --pyargs dpctl.tests.test_tensor_elementwise::test_cos_order -vv || true - name: Run tests run: | . $CONDA/etc/profile.d/conda.sh conda activate test_dpctl - # clinfo -l - python -m pytest --pyargs $MODULE_NAME + python -m pytest -v --pyargs $MODULE_NAME test_windows: needs: build_windows From f3535f1d5f57cea0afba61ad14075bb445ba5e5c Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 16 May 2023 19:24:06 -0500 Subject: [PATCH 17/48] Created templates for binary functions too, applied for addition --- .../kernels/elementwise_functions/add.hpp | 217 ++++++------------ .../kernels/elementwise_functions/common.hpp | 170 ++++++++++++++ 2 files changed, 242 insertions(+), 145 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index f26cfefbc2..9529b73d6c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -7,6 +7,8 @@ #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" #include "utils/type_utils.hpp" + +#include "kernels/elementwise_functions/common.hpp" #include namespace dpctl @@ -20,101 +22,60 @@ namespace add namespace py = pybind11; namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; -template -struct AddContigFunctor +template struct AddFunctor { -private: - const argT1 *in1 = nullptr; - const argT2 *in2 = nullptr; - resT *out = nullptr; - const size_t nelems_; - -public: - AddContigFunctor(const argT1 *inp1, - const argT2 *inp2, - resT *res, - const size_t n_elems) - : in1(inp1), in2(inp2), out(res), nelems_(n_elems) + + using supports_sg_loadstore = std::negation< + std::disjunction, tu_ns::is_complex>>; + using supports_vec = std::negation< + std::disjunction, tu_ns::is_complex>>; + + resT operator()(const argT1 &in1, const argT2 &in2) { + return in1 + in2; } - void operator()(sycl::nd_item<1> ndit) const + template + sycl::vec operator()(const sycl::vec &in1, + const sycl::vec &in2) { - /* Each work-item processes vec_sz elements, contiguous in memory */ - - using dpctl::tensor::type_utils::is_complex; - if constexpr (is_complex::value || is_complex::value) { - std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; - size_t base = ndit.get_global_linear_id(); - - base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); - for (size_t offset = base; - offset < std::min(nelems_, base + sgSize * (n_vecs * vec_sz)); - offset += sgSize) - { - out[offset] = in1[offset] + in2[offset]; - } + auto tmp = in1 + in2; + if constexpr (std::is_same_v) { + return tmp; } else { - auto sg = ndit.get_sub_group(); - std::uint8_t sgSize = sg.get_local_range()[0]; - std::uint8_t maxsgSize = sg.get_max_local_range()[0]; - size_t base = n_vecs * vec_sz * - (ndit.get_group(0) * ndit.get_local_range(0) + - sg.get_group_id()[0] * maxsgSize); + using dpctl::tensor::type_utils::vec_cast; - if ((base + n_vecs * vec_sz * sgSize < nelems_) && - (sgSize == maxsgSize)) { - using in_ptrT1 = - sycl::multi_ptr; - using in_ptrT2 = - sycl::multi_ptr; - using out_ptrT = - sycl::multi_ptr; - sycl::vec arg1_vec; - sycl::vec arg2_vec; - sycl::vec res_vec; - -#pragma unroll - for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { - arg1_vec = - sg.load(in_ptrT1(&in1[base + it * sgSize])); - arg2_vec = - sg.load(in_ptrT2(&in2[base + it * sgSize])); - if constexpr (std::is_same_v && - std::is_same_v) { - res_vec = arg1_vec + arg2_vec; - } - else { - using dpctl::tensor::type_utils::vec_cast; - - auto tmp = arg1_vec + arg2_vec; - res_vec = std::move( - vec_cast(tmp)); - } - sg.store(out_ptrT(&out[base + it * sgSize]), - res_vec); - } - } - else { - for (size_t k = base + sg.get_local_id()[0]; k < nelems_; - k += sgSize) { - out[k] = in1[k] + in2[k]; - } - } + return vec_cast( + tmp); } } }; +template +using AddContigFunctor = + elementwise_common::BinaryContigFunctor, + vec_sz, + n_vecs>; + +template +using AddStridedFunctor = + elementwise_common::BinaryStridedFunctor>; + template struct AddOutputType { using value_type = typename std::disjunction< // disjunction is C++17 @@ -257,41 +218,6 @@ template struct AddTypeMapFactory } }; -template -struct AddStridedFunctor -{ -private: - const argT1 *in1 = nullptr; - const argT2 *in2 = nullptr; - resT *out = nullptr; - ThreeOffsets_IndexerT three_offsets_indexer_; - -public: - AddStridedFunctor(const argT1 *inp1_tp, - const argT2 *inp2_tp, - resT *res_tp, - ThreeOffsets_IndexerT inps_res_indexer) - : in1(inp1_tp), in2(inp2_tp), out(res_tp), - three_offsets_indexer_(inps_res_indexer) - { - } - - void operator()(sycl::id<1> wid) const - { - const auto &three_offsets_ = - three_offsets_indexer_(static_cast(wid.get(0))); - - const auto &inp1_offset = three_offsets_.get_first_offset(); - const auto &inp2_offset = three_offsets_.get_second_offset(); - const auto &out_offset = three_offsets_.get_third_offset(); - - out[out_offset] = in1[inp1_offset] + in2[inp2_offset]; - } -}; - template class add_strided_strided_kernel; @@ -435,40 +361,41 @@ sycl::event add_contig_matrix_contig_row_broadcast_impl( size_t n_groups = (n_elems + lws - 1) / lws; auto gwsRange = sycl::range<1>(n_groups * lws); - cgh.parallel_for>( - sycl::nd_range<1>(gwsRange, lwsRange), - [=](sycl::nd_item<1> ndit) - { - auto sg = ndit.get_sub_group(); - size_t gid = ndit.get_global_linear_id(); + cgh.parallel_for>( + sycl::nd_range<1>(gwsRange, lwsRange), + [=](sycl::nd_item<1> ndit) + { + auto sg = ndit.get_sub_group(); + size_t gid = ndit.get_global_linear_id(); - std::uint8_t sgSize = sg.get_local_range()[0]; - size_t base = gid - sg.get_local_id()[0]; + std::uint8_t sgSize = sg.get_local_range()[0]; + size_t base = gid - sg.get_local_id()[0]; - if (base + sgSize < n_elems) { - using in_ptrT1 = sycl::multi_ptr< - const argT1, sycl::access::address_space::global_space>; - using in_ptrT2 = sycl::multi_ptr< - const argT2, sycl::access::address_space::global_space>; - using res_ptrT = sycl::multi_ptr< - resT, sycl::access::address_space::global_space>; + if (base + sgSize < n_elems) { + using in_ptrT1 = + sycl::multi_ptr; + using in_ptrT2 = + sycl::multi_ptr; + using res_ptrT = + sycl::multi_ptr; - const argT1 mat_el = sg.load(in_ptrT1(&mat[base])); - const argT2 vec_el = - sg.load(in_ptrT2(&padded_vec[base % n1])); + const argT1 mat_el = sg.load(in_ptrT1(&mat[base])); + const argT2 vec_el = sg.load(in_ptrT2(&padded_vec[base % n1])); - resT res_el = mat_el + vec_el; + resT res_el = mat_el + vec_el; - sg.store(res_ptrT(&res[base]), res_el); - } - else { - for (size_t k = base + sg.get_local_id()[0]; k < n_elems; - k += sgSize) { - res[k] = mat[k] + padded_vec[k % n1]; - } - } + sg.store(res_ptrT(&res[base]), res_el); + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < n_elems; + k += sgSize) { + res[k] = mat[k] + padded_vec[k % n1]; } - ); + } + }); }); sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp index 352dda5fdc..37978ab1eb 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp @@ -250,6 +250,176 @@ struct UnaryStridedFunctor } }; +template +struct BinaryContigFunctor +{ +private: + const argT1 *in1 = nullptr; + const argT2 *in2 = nullptr; + resT *out = nullptr; + const size_t nelems_; + +public: + BinaryContigFunctor(const argT1 *inp1, + const argT2 *inp2, + resT *res, + const size_t n_elems) + : in1(inp1), in2(inp2), out(res), nelems_(n_elems) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + BinaryOperatorT op{}; + /* Each work-item processes vec_sz elements, contiguous in memory */ + + if constexpr (BinaryOperatorT::supports_sg_loadstore::value && + BinaryOperatorT::supports_vec::value) + { + auto sg = ndit.get_sub_group(); + std::uint8_t sgSize = sg.get_local_range()[0]; + std::uint8_t maxsgSize = sg.get_max_local_range()[0]; + + size_t base = n_vecs * vec_sz * + (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if ((base + n_vecs * vec_sz * sgSize < nelems_) && + (sgSize == maxsgSize)) { + using in_ptrT1 = + sycl::multi_ptr; + using in_ptrT2 = + sycl::multi_ptr; + using out_ptrT = + sycl::multi_ptr; + sycl::vec arg1_vec; + sycl::vec arg2_vec; + sycl::vec res_vec; + +#pragma unroll + for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + arg1_vec = + sg.load(in_ptrT1(&in1[base + it * sgSize])); + arg2_vec = + sg.load(in_ptrT2(&in2[base + it * sgSize])); + res_vec = op(arg1_vec, arg2_vec); + sg.store(out_ptrT(&out[base + it * sgSize]), + res_vec); + } + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < nelems_; + k += sgSize) { + out[k] = op(in1[k], in2[k]); + } + } + } + else if constexpr (BinaryOperatorT::supports_sg_loadstore::value) { + auto sg = ndit.get_sub_group(); + std::uint8_t sgSize = sg.get_local_range()[0]; + std::uint8_t maxsgSize = sg.get_max_local_range()[0]; + + size_t base = n_vecs * vec_sz * + (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if ((base + n_vecs * vec_sz * sgSize < nelems_) && + (sgSize == maxsgSize)) { + using in_ptrT1 = + sycl::multi_ptr; + using in_ptrT2 = + sycl::multi_ptr; + using out_ptrT = + sycl::multi_ptr; + sycl::vec arg1_vec; + sycl::vec arg2_vec; + sycl::vec res_vec; + +#pragma unroll + for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) { + arg1_vec = + sg.load(in_ptrT1(&in1[base + it * sgSize])); + arg2_vec = + sg.load(in_ptrT2(&in2[base + it * sgSize])); +#pragma unroll + for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) { + res_vec[vec_id] = + op(arg1_vec[vec_id], arg2_vec[vec_id]); + } + sg.store(out_ptrT(&out[base + it * sgSize]), + res_vec); + } + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < nelems_; + k += sgSize) { + out[k] = op(in1[k], in2[k]); + } + } + } + else { + std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0]; + size_t base = ndit.get_global_linear_id(); + + base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize); + for (size_t offset = base; + offset < std::min(nelems_, base + sgSize * (n_vecs * vec_sz)); + offset += sgSize) + { + out[offset] = op(in1[offset], in2[offset]); + } + } + } +}; + +template +struct BinaryStridedFunctor +{ +private: + const argT1 *in1 = nullptr; + const argT2 *in2 = nullptr; + resT *out = nullptr; + ThreeOffsets_IndexerT three_offsets_indexer_; + +public: + BinaryStridedFunctor(const argT1 *inp1_tp, + const argT2 *inp2_tp, + resT *res_tp, + ThreeOffsets_IndexerT inps_res_indexer) + : in1(inp1_tp), in2(inp2_tp), out(res_tp), + three_offsets_indexer_(inps_res_indexer) + { + } + + void operator()(sycl::id<1> wid) const + { + const auto &three_offsets_ = + three_offsets_indexer_(static_cast(wid.get(0))); + + const auto &inp1_offset = three_offsets_.get_first_offset(); + const auto &inp2_offset = three_offsets_.get_second_offset(); + const auto &out_offset = three_offsets_.get_third_offset(); + + BinaryOperatorT op{}; + out[out_offset] = op(in1[inp1_offset], in2[inp2_offset]); + } +}; + } // namespace elementwise_common } // namespace kernels } // namespace tensor From 73e979d3682803e375c598598fee00c29e6130ab Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 16 May 2023 19:24:19 -0500 Subject: [PATCH 18/48] Added NullPtrTable and NullPtrVector classes (yet unused) --- .../libtensor/include/utils/type_dispatch.hpp | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp b/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp index 07fbbc6baf..378173bb4b 100644 --- a/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp +++ b/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp @@ -333,6 +333,96 @@ template struct GetTypeid } }; +/*! @brief Class to generate vector of null function pointers */ +template struct NullPtrVector +{ + + using iterator_category = std::forward_iterator_tag; + using different_type = std::ptrdiff_t; + using value_type = FunPtrT; + using pointer = value_type *; + using reference = value_type &; + + NullPtrVector() : val(nullptr) {} + + reference operator*() + { + return val; + } + + reference operator[](int) + { + return val; + } + + NullPtrVector &operator++() + { + return *this; + } + NullPtrVector operator++(int) + { + return *this; + } + + friend bool operator==(const NullPtrVector &a, + const NullPtrVector &b) + { + return true; + } + friend bool operator!=(const NullPtrVector &a, + const NullPtrVector &b) + { + return false; + } + +private: + value_type val; +}; + +/*! @brief Class to generate table of null function pointers */ +template struct NullPtrTable +{ + using iterator_category = std::forward_iterator_tag; + using different_type = std::ptrdiff_t; + using value_type = NullPtrVector; + using pointer = value_type *; + using reference = value_type &; + + NullPtrTable() : val() {} + + reference operator*() + { + return val; + } + reference operator[](int) + { + return val; + } + + NullPtrTable &operator++() + { + return *this; + } + NullPtrTable operator++(int) + { + return *this; + } + + friend bool operator==(const NullPtrTable &a, + const NullPtrTable &b) + { + return true; + } + friend bool operator!=(const NullPtrTable &a, + const NullPtrTable &b) + { + return false; + } + +private: + value_type val; +}; + } // namespace type_dispatch } // namespace tensor From 8077aac6c685f2af48e53a93f89b06cb3eae0762 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 17 May 2023 10:14:19 -0500 Subject: [PATCH 19/48] No all binary functions are symmetric, modified py_binary_func accordingly Introduced contig_row_contig_matrix_broadcasting_impl_fn_ptr_t and corresponding table. Implemented that for Add to fall back on contig_matrix_contrig_row_broadcasting_fn. It would be good to have a specialization for symmetric variants. --- .../kernels/elementwise_functions/add.hpp | 61 +++++++++++++++++++ .../source/elementwise_functions.cpp | 37 ++++++++--- .../source/elementwise_functions.hpp | 22 ++++--- 3 files changed, 103 insertions(+), 17 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index 9529b73d6c..3d9c0aaaea 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -436,6 +436,67 @@ struct AddContigMatrixContigRowBroadcastFactory } }; +typedef sycl::event (*add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t)( + sycl::queue, + std::vector &, + size_t, + size_t, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &); + +template +sycl::event add_contig_row_contig_matrix_broadcast_impl( + sycl::queue exec_q, + std::vector &host_tasks, + size_t n0, + size_t n1, + const char *vec_p, // typeless pointer to (n1,) contiguous row + py::ssize_t vec_offset, + const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix + py::ssize_t mat_offset, + char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, + // res[i,j] = mat[i,j] + vec[j] + py::ssize_t res_offset, + const std::vector &depends = {}) +{ + return add_contig_matrix_contig_row_broadcast_impl( + exec_q, host_tasks, n0, n1, mat_p, mat_offset, vec_p, vec_offset, res_p, + res_offset, depends); +}; + +template +struct AddContigRowContigMatrixBroadcastFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + fnT fn = nullptr; + return fn; + } + else { + using resT = typename AddOutputType::value_type; + if constexpr (dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = + add_contig_row_contig_matrix_broadcast_impl; + return fn; + } + } + } +}; + } // namespace add } // namespace kernels } // namespace tensor diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp index 65b2b1d9a1..de68f74c0d 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions.cpp @@ -283,6 +283,7 @@ namespace fn_ns = dpctl::tensor::kernels::add; using fn_ns::add_contig_impl_fn_ptr_t; using fn_ns::add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using fn_ns::add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; using fn_ns::add_strided_impl_fn_ptr_t; static add_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] @@ -292,28 +293,37 @@ static int add_output_id_table[td_ns::num_types][td_ns::num_types]; static add_strided_impl_fn_ptr_t add_strided_dispatch_table[td_ns::num_types] [td_ns::num_types]; +// add(matrix, row) static add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t add_contig_matrix_contig_row_broadcast_dispatch_table[td_ns::num_types] [td_ns::num_types]; +// add(row, matrix) +static add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t + add_contig_row_contig_matrix_broadcast_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + void populate_add_dispatch_tables(void) { using namespace td_ns; - using fn_ns::AddContigFactory; - DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(add_contig_dispatch_table); + // which input types are supported, and what is the type of the result + using fn_ns::AddTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(add_output_id_table); + // function pointers for operation on general strided arrays using fn_ns::AddStridedFactory; DispatchTableBuilder dtb2; dtb2.populate_dispatch_table(add_strided_dispatch_table); - using fn_ns::AddTypeMapFactory; - DispatchTableBuilder dtb3; - dtb3.populate_dispatch_table(add_output_id_table); + // function pointers for operation on contiguous inputs and outputs + using fn_ns::AddContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(add_contig_dispatch_table); using fn_ns::AddContigMatrixContigRowBroadcastFactory; DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table( + add_contig_row_contig_matrix_broadcast_dispatch_table); }; } // namespace impl @@ -365,6 +382,7 @@ void init_elementwise_functions(py::module_ m) impl::populate_add_dispatch_tables(); using impl::add_contig_dispatch_table; using impl::add_contig_matrix_contig_row_broadcast_dispatch_table; + using impl::add_contig_row_contig_matrix_broadcast_dispatch_table; using impl::add_output_id_table; using impl::add_strided_dispatch_table; @@ -382,7 +400,10 @@ void init_elementwise_functions(py::module_ m) add_strided_dispatch_table, // function pointers to handle operation of c-contig matrix and // c-contig row with broadcasting (may be nullptr) - add_contig_matrix_contig_row_broadcast_dispatch_table); + add_contig_matrix_contig_row_broadcast_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + add_contig_row_contig_matrix_broadcast_dispatch_table); }; auto add_result_type_pyapi = [&](py::dtype dtype1, py::dtype dtype2) { return py_binary_ufunc_result_type(dtype1, dtype2, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.hpp b/dpctl/tensor/libtensor/source/elementwise_functions.hpp index c29ced5e1d..b76e3da297 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions.hpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions.hpp @@ -297,7 +297,8 @@ bool isEqual(Container const &c, std::initializer_list const &l) template + typename contig_matrix_row_dispatchT, + typename contig_row_matrix_dispatchT> std::pair py_binary_ufunc( dpctl::tensor::usm_ndarray src1, dpctl::tensor::usm_ndarray src2, @@ -308,7 +309,10 @@ std::pair py_binary_ufunc( const output_typesT &output_type_table, const contig_dispatchT &contig_dispatch_table, const strided_dispatchT &strided_dispatch_table, - const matrix_row_dispatchT &contig_matrix_row_broadcast_dispatch_table) + const contig_matrix_row_dispatchT + &contig_matrix_row_broadcast_dispatch_table, + const contig_row_matrix_dispatchT + &contig_row_matrix_broadcast_dispatch_table) { // check type_nums int src1_typenum = src1.get_typenum(); @@ -507,15 +511,15 @@ std::pair py_binary_ufunc( isEqual(simplified_src2_strides, {one, simplified_shape[0]}) && isEqual(simplified_dst_strides, {one, simplified_shape[0]})) { - auto matrix_row_broadcast_fn = - contig_matrix_row_broadcast_dispatch_table[src2_typeid] - [src1_typeid]; - if (matrix_row_broadcast_fn != nullptr) { + auto row_matrix_broadcast_fn = + contig_row_matrix_broadcast_dispatch_table[src1_typeid] + [src2_typeid]; + if (row_matrix_broadcast_fn != nullptr) { size_t n0 = simplified_shape[1]; size_t n1 = simplified_shape[0]; - sycl::event comp_ev = matrix_row_broadcast_fn( - exec_q, host_tasks, n0, n1, src2_data, src2_offset, - src1_data, src1_offset, dst_data, dst_offset, depends); + sycl::event comp_ev = row_matrix_broadcast_fn( + exec_q, host_tasks, n0, n1, src1_data, src1_offset, + src2_data, src2_offset, dst_data, dst_offset, depends); return std::make_pair( dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, From d699a5f8628a3ae31928622a3f3d2be0cc8ebd09 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 17 May 2023 10:19:38 -0500 Subject: [PATCH 20/48] Fixed type in the test name: vaidation->validation --- dpctl/tests/test_tensor_elementwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tests/test_tensor_elementwise.py b/dpctl/tests/test_tensor_elementwise.py index 1a5ecc0dd9..0e84b4db73 100644 --- a/dpctl/tests/test_tensor_elementwise.py +++ b/dpctl/tests/test_tensor_elementwise.py @@ -168,7 +168,7 @@ def test_unary_func_arg_validation(): dpt.abs(a, order="invalid") -def test_binary_func_arg_vaidation(): +def test_binary_func_arg_validation(): with pytest.raises(dpctl.utils.ExecutionPlacementError): dpt.add([1, 2, 3], 1) try: From 5d37d7466af18b0bbf1d87c3a9aad10d37e3c2c2 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 17 May 2023 13:33:59 -0700 Subject: [PATCH 21/48] Added tests for dpctl.tensor.sqrt --- dpctl/tests/test_tensor_sqrt.py | 133 ++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 dpctl/tests/test_tensor_sqrt.py diff --git a/dpctl/tests/test_tensor_sqrt.py b/dpctl/tests/test_tensor_sqrt.py new file mode 100644 index 0000000000..2f924027dc --- /dev/null +++ b/dpctl/tests/test_tensor_sqrt.py @@ -0,0 +1,133 @@ +import itertools + +import numpy as np +import pytest +from numpy.testing import assert_equal + +import dpctl.tensor as dpt +import dpctl.tensor._type_utils as tu +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +_all_dtypes = [ + "b1", + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", +] +_usm_types = ["device", "shared", "host"] + + +def _map_to_device_dtype(dt, dev): + return tu._to_device_supported_dtype(dt, dev) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_sqrt_out_type(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + X = dpt.asarray(0, dtype=dtype, sycl_queue=q) + expected_dtype = np.sqrt(np.array(0, dtype=dtype)).dtype + expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device) + assert dpt.sqrt(X).dtype == expected_dtype + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"]) +def test_sqrt_output_contig(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n_seq = 1027 + + X = dpt.linspace(0, 13, num=n_seq, dtype=dtype, sycl_queue=q) + Xnp = dpt.asnumpy(X) + + Y = dpt.sqrt(X) + tol = 8 * dpt.finfo(Y.dtype).resolution + + np.testing.assert_allclose(dpt.asnumpy(Y), np.sqrt(Xnp), atol=tol, rtol=tol) + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"]) +def test_sqrt_output_strided(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n_seq = 2054 + + X = dpt.linspace(0, 13, num=n_seq, dtype=dtype, sycl_queue=q)[::-2] + Xnp = dpt.asnumpy(X) + + Y = dpt.sqrt(X) + tol = 8 * dpt.finfo(Y.dtype).resolution + + np.testing.assert_allclose(dpt.asnumpy(Y), np.sqrt(Xnp), atol=tol, rtol=tol) + + +@pytest.mark.parametrize("usm_type", _usm_types) +def test_sqrt_usm_type(usm_type): + q = get_queue_or_skip() + + arg_dt = np.dtype("f4") + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q) + X[..., 0::2] = 16.0 + X[..., 1::2] = 23.0 + + Y = dpt.sqrt(X) + assert Y.usm_type == X.usm_type + assert Y.sycl_queue == X.sycl_queue + assert Y.flags.c_contiguous + + expected_Y = np.empty(input_shape, dtype=arg_dt) + expected_Y[..., 0::2] = np.sqrt(np.float32(16.0)) + expected_Y[..., 1::2] = np.sqrt(np.float32(23.0)) + tol = 8 * dpt.finfo(Y.dtype).resolution + + np.testing.assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_sqrt_order(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) + X[..., 0::2] = 16.0 + X[..., 1::2] = 23.0 + + for ord in ["C", "F", "A", "K"]: + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms) + Y = dpt.sqrt(U, order=ord) + expected_Y = np.sqrt(dpt.asnumpy(U)) + tol = 8 * max( + dpt.finfo(Y.dtype).resolution, + np.finfo(expected_Y.dtype).resolution, + ) + np.testing.assert_allclose( + dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol + ) + + +def test_sqrt_special_cases(): + q = get_queue_or_skip() + + X = dpt.asarray( + [dpt.nan, -1.0, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4", sycl_queue=q + ) + Xnp = dpt.asnumpy(X) + + assert_equal(dpt.asnumpy(dpt.sqrt(X)), np.sqrt(Xnp)) From 863ba3b0625548fac3924915d6f3f514ff83c911 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 17 May 2023 15:30:25 -0500 Subject: [PATCH 22/48] Renamed private sycl::event variable --- .../libtensor/include/kernels/elementwise_functions/abs.hpp | 4 ++-- .../libtensor/include/kernels/elementwise_functions/isinf.hpp | 4 ++-- .../libtensor/include/kernels/elementwise_functions/isnan.hpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index 1ad27037b3..266f0bed8c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -178,7 +178,7 @@ sycl::event abs_strided_impl(sycl::queue exec_q, const std::vector &depends, const std::vector &additional_depends) { - sycl::event abs_ev = exec_q.submit([&](sycl::handler &cgh) { + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); cgh.depends_on(additional_depends); @@ -195,7 +195,7 @@ sycl::event abs_strided_impl(sycl::queue exec_q, {nelems}, AbsStridedFunctor(arg_tp, res_tp, indexer)); }); - return abs_ev; + return comp_ev; } template struct AbsStridedFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp index 0da88e687e..74f10fdc4e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp @@ -173,7 +173,7 @@ isinf_strided_impl(sycl::queue exec_q, const std::vector &depends, const std::vector &additional_depends) { - sycl::event abs_ev = exec_q.submit([&](sycl::handler &cgh) { + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); cgh.depends_on(additional_depends); @@ -192,7 +192,7 @@ isinf_strided_impl(sycl::queue exec_q, gRange, IsInfStridedFunctor( arg_tptr, res_tptr, arg_res_indexer)); }); - return abs_ev; + return comp_ev; } template struct IsInfStridedFactory diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp index 73522a4b7e..e0bdae1ca6 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp @@ -321,7 +321,7 @@ isnan_strided_impl(sycl::queue exec_q, const std::vector &depends, const std::vector &additional_depends) { - sycl::event abs_ev = exec_q.submit([&](sycl::handler &cgh) { + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); cgh.depends_on(additional_depends); @@ -338,7 +338,7 @@ isnan_strided_impl(sycl::queue exec_q, {nelems}, IsNanStridedFunctor( arg_tptr, res_tptr, arg_res_indexer)); }); - return abs_ev; + return comp_ev; } template struct IsNanStridedFactory From b088a523eeff26e01fcff33d2481eafef526b603 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 17 May 2023 15:30:51 -0500 Subject: [PATCH 23/48] Use std::cos for complex types as well. --- .../kernels/elementwise_functions/cos.hpp | 20 +------------------ 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index 4cc6454404..2dbfde12aa 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -41,25 +41,7 @@ template struct CosFunctor resT operator()(const argT &in) { - if constexpr (is_complex::value) { - using realT = typename argT::value_type; - // cos(x + I*y) = cos(x)*cosh(y) - I*sin(x)*sinh(y) - auto v = std::real(in); - realT cosX_val; - const realT sinX_val = sycl::sincos(-v, &cosX_val); - v = std::imag(in); - const realT sinhY_val = sycl::sinh(v); - const realT coshY_val = sycl::cosh(v); - - const realT res_re = coshY_val * cosX_val; - const realT res_im = sinX_val * sinhY_val; - return resT{res_re, res_im}; - } - else { - static_assert(std::is_floating_point_v || - std::is_same_v); - return std::cos(in); - } + return std::cos(in); } }; From 6b56e9983714ca446606f05c986134c1a30cd97a Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 17 May 2023 15:32:24 -0500 Subject: [PATCH 24/48] Added BinaryContigMatrixContigRowBroadcastFunctor and RowMatrix variant Added to common.hpp templated callable to generate kernels for arbitrary binary operator. Applied that to addition code. Also implemented true_divide operator, exported as _tensor_impl._divide --- .../kernels/elementwise_functions/add.hpp | 66 +-- .../kernels/elementwise_functions/common.hpp | 124 ++++ .../elementwise_functions/true_divide.hpp | 544 ++++++++++++++++++ .../source/elementwise_functions.cpp | 135 ++++- 4 files changed, 817 insertions(+), 52 deletions(-) create mode 100644 dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index 3d9c0aaaea..2884e06f19 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -163,7 +163,7 @@ sycl::event add_contig_impl(sycl::queue exec_q, py::ssize_t res_offset, const std::vector &depends = {}) { - sycl::event add_ev = exec_q.submit([&](sycl::handler &cgh) { + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); size_t lws = 64; @@ -188,7 +188,7 @@ sycl::event add_contig_impl(sycl::queue exec_q, AddContigFunctor( arg1_tp, arg2_tp, res_tp, nelems)); }); - return add_ev; + return comp_ev; } template struct AddContigFactory @@ -249,7 +249,7 @@ sycl::event add_strided_impl(sycl::queue exec_q, const std::vector &depends, const std::vector &additional_depends) { - sycl::event abs_ev = exec_q.submit([&](sycl::handler &cgh) { + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); cgh.depends_on(additional_depends); @@ -270,7 +270,7 @@ sycl::event add_strided_impl(sycl::queue exec_q, {nelems}, AddStridedFunctor( arg1_tp, arg2_tp, res_tp, indexer)); }); - return abs_ev; + return comp_ev; } template struct AddStridedFactory @@ -290,7 +290,7 @@ template struct AddStridedFactory }; template -class add_matrix_vector_broadcast_sg_krn; +class add_matrix_row_broadcast_sg_krn; typedef sycl::event (*add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t)( sycl::queue, @@ -305,6 +305,14 @@ typedef sycl::event (*add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t)( py::ssize_t, const std::vector &); +template +using AddContigMatrixContigRowBroadcastingFunctor = + elementwise_common::BinaryContigMatrixContigRowBroadcastingFunctor< + argT1, + argT2, + resT, + AddFunctor>; + template sycl::event add_contig_matrix_contig_row_broadcast_impl( sycl::queue exec_q, @@ -361,41 +369,11 @@ sycl::event add_contig_matrix_contig_row_broadcast_impl( size_t n_groups = (n_elems + lws - 1) / lws; auto gwsRange = sycl::range<1>(n_groups * lws); - cgh.parallel_for>( + cgh.parallel_for< + class add_matrix_row_broadcast_sg_krn>( sycl::nd_range<1>(gwsRange, lwsRange), - [=](sycl::nd_item<1> ndit) - { - auto sg = ndit.get_sub_group(); - size_t gid = ndit.get_global_linear_id(); - - std::uint8_t sgSize = sg.get_local_range()[0]; - size_t base = gid - sg.get_local_id()[0]; - - if (base + sgSize < n_elems) { - using in_ptrT1 = - sycl::multi_ptr; - using in_ptrT2 = - sycl::multi_ptr; - using res_ptrT = - sycl::multi_ptr; - - const argT1 mat_el = sg.load(in_ptrT1(&mat[base])); - const argT2 vec_el = sg.load(in_ptrT2(&padded_vec[base % n1])); - - resT res_el = mat_el + vec_el; - - sg.store(res_ptrT(&res[base]), res_el); - } - else { - for (size_t k = base + sg.get_local_id()[0]; k < n_elems; - k += sgSize) { - res[k] = mat[k] + padded_vec[k % n1]; - } - } - }); + AddContigMatrixContigRowBroadcastingFunctor( + mat, padded_vec, res, n_elems, n1)); }); sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { @@ -413,13 +391,12 @@ struct AddContigMatrixContigRowBroadcastFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) { + using resT = typename AddOutputType::value_type; + if constexpr (std::is_same_v) { fnT fn = nullptr; return fn; } else { - using resT = typename AddOutputType::value_type; if constexpr (dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value) @@ -474,13 +451,12 @@ struct AddContigRowContigMatrixBroadcastFactory { fnT get() { - if constexpr (std::is_same_v::value_type, - void>) { + using resT = typename AddOutputType::value_type; + if constexpr (std::is_same_v) { fnT fn = nullptr; return fn; } else { - using resT = typename AddOutputType::value_type; if constexpr (dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value || dpctl::tensor::type_utils::is_complex::value) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp index 37978ab1eb..7c567e00c6 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp @@ -420,6 +420,130 @@ struct BinaryStridedFunctor } }; +template +struct BinaryContigMatrixContigRowBroadcastingFunctor +{ +private: + const argT1 *mat; + const argT2 *padded_vec; + resT *res; + size_t n_elems; + size_t n1; + +public: + BinaryContigMatrixContigRowBroadcastingFunctor(const argT1 *mat_tp, + const argT2 *row_tp, + resT *res_tp, + size_t n_elems_in_mat, + size_t n_elems_in_row) + : mat(mat_tp), padded_vec(row_tp), res(res_tp), n_elems(n_elems_in_mat), + n1(n_elems_in_row) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + BinaryOperatorT op{}; + static_assert(BinaryOperatorT::supports_sg_loadstore::value); + + auto sg = ndit.get_sub_group(); + size_t gid = ndit.get_global_linear_id(); + + std::uint8_t sgSize = sg.get_local_range()[0]; + size_t base = gid - sg.get_local_id()[0]; + + if (base + sgSize < n_elems) { + using in_ptrT1 = + sycl::multi_ptr; + using in_ptrT2 = + sycl::multi_ptr; + using res_ptrT = + sycl::multi_ptr; + + const argT1 mat_el = sg.load(in_ptrT1(&mat[base])); + const argT2 vec_el = sg.load(in_ptrT2(&padded_vec[base % n1])); + + resT res_el = op(mat_el, vec_el); + + sg.store(res_ptrT(&res[base]), res_el); + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < n_elems; + k += sgSize) { + res[k] = op(mat[k], padded_vec[k % n1]); + } + } + } +}; + +template +struct BinaryContigRowContigMatrixBroadcastingFunctor +{ +private: + const argT1 *padded_vec; + const argT2 *mat; + resT *res; + size_t n_elems; + size_t n1; + +public: + BinaryContigRowContigMatrixBroadcastingFunctor(const argT1 *row_tp, + const argT2 *mat_tp, + resT *res_tp, + size_t n_elems_in_mat, + size_t n_elems_in_row) + : padded_vec(row_tp), mat(mat_tp), res(res_tp), n_elems(n_elems_in_mat), + n1(n_elems_in_row) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + BinaryOperatorT op{}; + static_assert(BinaryOperatorT::supports_sg_loadstore::value); + + auto sg = ndit.get_sub_group(); + size_t gid = ndit.get_global_linear_id(); + + std::uint8_t sgSize = sg.get_local_range()[0]; + size_t base = gid - sg.get_local_id()[0]; + + if (base + sgSize < n_elems) { + using in_ptrT1 = + sycl::multi_ptr; + using in_ptrT2 = + sycl::multi_ptr; + using res_ptrT = + sycl::multi_ptr; + + const argT2 mat_el = sg.load(in_ptrT2(&mat[base])); + const argT1 vec_el = sg.load(in_ptrT1(&padded_vec[base % n1])); + + resT res_el = op(vec_el, mat_el); + + sg.store(res_ptrT(&res[base]), res_el); + } + else { + for (size_t k = base + sg.get_local_id()[0]; k < n_elems; + k += sgSize) { + res[k] = op(padded_vec[k % n1], mat[k]); + } + } + } +}; + } // namespace elementwise_common } // namespace kernels } // namespace tensor diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp new file mode 100644 index 0000000000..58abf0cfd5 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp @@ -0,0 +1,544 @@ +#pragma once +#include +#include +#include +#include + +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" + +#include "kernels/elementwise_functions/common.hpp" +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace true_divide +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; + +template +struct TrueDivideFunctor +{ + + using supports_sg_loadstore = std::negation< + std::disjunction, tu_ns::is_complex>>; + using supports_vec = std::negation< + std::disjunction, tu_ns::is_complex>>; + + resT operator()(const argT1 &in1, const argT2 &in2) + { + return in1 / in2; + } + + template + sycl::vec operator()(const sycl::vec &in1, + const sycl::vec &in2) + { + auto tmp = in1 / in2; + if constexpr (std::is_same_v) { + return tmp; + } + else { + using dpctl::tensor::type_utils::vec_cast; + + return vec_cast( + tmp); + } + } +}; + +template +using TrueDivideContigFunctor = elementwise_common::BinaryContigFunctor< + argT1, + argT2, + resT, + TrueDivideFunctor, + vec_sz, + n_vecs>; + +template +using TrueDivideStridedFunctor = elementwise_common::BinaryStridedFunctor< + argT1, + argT2, + resT, + IndexerT, + TrueDivideFunctor>; + +template struct TrueDivideOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + T2, + std::complex, + std::complex>, + td_ns::BinaryTypeMapEntry, + T2, + float, + std::complex>, + td_ns::BinaryTypeMapEntry, + std::complex>, + td_ns::BinaryTypeMapEntry, + T2, + std::complex, + std::complex>, + td_ns::BinaryTypeMapEntry, + std::complex>, + td_ns::BinaryTypeMapEntry, + T2, + double, + std::complex>, + td_ns::DefaultEntry>::result_type; +}; + +template +class true_divide_contig_kernel; + +typedef sycl::event (*true_divide_contig_impl_fn_ptr_t)( + sycl::queue, + size_t, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &); + +template +sycl::event +true_divide_contig_impl(sycl::queue exec_q, + size_t nelems, + const char *arg1_p, + py::ssize_t arg1_offset, + const char *arg2_p, + py::ssize_t arg2_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends = {}) +{ + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + size_t lws = 64; + constexpr unsigned int vec_sz = 4; + constexpr unsigned int n_vecs = 2; + const size_t n_groups = + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); + + using resTy = typename TrueDivideOutputType::value_type; + + const argTy1 *arg1_tp = + reinterpret_cast(arg1_p) + arg1_offset; + const argTy2 *arg2_tp = + reinterpret_cast(arg2_p) + arg2_offset; + resTy *res_tp = reinterpret_cast(res_p) + res_offset; + + cgh.parallel_for< + true_divide_contig_kernel>( + sycl::nd_range<1>(gws_range, lws_range), + TrueDivideContigFunctor( + arg1_tp, arg2_tp, res_tp, nelems)); + }); + return comp_ev; +} + +template struct TrueDivideContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename TrueDivideOutputType::value_type, + void>) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = true_divide_contig_impl; + return fn; + } + } +}; + +template +struct TrueDivideTypeMapFactory +{ + /*! @brief get typeid for output type of divide(T1 x, T2 y) */ + std::enable_if_t::value, int> get() + { + using rT = typename TrueDivideOutputType::value_type; + return td_ns::GetTypeid{}.get(); + } +}; + +template +class true_divide_strided_strided_kernel; + +typedef sycl::event (*true_divide_strided_impl_fn_ptr_t)( + sycl::queue, + size_t, + int, + const py::ssize_t *, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &, + const std::vector &); + +template +sycl::event +true_divide_strided_impl(sycl::queue exec_q, + size_t nelems, + int nd, + const py::ssize_t *shape_and_strides, + const char *arg1_p, + py::ssize_t arg1_offset, + const char *arg2_p, + py::ssize_t arg2_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.depends_on(additional_depends); + + using resTy = typename TrueDivideOutputType::value_type; + + using IndexerT = + typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + + IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset, + shape_and_strides}; + + const argTy1 *arg1_tp = reinterpret_cast(arg1_p); + const argTy2 *arg2_tp = reinterpret_cast(arg2_p); + resTy *res_tp = reinterpret_cast(res_p); + + sycl::range<1> gRange(nelems); + + cgh.parallel_for>( + gRange, TrueDivideStridedFunctor( + arg1_tp, arg2_tp, res_tp, indexer)); + }); + return comp_ev; +} + +template +struct TrueDivideStridedFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename TrueDivideOutputType::value_type, + void>) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = true_divide_strided_impl; + return fn; + } + } +}; + +template +using TrueDivideContigMatrixContigRowBroadcastingFunctor = + elementwise_common::BinaryContigMatrixContigRowBroadcastingFunctor< + argT1, + argT2, + resT, + TrueDivideFunctor>; + +template +using TrueDivideContigRowContigMatrixBroadcastingFunctor = + elementwise_common::BinaryContigRowContigMatrixBroadcastingFunctor< + argT1, + argT2, + resT, + TrueDivideFunctor>; + +template +class true_divide_matrix_row_broadcast_sg_krn; + +template +class true_divide_row_matrix_broadcast_sg_krn; + +typedef sycl::event ( + *true_divide_contig_matrix_contig_row_broadcast_impl_fn_ptr_t)( + sycl::queue, + std::vector &, + size_t, + size_t, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &); + +template +sycl::event true_divide_contig_matrix_contig_row_broadcast_impl( + sycl::queue exec_q, + std::vector &host_tasks, + size_t n0, + size_t n1, + const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix + py::ssize_t mat_offset, + const char *vec_p, // typeless pointer to (n1,) contiguous row + py::ssize_t vec_offset, + char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, + // res[i,j] = mat[i,j] / vec[j] + py::ssize_t res_offset, + const std::vector &depends = {}) +{ + const argT1 *mat = reinterpret_cast(mat_p) + mat_offset; + const argT2 *vec = reinterpret_cast(vec_p) + vec_offset; + resT *res = reinterpret_cast(res_p) + res_offset; + + const auto &dev = exec_q.get_device(); + const auto &sg_sizes = dev.get_info(); + // Get device-specific kernel info max_sub_group_size + size_t max_sgSize = + *(std::max_element(std::begin(sg_sizes), std::end(sg_sizes))); + + size_t n1_padded = n1 + max_sgSize; + argT2 *padded_vec = sycl::malloc_device(n1_padded, exec_q); + + if (padded_vec == nullptr) { + throw std::runtime_error("Could not allocate memory on the device"); + } + sycl::event make_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); // ensure vec contains actual data + cgh.parallel_for({n1_padded}, [=](sycl::id<1> id) { + auto i = id[0]; + padded_vec[i] = vec[i % n1]; + }); + }); + + // sub-group spans work-items [I, I + sgSize) + // base = ndit.get_global_linear_id() - sg.get_local_id()[0] + // Generically, sg.load( &mat[base]) may load arrays from + // different rows of mat. The start corresponds to row (base / n0) + // We read sg.load(&padded_vec[(base / n0)]). The vector is padded to + // ensure that reads are accessible + + size_t lws = 64; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(make_padded_vec_ev); + + auto lwsRange = sycl::range<1>(lws); + size_t n_elems = n0 * n1; + size_t n_groups = (n_elems + lws - 1) / lws; + auto gwsRange = sycl::range<1>(n_groups * lws); + + cgh.parallel_for< + class true_divide_matrix_row_broadcast_sg_krn>( + sycl::nd_range<1>(gwsRange, lwsRange), + TrueDivideContigMatrixContigRowBroadcastingFunctor( + mat, padded_vec, res, n_elems, n1)); + }); + + sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(comp_ev); + sycl::context ctx = exec_q.get_context(); + cgh.host_task([ctx, padded_vec]() { sycl::free(padded_vec, ctx); }); + }); + host_tasks.push_back(tmp_cleanup_ev); + + return comp_ev; +} + +template +struct TrueDivideContigMatrixContigRowBroadcastFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename TrueDivideOutputType::value_type, + void>) + { + fnT fn = nullptr; + return fn; + } + else { + using resT = typename TrueDivideOutputType::value_type; + if constexpr (dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = + true_divide_contig_matrix_contig_row_broadcast_impl; + return fn; + } + } + } +}; + +typedef sycl::event ( + *true_divide_contig_row_contig_matrix_broadcast_impl_fn_ptr_t)( + sycl::queue, + std::vector &, + size_t, + size_t, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &); + +template +sycl::event true_divide_contig_row_contig_matrix_broadcast_impl( + sycl::queue exec_q, + std::vector &host_tasks, + size_t n0, + size_t n1, + const char *vec_p, // typeless pointer to (n1,) contiguous row + py::ssize_t vec_offset, + const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix + py::ssize_t mat_offset, + char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, + // res[i,j] = mat[i,j] + vec[j] + py::ssize_t res_offset, + const std::vector &depends = {}) +{ + const argT1 *vec = reinterpret_cast(vec_p) + vec_offset; + const argT2 *mat = reinterpret_cast(mat_p) + mat_offset; + resT *res = reinterpret_cast(res_p) + res_offset; + + const auto &dev = exec_q.get_device(); + const auto &sg_sizes = dev.get_info(); + // Get device-specific kernel info max_sub_group_size + size_t max_sgSize = + *(std::max_element(std::begin(sg_sizes), std::end(sg_sizes))); + + size_t n1_padded = n1 + max_sgSize; + argT2 *padded_vec = sycl::malloc_device(n1_padded, exec_q); + + if (padded_vec == nullptr) { + throw std::runtime_error("Could not allocate memory on the device"); + } + sycl::event make_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); // ensure vec contains actual data + cgh.parallel_for({n1_padded}, [=](sycl::id<1> id) { + auto i = id[0]; + padded_vec[i] = vec[i % n1]; + }); + }); + + // sub-group spans work-items [I, I + sgSize) + // base = ndit.get_global_linear_id() - sg.get_local_id()[0] + // Generically, sg.load( &mat[base]) may load arrays from + // different rows of mat. The start corresponds to row (base / n0) + // We read sg.load(&padded_vec[(base / n0)]). The vector is padded to + // ensure that reads are accessible + + size_t lws = 64; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(make_padded_vec_ev); + + auto lwsRange = sycl::range<1>(lws); + size_t n_elems = n0 * n1; + size_t n_groups = (n_elems + lws - 1) / lws; + auto gwsRange = sycl::range<1>(n_groups * lws); + + cgh.parallel_for< + class true_divide_row_matrix_broadcast_sg_krn>( + sycl::nd_range<1>(gwsRange, lwsRange), + TrueDivideContigRowContigMatrixBroadcastingFunctor( + padded_vec, mat, res, n_elems, n1)); + }); + + sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(comp_ev); + sycl::context ctx = exec_q.get_context(); + cgh.host_task([ctx, padded_vec]() { sycl::free(padded_vec, ctx); }); + }); + host_tasks.push_back(tmp_cleanup_ev); + + return comp_ev; +}; + +template +struct TrueDivideContigRowContigMatrixBroadcastFactory +{ + fnT get() + { + using resT = typename TrueDivideOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + if constexpr (dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = + true_divide_contig_row_contig_matrix_broadcast_impl; + return fn; + } + } + } +}; + +} // namespace true_divide +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp index 45a5c91cf6..42e5943e6d 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions.cpp @@ -39,6 +39,7 @@ #include "kernels/elementwise_functions/isinf.hpp" #include "kernels/elementwise_functions/isnan.hpp" #include "kernels/elementwise_functions/sqrt.hpp" +#include "kernels/elementwise_functions/true_divide.hpp" namespace dpctl { @@ -280,12 +281,12 @@ void populate_cos_dispatch_vectors(void) namespace impl { -namespace fn_ns = dpctl::tensor::kernels::add; +namespace add_fn_ns = dpctl::tensor::kernels::add; -using fn_ns::add_contig_impl_fn_ptr_t; -using fn_ns::add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; -using fn_ns::add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; -using fn_ns::add_strided_impl_fn_ptr_t; +using add_fn_ns::add_contig_impl_fn_ptr_t; +using add_fn_ns::add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using add_fn_ns::add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using add_fn_ns::add_strided_impl_fn_ptr_t; static add_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] [td_ns::num_types]; @@ -307,6 +308,7 @@ static add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t void populate_add_dispatch_tables(void) { using namespace td_ns; + namespace fn_ns = add_fn_ns; // which input types are supported, and what is the type of the result using fn_ns::AddTypeMapFactory; @@ -320,12 +322,14 @@ void populate_add_dispatch_tables(void) dtb2; dtb2.populate_dispatch_table(add_strided_dispatch_table); - // function pointers for operation on contiguous inputs and outputs + // function pointers for operation on contiguous inputs and output using fn_ns::AddContigFactory; DispatchTableBuilder dtb3; dtb3.populate_dispatch_table(add_contig_dispatch_table); + // function pointers for operation on contiguous matrix, contiguous row + // with contiguous matrix output using fn_ns::AddContigMatrixContigRowBroadcastFactory; DispatchTableBuilder @@ -333,6 +337,8 @@ void populate_add_dispatch_tables(void) dtb4.populate_dispatch_table( add_contig_matrix_contig_row_broadcast_dispatch_table); + // function pointers for operation on contiguous row, contiguous matrix + // with contiguous matrix output using fn_ns::AddContigRowContigMatrixBroadcastFactory; DispatchTableBuilder @@ -380,6 +386,82 @@ void populate_sqrt_dispatch_vectors(void) } // namespace impl +// DIVIDE +namespace impl +{ +namespace true_divide_fn_ns = dpctl::tensor::kernels::true_divide; + +using true_divide_fn_ns::true_divide_contig_impl_fn_ptr_t; +using true_divide_fn_ns:: + true_divide_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using true_divide_fn_ns:: + true_divide_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using true_divide_fn_ns::true_divide_strided_impl_fn_ptr_t; + +static true_divide_contig_impl_fn_ptr_t + true_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int true_divide_output_id_table[td_ns::num_types][td_ns::num_types]; + +static true_divide_strided_impl_fn_ptr_t + true_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +// divide(matrix, row) +static true_divide_contig_matrix_contig_row_broadcast_impl_fn_ptr_t + true_divide_contig_matrix_contig_row_broadcast_dispatch_table + [td_ns::num_types][td_ns::num_types]; + +// divide(row, matrix) +static true_divide_contig_row_contig_matrix_broadcast_impl_fn_ptr_t + true_divide_contig_row_contig_matrix_broadcast_dispatch_table + [td_ns::num_types][td_ns::num_types]; + +void populate_true_divide_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = true_divide_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::TrueDivideTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(true_divide_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::TrueDivideStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(true_divide_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::TrueDivideContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(true_divide_contig_dispatch_table); + + // function pointers for operation on contiguous matrix, contiguous row + // with contiguous matrix output + using fn_ns::TrueDivideContigMatrixContigRowBroadcastFactory; + DispatchTableBuilder< + true_divide_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, + TrueDivideContigMatrixContigRowBroadcastFactory, num_types> + dtb4; + dtb4.populate_dispatch_table( + true_divide_contig_matrix_contig_row_broadcast_dispatch_table); + + // function pointers for operation on contiguous row, contiguous matrix + // with contiguous matrix output + using fn_ns::TrueDivideContigRowContigMatrixBroadcastFactory; + DispatchTableBuilder< + true_divide_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, + TrueDivideContigRowContigMatrixBroadcastFactory, num_types> + dtb5; + dtb5.populate_dispatch_table( + true_divide_contig_row_contig_matrix_broadcast_dispatch_table); +}; + +} // namespace impl + namespace py = pybind11; void init_elementwise_functions(py::module_ m) @@ -518,7 +600,46 @@ void init_elementwise_functions(py::module_ m) // FIXME: // B08: ==== DIVIDE (x1, x2) - // FIXME: + { + impl::populate_true_divide_dispatch_tables(); + using impl::true_divide_contig_dispatch_table; + using impl:: + true_divide_contig_matrix_contig_row_broadcast_dispatch_table; + using impl:: + true_divide_contig_row_contig_matrix_broadcast_dispatch_table; + using impl::true_divide_output_id_table; + using impl::true_divide_strided_dispatch_table; + + auto divide_pyapi = [&](dpctl::tensor::usm_ndarray src1, + dpctl::tensor::usm_ndarray src2, + dpctl::tensor::usm_ndarray dst, + sycl::queue exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, true_divide_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + true_divide_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + true_divide_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + true_divide_contig_matrix_contig_row_broadcast_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + true_divide_contig_row_contig_matrix_broadcast_dispatch_table); + }; + auto divide_result_type_pyapi = [&](py::dtype dtype1, + py::dtype dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + true_divide_output_id_table); + }; + m.def("_divide", divide_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_divide_result_type", divide_result_type_pyapi, ""); + } // B09: ==== EQUAL (x1, x2) // FIXME: From 8a8411a3e0f949fa232b21640de8f6489597f086 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 17 May 2023 15:35:12 -0500 Subject: [PATCH 25/48] Exposed dpctl.tensor.divide(x, y) --- dpctl/tensor/__init__.py | 12 +++++++++++- dpctl/tensor/_elementwise_funcs.py | 22 ++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 91c7bfdfac..c3f2b9c6dd 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -91,7 +91,16 @@ from dpctl.tensor._utility_functions import all, any from ._constants import e, inf, nan, newaxis, pi -from ._elementwise_funcs import abs, add, cos, isfinite, isinf, isnan, sqrt +from ._elementwise_funcs import ( + abs, + add, + cos, + divide, + isfinite, + isinf, + isnan, + sqrt, +) __all__ = [ "Device", @@ -175,4 +184,5 @@ "isnan", "isfinite", "sqrt", + "divide", ] diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index 97b873773b..d90ba5a39a 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -31,6 +31,28 @@ "add", ti._add_result_type, ti._add, _add_docstring_ ) +# DIVIDE + +_divide_docstring_ = """ +divide(x1, x2, order='K') + +Calculates the ratio for each element `x1_i` of the input array `x1` with +the respective element `x2_i` of the input array `x2`. + +Args: + x1 (usm_ndarray): + First input array, expected to have numeric data type. + x2 (usm_ndarray): + Second input array, also expected to have numeric data type. +Returns: + usm_narray: + an array containing the result of element-wise division. The data type + of the returned array is determined by the Type Promotion Rules. +""" +divide = BinaryElementwiseFunc( + "divide", ti._divide_result_type, ti._divide, _divide_docstring_ +) + # COS From dc386eb08f5f01680416ef49e7206189778972d0 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 17 May 2023 18:29:56 -0500 Subject: [PATCH 26/48] Split tests of elementwise functions into separate files --- dpctl/tests/elementwise/__init__.py | 0 dpctl/tests/elementwise/test_abs.py | 91 +++ dpctl/tests/elementwise/test_add.py | 167 +++++ dpctl/tests/elementwise/test_cos.py | 87 +++ dpctl/tests/elementwise/test_isfinite.py | 74 +++ dpctl/tests/elementwise/test_isinf.py | 76 +++ dpctl/tests/elementwise/test_isnan.py | 74 +++ dpctl/tests/elementwise/test_type_utils.py | 154 +++++ dpctl/tests/elementwise/utils.py | 39 ++ dpctl/tests/test_tensor_elementwise.py | 701 --------------------- 10 files changed, 762 insertions(+), 701 deletions(-) create mode 100644 dpctl/tests/elementwise/__init__.py create mode 100644 dpctl/tests/elementwise/test_abs.py create mode 100644 dpctl/tests/elementwise/test_add.py create mode 100644 dpctl/tests/elementwise/test_cos.py create mode 100644 dpctl/tests/elementwise/test_isfinite.py create mode 100644 dpctl/tests/elementwise/test_isinf.py create mode 100644 dpctl/tests/elementwise/test_isnan.py create mode 100644 dpctl/tests/elementwise/test_type_utils.py create mode 100644 dpctl/tests/elementwise/utils.py delete mode 100644 dpctl/tests/test_tensor_elementwise.py diff --git a/dpctl/tests/elementwise/__init__.py b/dpctl/tests/elementwise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dpctl/tests/elementwise/test_abs.py b/dpctl/tests/elementwise/test_abs.py new file mode 100644 index 0000000000..275be0d573 --- /dev/null +++ b/dpctl/tests/elementwise/test_abs.py @@ -0,0 +1,91 @@ +import itertools + +import numpy as np +import pytest + +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +from .utils import _all_dtypes, _usm_types + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_abs_out_type(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + X = dpt.asarray(0, dtype=arg_dt, sycl_queue=q) + if np.issubdtype(arg_dt, np.complexfloating): + type_map = { + np.dtype("c8"): np.dtype("f4"), + np.dtype("c16"): np.dtype("f8"), + } + assert dpt.abs(X).dtype == type_map[arg_dt] + else: + assert dpt.abs(X).dtype == arg_dt + + +@pytest.mark.parametrize("usm_type", _usm_types) +def test_abs_usm_type(usm_type): + q = get_queue_or_skip() + + arg_dt = np.dtype("i4") + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q) + X[..., 0::2] = 1 + X[..., 1::2] = 0 + + Y = dpt.abs(X) + assert Y.usm_type == X.usm_type + assert Y.sycl_queue == X.sycl_queue + assert Y.flags.c_contiguous + + expected_Y = dpt.asnumpy(X) + assert np.allclose(dpt.asnumpy(Y), expected_Y) + + +@pytest.mark.parametrize("dtype", _all_dtypes[1:]) +def test_abs_order(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) + X[..., 0::2] = 1 + X[..., 1::2] = 0 + + for ord in ["C", "F", "A", "K"]: + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms) + Y = dpt.abs(U, order=ord) + expected_Y = np.ones(Y.shape, dtype=Y.dtype) + expected_Y[..., 1::2] = 0 + expected_Y = np.transpose(expected_Y, perms) + assert np.allclose(dpt.asnumpy(Y), expected_Y) + + +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_abs_complex(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) + Xnp = np.random.standard_normal( + size=input_shape + ) + 1j * np.random.standard_normal(size=input_shape) + Xnp = Xnp.astype(arg_dt) + X[...] = Xnp + + for ord in ["C", "F", "A", "K"]: + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms) + Y = dpt.abs(U, order=ord) + expected_Y = np.abs(np.transpose(Xnp[:, ::-1, ::-1, :], perms)) + tol = dpt.finfo(Y.dtype).resolution + np.testing.assert_allclose( + dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol + ) diff --git a/dpctl/tests/elementwise/test_add.py b/dpctl/tests/elementwise/test_add.py new file mode 100644 index 0000000000..81176f4f45 --- /dev/null +++ b/dpctl/tests/elementwise/test_add.py @@ -0,0 +1,167 @@ +import ctypes + +import numpy as np +import pytest + +import dpctl +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +from .utils import _all_dtypes, _compare_dtypes, _usm_types + + +@pytest.mark.parametrize("op1_dtype", _all_dtypes) +@pytest.mark.parametrize("op2_dtype", _all_dtypes) +def test_add_dtype_matrix(op1_dtype, op2_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(op1_dtype, q) + skip_if_dtype_not_supported(op2_dtype, q) + + sz = 127 + ar1 = dpt.ones(sz, dtype=op1_dtype) + ar2 = dpt.ones_like(ar1, dtype=op2_dtype) + + r = dpt.add(ar1, ar2) + assert isinstance(r, dpt.usm_ndarray) + expected_dtype = np.add( + np.zeros(1, dtype=op1_dtype), np.zeros(1, dtype=op2_dtype) + ).dtype + assert _compare_dtypes(r.dtype, expected_dtype, sycl_queue=q) + assert r.shape == ar1.shape + assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all() + assert r.sycl_queue == ar1.sycl_queue + + ar3 = dpt.ones(sz, dtype=op1_dtype) + ar4 = dpt.ones(2 * sz, dtype=op2_dtype) + + r = dpt.add(ar3[::-1], ar4[::2]) + assert isinstance(r, dpt.usm_ndarray) + expected_dtype = np.add( + np.zeros(1, dtype=op1_dtype), np.zeros(1, dtype=op2_dtype) + ).dtype + assert _compare_dtypes(r.dtype, expected_dtype, sycl_queue=q) + assert r.shape == ar3.shape + assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all() + + +@pytest.mark.parametrize("op1_usm_type", _usm_types) +@pytest.mark.parametrize("op2_usm_type", _usm_types) +def test_add_usm_type_matrix(op1_usm_type, op2_usm_type): + get_queue_or_skip() + + sz = 128 + ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type) + ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type) + + r = dpt.add(ar1, ar2) + assert isinstance(r, dpt.usm_ndarray) + expected_usm_type = dpctl.utils.get_coerced_usm_type( + (op1_usm_type, op2_usm_type) + ) + assert r.usm_type == expected_usm_type + + +def test_add_order(): + get_queue_or_skip() + + ar1 = dpt.ones((20, 20), dtype="i4", order="C") + ar2 = dpt.ones((20, 20), dtype="i4", order="C") + r1 = dpt.add(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.add(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.add(ar1, ar2, order="A") + assert r3.flags.c_contiguous + r4 = dpt.add(ar1, ar2, order="K") + assert r4.flags.c_contiguous + + ar1 = dpt.ones((20, 20), dtype="i4", order="F") + ar2 = dpt.ones((20, 20), dtype="i4", order="F") + r1 = dpt.add(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.add(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.add(ar1, ar2, order="A") + assert r3.flags.f_contiguous + r4 = dpt.add(ar1, ar2, order="K") + assert r4.flags.f_contiguous + + ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] + ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] + r4 = dpt.add(ar1, ar2, order="K") + assert r4.strides == (20, -1) + + ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT + ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT + r4 = dpt.add(ar1, ar2, order="K") + assert r4.strides == (-1, 20) + + +def test_add_broadcasting(): + get_queue_or_skip() + + m = dpt.ones((100, 5), dtype="i4") + v = dpt.arange(5, dtype="i4") + + r = dpt.add(m, v) + + assert (dpt.asnumpy(r) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all() + + r2 = dpt.add(v, m) + assert (dpt.asnumpy(r2) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all() + + +@pytest.mark.parametrize("arr_dt", _all_dtypes) +def test_add_python_scalar(arr_dt): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arr_dt, q) + + X = dpt.zeros((10, 10), dtype=arr_dt, sycl_queue=q) + py_zeros = ( + bool(0), + int(0), + float(0), + complex(0), + np.float32(0), + ctypes.c_int(0), + ) + for sc in py_zeros: + R = dpt.add(X, sc) + assert isinstance(R, dpt.usm_ndarray) + R = dpt.add(sc, X) + assert isinstance(R, dpt.usm_ndarray) + + +class MockArray: + def __init__(self, arr): + self.data_ = arr + + @property + def __sycl_usm_array_interface__(self): + return self.data_.__sycl_usm_array_interface__ + + +def test_add_mock_array(): + get_queue_or_skip() + a = dpt.arange(10) + b = dpt.ones(10) + c = MockArray(b) + r = dpt.add(a, c) + assert isinstance(r, dpt.usm_ndarray) + + +def test_add_canary_mock_array(): + get_queue_or_skip() + a = dpt.arange(10) + + class Canary: + def __init__(self): + pass + + @property + def __sycl_usm_array_interface__(self): + return None + + c = Canary() + with pytest.raises(ValueError): + dpt.add(a, c) diff --git a/dpctl/tests/elementwise/test_cos.py b/dpctl/tests/elementwise/test_cos.py new file mode 100644 index 0000000000..22588aea44 --- /dev/null +++ b/dpctl/tests/elementwise/test_cos.py @@ -0,0 +1,87 @@ +import itertools + +import numpy as np +import pytest + +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +from .utils import _all_dtypes, _map_to_device_dtype + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_cos_out_type(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + X = dpt.asarray(0, dtype=dtype, sycl_queue=q) + expected_dtype = np.cos(np.array(0, dtype=dtype)).dtype + expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device) + assert dpt.cos(X).dtype == expected_dtype + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"]) +def test_cos_output(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n_seq = 100 + n_rep = 137 + + Xnp = np.linspace(-np.pi / 4, np.pi / 4, num=n_seq, dtype=dtype) + X = dpt.asarray(np.repeat(Xnp, n_rep), dtype=dtype, sycl_queue=q) + + Y = dpt.cos(X) + tol = 8 * dpt.finfo(Y.dtype).resolution + + np.testing.assert_allclose( + dpt.asnumpy(Y), np.repeat(np.cos(Xnp), n_rep), atol=tol, rtol=tol + ) + + +@pytest.mark.parametrize("usm_type", ["device", "shared", "host"]) +def test_cos_usm_type(usm_type): + q = get_queue_or_skip() + + arg_dt = np.dtype("f4") + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q) + X[..., 0::2] = np.pi / 6 + X[..., 1::2] = np.pi / 3 + + Y = dpt.cos(X) + assert Y.usm_type == X.usm_type + assert Y.sycl_queue == X.sycl_queue + assert Y.flags.c_contiguous + + expected_Y = np.empty(input_shape, dtype=arg_dt) + expected_Y[..., 0::2] = np.cos(np.float32(np.pi / 6)) + expected_Y[..., 1::2] = np.cos(np.float32(np.pi / 3)) + tol = 8 * dpt.finfo(Y.dtype).resolution + + np.testing.assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_cos_order(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) + X[..., 0::2] = np.pi / 6 + X[..., 1::2] = np.pi / 3 + + for ord in ["C", "F", "A", "K"]: + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms) + Y = dpt.cos(U, order=ord) + expected_Y = np.cos(dpt.asnumpy(U)) + tol = 8 * max( + dpt.finfo(Y.dtype).resolution, + np.finfo(expected_Y.dtype).resolution, + ) + np.testing.assert_allclose( + dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol + ) diff --git a/dpctl/tests/elementwise/test_isfinite.py b/dpctl/tests/elementwise/test_isfinite.py new file mode 100644 index 0000000000..5cc9699cf8 --- /dev/null +++ b/dpctl/tests/elementwise/test_isfinite.py @@ -0,0 +1,74 @@ +import itertools + +import numpy as np +import pytest + +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +from .utils import _all_dtypes + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_isfinite_out_type(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + X = dpt.asarray(0, dtype=dtype, sycl_queue=q) + assert dpt.isfinite(X).dtype == dpt.bool + + +def test_isfinite_output(): + q = get_queue_or_skip() + + Xnp = np.asarray(np.nan) + X = dpt.asarray(np.nan, sycl_queue=q) + assert dpt.asnumpy(dpt.isfinite(X)) == np.isfinite(Xnp) + + +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_isfinite_complex(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = complex(np.nan, np.nan) + y2 = complex(1, np.nan) + y3 = complex(np.nan, 1) + y4 = complex(2, 1) + y5 = complex(np.inf, 1) + + Ynp = np.repeat(np.array([y1, y2, y3, y4, y5], dtype=dtype), 12) + Y = dpt.asarray(Ynp, sycl_queue=q) + assert np.array_equal(dpt.asnumpy(dpt.isfinite(Y)), np.isfinite(Ynp)) + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) +def test_isfinite_floats(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = np.nan + y2 = 1 + y3 = np.inf + + for mult in [123, 137, 255, 271, 272]: + Ynp = np.repeat(np.array([y1, y2, y3], dtype=dtype), mult) + Y = dpt.asarray(Ynp, sycl_queue=q) + assert np.array_equal(dpt.asnumpy(dpt.isfinite(Y)), np.isfinite(Ynp)) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_isfinite_order(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.ones(input_shape, dtype=arg_dt, sycl_queue=q) + + for ord in ["C", "F", "A", "K"]: + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[::2, ::-1, ::-1, ::5], perms) + Y = dpt.isfinite(U, order=ord) + expected_Y = np.full(Y.shape, True, dtype=Y.dtype) + assert np.allclose(dpt.asnumpy(Y), expected_Y) diff --git a/dpctl/tests/elementwise/test_isinf.py b/dpctl/tests/elementwise/test_isinf.py new file mode 100644 index 0000000000..3ce1c74f36 --- /dev/null +++ b/dpctl/tests/elementwise/test_isinf.py @@ -0,0 +1,76 @@ +import itertools + +import numpy as np +import pytest + +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +from .utils import _all_dtypes + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_isinf_out_type(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + X = dpt.asarray(0, dtype=dtype, sycl_queue=q) + assert dpt.isinf(X).dtype == dpt.bool + + +def test_isinf_output(): + q = get_queue_or_skip() + + Xnp = np.asarray(np.inf) + X = dpt.asarray(np.inf, sycl_queue=q) + assert dpt.asnumpy(dpt.isinf(X)) == np.isinf(Xnp) + + +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_isinf_complex(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = complex(np.inf, np.inf) + y2 = complex(1, np.inf) + y3 = complex(np.inf, 1) + y4 = complex(2, 1) + y5 = complex(np.inf, 1) + y6 = complex(np.inf, np.nan) + + Ynp = np.repeat(np.array([y1, y2, y3, y4, y5, y6], dtype=dtype), 123) + Y = dpt.asarray(Ynp, sycl_queue=q) + assert np.array_equal(dpt.asnumpy(dpt.isinf(Y)), np.isinf(Ynp)) + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) +def test_isinf_floats(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = np.nan + y2 = 1 + y3 = np.inf + y4 = -np.inf + + for mult in [123, 137, 255, 271, 272]: + Ynp = np.repeat(np.array([y1, y2, y3, y4], dtype=dtype), mult) + Y = dpt.asarray(Ynp, sycl_queue=q) + assert np.array_equal(dpt.asnumpy(dpt.isinf(Y)), np.isinf(Ynp)) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_isinf_order(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.ones(input_shape, dtype=arg_dt, sycl_queue=q) + + for ord in ["C", "F", "A", "K"]: + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[::2, ::-1, ::-1, ::5], perms) + Y = dpt.isinf(U, order=ord) + expected_Y = np.full(Y.shape, False, dtype=Y.dtype) + assert np.allclose(dpt.asnumpy(Y), expected_Y) diff --git a/dpctl/tests/elementwise/test_isnan.py b/dpctl/tests/elementwise/test_isnan.py new file mode 100644 index 0000000000..8e983cb2dc --- /dev/null +++ b/dpctl/tests/elementwise/test_isnan.py @@ -0,0 +1,74 @@ +import itertools + +import numpy as np +import pytest + +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +from .utils import _all_dtypes + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_isnan_out_type(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + X = dpt.asarray(0, dtype=dtype, sycl_queue=q) + assert dpt.isnan(X).dtype == dpt.bool + + +def test_isnan_output(): + q = get_queue_or_skip() + + Xnp = np.asarray(np.nan) + X = dpt.asarray(np.nan, sycl_queue=q) + assert dpt.asnumpy(dpt.isnan(X)) == np.isnan(Xnp) + + +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_isnan_complex(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = complex(np.nan, np.nan) + y2 = complex(1, np.nan) + y3 = complex(np.nan, 1) + y4 = complex(2, 1) + y5 = complex(np.inf, 1) + + Ynp = np.repeat(np.array([y1, y2, y3, y4, y5], dtype=dtype), 123) + Y = dpt.asarray(Ynp, sycl_queue=q) + assert np.array_equal(dpt.asnumpy(dpt.isnan(Y)), np.isnan(Ynp)) + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) +def test_isnan_floats(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = np.nan + y2 = 1 + y3 = np.inf + + for mult in [123, 137, 255, 271, 272]: + Ynp = np.repeat(np.array([y1, y2, y3], dtype=dtype), mult) + Y = dpt.asarray(Ynp, sycl_queue=q) + assert np.array_equal(dpt.asnumpy(dpt.isnan(Y)), np.isnan(Ynp)) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_isnan_order(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.ones(input_shape, dtype=arg_dt, sycl_queue=q) + + for ord in ["C", "F", "A", "K"]: + for perms in itertools.permutations(range(4)): + U = dpt.permute_dims(X[::2, ::-1, ::-1, ::5], perms) + Y = dpt.isnan(U, order=ord) + expected_Y = np.full(Y.shape, False, dtype=Y.dtype) + assert np.allclose(dpt.asnumpy(Y), expected_Y) diff --git a/dpctl/tests/elementwise/test_type_utils.py b/dpctl/tests/elementwise/test_type_utils.py new file mode 100644 index 0000000000..cc1166e966 --- /dev/null +++ b/dpctl/tests/elementwise/test_type_utils.py @@ -0,0 +1,154 @@ +import pytest + +import dpctl +import dpctl.tensor as dpt +import dpctl.tensor._type_utils as tu + +from .utils import _all_dtypes, _map_to_device_dtype + + +class MockDevice: + def __init__(self, fp16: bool, fp64: bool): + self.has_aspect_fp16 = fp16 + self.has_aspect_fp64 = fp64 + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_type_utils_map_to_device_type(dtype): + for fp64 in [ + True, + False, + ]: + for fp16 in [True, False]: + dev = MockDevice(fp16, fp64) + dt_in = dpt.dtype(dtype) + dt_out = _map_to_device_dtype(dt_in, dev) + assert isinstance(dt_out, dpt.dtype) + + +def test_type_util_all_data_types(): + for fp64 in [ + True, + False, + ]: + for fp16 in [True, False]: + r = tu._all_data_types(fp16, fp64) + assert isinstance(r, list) + # 11: bool + 4 signed + 4 unsigned inegral + float32 + complex64 + assert len(r) == 11 + int(fp16) + 2 * int(fp64) + + +def test_type_util_can_cast(): + for fp64 in [ + True, + False, + ]: + for fp16 in [True, False]: + for from_ in _all_dtypes: + for to_ in _all_dtypes: + r = tu._can_cast( + dpt.dtype(from_), dpt.dtype(to_), fp16, fp64 + ) + assert isinstance(r, bool) + + +def test_type_utils_empty_like_orderK(): + try: + a = dpt.empty((10, 10), dtype=dpt.int32, order="F") + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + X = tu._empty_like_orderK(a, dpt.int32, a.usm_type, a.device) + assert X.flags["F"] + + +def test_type_utils_empty_like_orderK_invalid_args(): + with pytest.raises(TypeError): + tu._empty_like_orderK([1, 2, 3], dpt.int32, "device", None) + with pytest.raises(TypeError): + tu._empty_like_pair_orderK( + [1, 2, 3], + ( + 1, + 2, + 3, + ), + dpt.int32, + "device", + None, + ) + try: + a = dpt.empty(10, dtype=dpt.int32) + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + with pytest.raises(TypeError): + tu._empty_like_pair_orderK( + a, + ( + 1, + 2, + 3, + ), + dpt.int32, + "device", + None, + ) + + +def test_type_utils_find_buf_dtype(): + def _denier_fn(dt): + return False + + for fp64 in [ + True, + False, + ]: + for fp16 in [True, False]: + dev = MockDevice(fp16, fp64) + arg_dt = dpt.float64 + r = tu._find_buf_dtype(arg_dt, _denier_fn, dev) + assert r == ( + None, + None, + ) + + +def test_type_utils_find_buf_dtype2(): + def _denier_fn(dt1, dt2): + return False + + for fp64 in [ + True, + False, + ]: + for fp16 in [True, False]: + dev = MockDevice(fp16, fp64) + arg1_dt = dpt.float64 + arg2_dt = dpt.complex64 + r = tu._find_buf_dtype2(arg1_dt, arg2_dt, _denier_fn, dev) + assert r == ( + None, + None, + None, + ) + + +def test_unary_func_arg_validation(): + with pytest.raises(TypeError): + dpt.abs([1, 2, 3]) + try: + a = dpt.arange(8) + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + dpt.abs(a, order="invalid") + + +def test_binary_func_arg_validation(): + with pytest.raises(dpctl.utils.ExecutionPlacementError): + dpt.add([1, 2, 3], 1) + try: + a = dpt.arange(8) + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + with pytest.raises(ValueError): + dpt.add(a, Ellipsis) + dpt.add(a, a, order="invalid") diff --git a/dpctl/tests/elementwise/utils.py b/dpctl/tests/elementwise/utils.py new file mode 100644 index 0000000000..b4e71f14ad --- /dev/null +++ b/dpctl/tests/elementwise/utils.py @@ -0,0 +1,39 @@ +import dpctl +import dpctl.tensor._type_utils as tu + +_all_dtypes = [ + "b1", + "i1", + "u1", + "i2", + "u2", + "i4", + "u4", + "i8", + "u8", + "f2", + "f4", + "f8", + "c8", + "c16", +] +_usm_types = ["device", "shared", "host"] + + +def _map_to_device_dtype(dt, dev): + return tu._to_device_supported_dtype(dt, dev) + + +def _compare_dtypes(dt, ref_dt, sycl_queue=None): + assert isinstance(sycl_queue, dpctl.SyclQueue) + dev = sycl_queue.sycl_device + expected_dt = _map_to_device_dtype(ref_dt, dev) + return dt == expected_dt + + +__all__ = [ + "_all_dtypes", + "_usm_types", + "_map_to_device_dtype", + "_compare_dtypes", +] diff --git a/dpctl/tests/test_tensor_elementwise.py b/dpctl/tests/test_tensor_elementwise.py deleted file mode 100644 index 0e84b4db73..0000000000 --- a/dpctl/tests/test_tensor_elementwise.py +++ /dev/null @@ -1,701 +0,0 @@ -import ctypes -import itertools - -import numpy as np -import pytest - -import dpctl -import dpctl.tensor as dpt -import dpctl.tensor._type_utils as tu -import dpctl.utils -from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported - -_all_dtypes = [ - "b1", - "i1", - "u1", - "i2", - "u2", - "i4", - "u4", - "i8", - "u8", - "f2", - "f4", - "f8", - "c8", - "c16", -] -_usm_types = ["device", "shared", "host"] - - -class MockDevice: - def __init__(self, fp16: bool, fp64: bool): - self.has_aspect_fp16 = fp16 - self.has_aspect_fp64 = fp64 - - -def _map_to_device_dtype(dt, dev): - return tu._to_device_supported_dtype(dt, dev) - - -@pytest.mark.parametrize("dtype", _all_dtypes) -def test_type_utils_map_to_device_type(dtype): - for fp64 in [ - True, - False, - ]: - for fp16 in [True, False]: - dev = MockDevice(fp16, fp64) - dt_in = dpt.dtype(dtype) - dt_out = _map_to_device_dtype(dt_in, dev) - assert isinstance(dt_out, dpt.dtype) - - -def test_type_util_all_data_types(): - for fp64 in [ - True, - False, - ]: - for fp16 in [True, False]: - r = tu._all_data_types(fp16, fp64) - assert isinstance(r, list) - # 11: bool + 4 signed + 4 unsigned inegral + float32 + complex64 - assert len(r) == 11 + int(fp16) + 2 * int(fp64) - - -def test_type_util_can_cast(): - for fp64 in [ - True, - False, - ]: - for fp16 in [True, False]: - for from_ in _all_dtypes: - for to_ in _all_dtypes: - r = tu._can_cast( - dpt.dtype(from_), dpt.dtype(to_), fp16, fp64 - ) - assert isinstance(r, bool) - - -def test_type_utils_empty_like_orderK(): - try: - a = dpt.empty((10, 10), dtype=dpt.int32, order="F") - except dpctl.SyclDeviceCreationError: - pytest.skip("No SYCL devices available") - X = tu._empty_like_orderK(a, dpt.int32, a.usm_type, a.device) - assert X.flags["F"] - - -def test_type_utils_empty_like_orderK_invalid_args(): - with pytest.raises(TypeError): - tu._empty_like_orderK([1, 2, 3], dpt.int32, "device", None) - with pytest.raises(TypeError): - tu._empty_like_pair_orderK( - [1, 2, 3], - ( - 1, - 2, - 3, - ), - dpt.int32, - "device", - None, - ) - try: - a = dpt.empty(10, dtype=dpt.int32) - except dpctl.SyclDeviceCreationError: - pytest.skip("No SYCL devices available") - with pytest.raises(TypeError): - tu._empty_like_pair_orderK( - a, - ( - 1, - 2, - 3, - ), - dpt.int32, - "device", - None, - ) - - -def test_type_utils_find_buf_dtype(): - def _denier_fn(dt): - return False - - for fp64 in [ - True, - False, - ]: - for fp16 in [True, False]: - dev = MockDevice(fp16, fp64) - arg_dt = dpt.float64 - r = tu._find_buf_dtype(arg_dt, _denier_fn, dev) - assert r == ( - None, - None, - ) - - -def test_type_utils_find_buf_dtype2(): - def _denier_fn(dt1, dt2): - return False - - for fp64 in [ - True, - False, - ]: - for fp16 in [True, False]: - dev = MockDevice(fp16, fp64) - arg1_dt = dpt.float64 - arg2_dt = dpt.complex64 - r = tu._find_buf_dtype2(arg1_dt, arg2_dt, _denier_fn, dev) - assert r == ( - None, - None, - None, - ) - - -def test_unary_func_arg_validation(): - with pytest.raises(TypeError): - dpt.abs([1, 2, 3]) - try: - a = dpt.arange(8) - except dpctl.SyclDeviceCreationError: - pytest.skip("No SYCL devices available") - dpt.abs(a, order="invalid") - - -def test_binary_func_arg_validation(): - with pytest.raises(dpctl.utils.ExecutionPlacementError): - dpt.add([1, 2, 3], 1) - try: - a = dpt.arange(8) - except dpctl.SyclDeviceCreationError: - pytest.skip("No SYCL devices available") - with pytest.raises(ValueError): - dpt.add(a, Ellipsis) - dpt.add(a, a, order="invalid") - - -@pytest.mark.parametrize("dtype", _all_dtypes) -def test_abs_out_type(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - arg_dt = np.dtype(dtype) - X = dpt.asarray(0, dtype=arg_dt, sycl_queue=q) - if np.issubdtype(arg_dt, np.complexfloating): - type_map = { - np.dtype("c8"): np.dtype("f4"), - np.dtype("c16"): np.dtype("f8"), - } - assert dpt.abs(X).dtype == type_map[arg_dt] - else: - assert dpt.abs(X).dtype == arg_dt - - -@pytest.mark.parametrize("usm_type", _usm_types) -def test_abs_usm_type(usm_type): - q = get_queue_or_skip() - - arg_dt = np.dtype("i4") - input_shape = (10, 10, 10, 10) - X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q) - X[..., 0::2] = 1 - X[..., 1::2] = 0 - - Y = dpt.abs(X) - assert Y.usm_type == X.usm_type - assert Y.sycl_queue == X.sycl_queue - assert Y.flags.c_contiguous - - expected_Y = dpt.asnumpy(X) - assert np.allclose(dpt.asnumpy(Y), expected_Y) - - -@pytest.mark.parametrize("dtype", _all_dtypes[1:]) -def test_abs_order(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - arg_dt = np.dtype(dtype) - input_shape = (10, 10, 10, 10) - X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) - X[..., 0::2] = 1 - X[..., 1::2] = 0 - - for ord in ["C", "F", "A", "K"]: - for perms in itertools.permutations(range(4)): - U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms) - Y = dpt.abs(U, order=ord) - expected_Y = np.ones(Y.shape, dtype=Y.dtype) - expected_Y[..., 1::2] = 0 - expected_Y = np.transpose(expected_Y, perms) - assert np.allclose(dpt.asnumpy(Y), expected_Y) - - -@pytest.mark.parametrize("dtype", ["c8", "c16"]) -def test_abs_complex(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - arg_dt = np.dtype(dtype) - input_shape = (10, 10, 10, 10) - X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) - Xnp = np.random.standard_normal( - size=input_shape - ) + 1j * np.random.standard_normal(size=input_shape) - Xnp = Xnp.astype(arg_dt) - X[...] = Xnp - - for ord in ["C", "F", "A", "K"]: - for perms in itertools.permutations(range(4)): - U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms) - Y = dpt.abs(U, order=ord) - expected_Y = np.abs(np.transpose(Xnp[:, ::-1, ::-1, :], perms)) - tol = dpt.finfo(Y.dtype).resolution - np.testing.assert_allclose( - dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol - ) - - -def _compare_dtypes(dt, ref_dt, sycl_queue=None): - assert isinstance(sycl_queue, dpctl.SyclQueue) - dev = sycl_queue.sycl_device - expected_dt = _map_to_device_dtype(ref_dt, dev) - return dt == expected_dt - - -@pytest.mark.parametrize("op1_dtype", _all_dtypes) -@pytest.mark.parametrize("op2_dtype", _all_dtypes) -def test_add_dtype_matrix(op1_dtype, op2_dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(op1_dtype, q) - skip_if_dtype_not_supported(op2_dtype, q) - - sz = 127 - ar1 = dpt.ones(sz, dtype=op1_dtype) - ar2 = dpt.ones_like(ar1, dtype=op2_dtype) - - r = dpt.add(ar1, ar2) - assert isinstance(r, dpt.usm_ndarray) - expected_dtype = np.add( - np.zeros(1, dtype=op1_dtype), np.zeros(1, dtype=op2_dtype) - ).dtype - assert _compare_dtypes(r.dtype, expected_dtype, sycl_queue=q) - assert r.shape == ar1.shape - assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all() - assert r.sycl_queue == ar1.sycl_queue - - ar3 = dpt.ones(sz, dtype=op1_dtype) - ar4 = dpt.ones(2 * sz, dtype=op2_dtype) - - r = dpt.add(ar3[::-1], ar4[::2]) - assert isinstance(r, dpt.usm_ndarray) - expected_dtype = np.add( - np.zeros(1, dtype=op1_dtype), np.zeros(1, dtype=op2_dtype) - ).dtype - assert _compare_dtypes(r.dtype, expected_dtype, sycl_queue=q) - assert r.shape == ar3.shape - assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all() - - -@pytest.mark.parametrize("op1_usm_type", _usm_types) -@pytest.mark.parametrize("op2_usm_type", _usm_types) -def test_add_usm_type_matrix(op1_usm_type, op2_usm_type): - get_queue_or_skip() - - sz = 128 - ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type) - ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type) - - r = dpt.add(ar1, ar2) - assert isinstance(r, dpt.usm_ndarray) - expected_usm_type = dpctl.utils.get_coerced_usm_type( - (op1_usm_type, op2_usm_type) - ) - assert r.usm_type == expected_usm_type - - -def test_add_order(): - get_queue_or_skip() - - ar1 = dpt.ones((20, 20), dtype="i4", order="C") - ar2 = dpt.ones((20, 20), dtype="i4", order="C") - r1 = dpt.add(ar1, ar2, order="C") - assert r1.flags.c_contiguous - r2 = dpt.add(ar1, ar2, order="F") - assert r2.flags.f_contiguous - r3 = dpt.add(ar1, ar2, order="A") - assert r3.flags.c_contiguous - r4 = dpt.add(ar1, ar2, order="K") - assert r4.flags.c_contiguous - - ar1 = dpt.ones((20, 20), dtype="i4", order="F") - ar2 = dpt.ones((20, 20), dtype="i4", order="F") - r1 = dpt.add(ar1, ar2, order="C") - assert r1.flags.c_contiguous - r2 = dpt.add(ar1, ar2, order="F") - assert r2.flags.f_contiguous - r3 = dpt.add(ar1, ar2, order="A") - assert r3.flags.f_contiguous - r4 = dpt.add(ar1, ar2, order="K") - assert r4.flags.f_contiguous - - ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] - ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] - r4 = dpt.add(ar1, ar2, order="K") - assert r4.strides == (20, -1) - - ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT - ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT - r4 = dpt.add(ar1, ar2, order="K") - assert r4.strides == (-1, 20) - - -def test_add_broadcasting(): - get_queue_or_skip() - - m = dpt.ones((100, 5), dtype="i4") - v = dpt.arange(5, dtype="i4") - - r = dpt.add(m, v) - - assert (dpt.asnumpy(r) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all() - - r2 = dpt.add(v, m) - assert (dpt.asnumpy(r2) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all() - - -@pytest.mark.parametrize("arr_dt", _all_dtypes) -def test_add_python_scalar(arr_dt): - q = get_queue_or_skip() - skip_if_dtype_not_supported(arr_dt, q) - - X = dpt.zeros((10, 10), dtype=arr_dt, sycl_queue=q) - py_zeros = ( - bool(0), - int(0), - float(0), - complex(0), - np.float32(0), - ctypes.c_int(0), - ) - for sc in py_zeros: - R = dpt.add(X, sc) - assert isinstance(R, dpt.usm_ndarray) - R = dpt.add(sc, X) - assert isinstance(R, dpt.usm_ndarray) - - -class MockArray: - def __init__(self, arr): - self.data_ = arr - - @property - def __sycl_usm_array_interface__(self): - return self.data_.__sycl_usm_array_interface__ - - -def test_add_mock_array(): - get_queue_or_skip() - a = dpt.arange(10) - b = dpt.ones(10) - c = MockArray(b) - r = dpt.add(a, c) - assert isinstance(r, dpt.usm_ndarray) - - -def test_add_canary_mock_array(): - get_queue_or_skip() - a = dpt.arange(10) - - class Canary: - def __init__(self): - pass - - @property - def __sycl_usm_array_interface__(self): - return None - - c = Canary() - with pytest.raises(ValueError): - dpt.add(a, c) - - -@pytest.mark.parametrize("dtype", _all_dtypes) -def test_cos_out_type(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - X = dpt.asarray(0, dtype=dtype, sycl_queue=q) - expected_dtype = np.cos(np.array(0, dtype=dtype)).dtype - expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device) - assert dpt.cos(X).dtype == expected_dtype - - -@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"]) -def test_cos_output(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - n_seq = 100 - n_rep = 137 - - Xnp = np.linspace(-np.pi / 4, np.pi / 4, num=n_seq, dtype=dtype) - X = dpt.asarray(np.repeat(Xnp, n_rep), dtype=dtype, sycl_queue=q) - - Y = dpt.cos(X) - tol = 8 * dpt.finfo(Y.dtype).resolution - - np.testing.assert_allclose( - dpt.asnumpy(Y), np.repeat(np.cos(Xnp), n_rep), atol=tol, rtol=tol - ) - - -@pytest.mark.parametrize("usm_type", ["device", "shared", "host"]) -def test_cos_usm_type(usm_type): - q = get_queue_or_skip() - - arg_dt = np.dtype("f4") - input_shape = (10, 10, 10, 10) - X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q) - X[..., 0::2] = np.pi / 6 - X[..., 1::2] = np.pi / 3 - - Y = dpt.cos(X) - assert Y.usm_type == X.usm_type - assert Y.sycl_queue == X.sycl_queue - assert Y.flags.c_contiguous - - expected_Y = np.empty(input_shape, dtype=arg_dt) - expected_Y[..., 0::2] = np.cos(np.float32(np.pi / 6)) - expected_Y[..., 1::2] = np.cos(np.float32(np.pi / 3)) - tol = 8 * dpt.finfo(Y.dtype).resolution - - np.testing.assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol) - - -@pytest.mark.parametrize("dtype", _all_dtypes) -def test_cos_order(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - arg_dt = np.dtype(dtype) - input_shape = (10, 10, 10, 10) - X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) - X[..., 0::2] = np.pi / 6 - X[..., 1::2] = np.pi / 3 - - for ord in ["C", "F", "A", "K"]: - for perms in itertools.permutations(range(4)): - U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms) - Y = dpt.cos(U, order=ord) - expected_Y = np.cos(dpt.asnumpy(U)) - tol = 8 * max( - dpt.finfo(Y.dtype).resolution, - np.finfo(expected_Y.dtype).resolution, - ) - np.testing.assert_allclose( - dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol - ) - - -@pytest.mark.parametrize("dtype", _all_dtypes) -def test_isnan_out_type(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - X = dpt.asarray(0, dtype=dtype, sycl_queue=q) - assert dpt.isnan(X).dtype == dpt.bool - - -def test_isnan_output(): - q = get_queue_or_skip() - - Xnp = np.asarray(np.nan) - X = dpt.asarray(np.nan, sycl_queue=q) - assert dpt.asnumpy(dpt.isnan(X)) == np.isnan(Xnp) - - -@pytest.mark.parametrize("dtype", ["c8", "c16"]) -def test_isnan_complex(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - y1 = complex(np.nan, np.nan) - y2 = complex(1, np.nan) - y3 = complex(np.nan, 1) - y4 = complex(2, 1) - y5 = complex(np.inf, 1) - - Ynp = np.repeat(np.array([y1, y2, y3, y4, y5], dtype=dtype), 123) - Y = dpt.asarray(Ynp, sycl_queue=q) - assert np.array_equal(dpt.asnumpy(dpt.isnan(Y)), np.isnan(Ynp)) - - -@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) -def test_isnan_floats(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - y1 = np.nan - y2 = 1 - y3 = np.inf - - for mult in [123, 137, 255, 271, 272]: - Ynp = np.repeat(np.array([y1, y2, y3], dtype=dtype), mult) - Y = dpt.asarray(Ynp, sycl_queue=q) - assert np.array_equal(dpt.asnumpy(dpt.isnan(Y)), np.isnan(Ynp)) - - -@pytest.mark.parametrize("dtype", _all_dtypes) -def test_isnan_order(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - arg_dt = np.dtype(dtype) - input_shape = (10, 10, 10, 10) - X = dpt.ones(input_shape, dtype=arg_dt, sycl_queue=q) - - for ord in ["C", "F", "A", "K"]: - for perms in itertools.permutations(range(4)): - U = dpt.permute_dims(X[::2, ::-1, ::-1, ::5], perms) - Y = dpt.isnan(U, order=ord) - expected_Y = np.full(Y.shape, False, dtype=Y.dtype) - assert np.allclose(dpt.asnumpy(Y), expected_Y) - - -@pytest.mark.parametrize("dtype", _all_dtypes) -def test_isfinite_out_type(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - X = dpt.asarray(0, dtype=dtype, sycl_queue=q) - assert dpt.isfinite(X).dtype == dpt.bool - - -def test_isfinite_output(): - q = get_queue_or_skip() - - Xnp = np.asarray(np.nan) - X = dpt.asarray(np.nan, sycl_queue=q) - assert dpt.asnumpy(dpt.isfinite(X)) == np.isfinite(Xnp) - - -@pytest.mark.parametrize("dtype", ["c8", "c16"]) -def test_isfinite_complex(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - y1 = complex(np.nan, np.nan) - y2 = complex(1, np.nan) - y3 = complex(np.nan, 1) - y4 = complex(2, 1) - y5 = complex(np.inf, 1) - - Ynp = np.repeat(np.array([y1, y2, y3, y4, y5], dtype=dtype), 12) - Y = dpt.asarray(Ynp, sycl_queue=q) - assert np.array_equal(dpt.asnumpy(dpt.isfinite(Y)), np.isfinite(Ynp)) - - -@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) -def test_isfinite_floats(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - y1 = np.nan - y2 = 1 - y3 = np.inf - - for mult in [123, 137, 255, 271, 272]: - Ynp = np.repeat(np.array([y1, y2, y3], dtype=dtype), mult) - Y = dpt.asarray(Ynp, sycl_queue=q) - assert np.array_equal(dpt.asnumpy(dpt.isfinite(Y)), np.isfinite(Ynp)) - - -@pytest.mark.parametrize("dtype", _all_dtypes) -def test_isfinite_order(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - arg_dt = np.dtype(dtype) - input_shape = (10, 10, 10, 10) - X = dpt.ones(input_shape, dtype=arg_dt, sycl_queue=q) - - for ord in ["C", "F", "A", "K"]: - for perms in itertools.permutations(range(4)): - U = dpt.permute_dims(X[::2, ::-1, ::-1, ::5], perms) - Y = dpt.isfinite(U, order=ord) - expected_Y = np.full(Y.shape, True, dtype=Y.dtype) - assert np.allclose(dpt.asnumpy(Y), expected_Y) - - -@pytest.mark.parametrize("dtype", _all_dtypes) -def test_isinf_out_type(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - X = dpt.asarray(0, dtype=dtype, sycl_queue=q) - assert dpt.isinf(X).dtype == dpt.bool - - -def test_isinf_output(): - q = get_queue_or_skip() - - Xnp = np.asarray(np.inf) - X = dpt.asarray(np.inf, sycl_queue=q) - assert dpt.asnumpy(dpt.isinf(X)) == np.isinf(Xnp) - - -@pytest.mark.parametrize("dtype", ["c8", "c16"]) -def test_isinf_complex(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - y1 = complex(np.inf, np.inf) - y2 = complex(1, np.inf) - y3 = complex(np.inf, 1) - y4 = complex(2, 1) - y5 = complex(np.inf, 1) - y6 = complex(np.inf, np.nan) - - Ynp = np.repeat(np.array([y1, y2, y3, y4, y5, y6], dtype=dtype), 123) - Y = dpt.asarray(Ynp, sycl_queue=q) - assert np.array_equal(dpt.asnumpy(dpt.isinf(Y)), np.isinf(Ynp)) - - -@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) -def test_isinf_floats(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - y1 = np.nan - y2 = 1 - y3 = np.inf - y4 = -np.inf - - for mult in [123, 137, 255, 271, 272]: - Ynp = np.repeat(np.array([y1, y2, y3, y4], dtype=dtype), mult) - Y = dpt.asarray(Ynp, sycl_queue=q) - assert np.array_equal(dpt.asnumpy(dpt.isinf(Y)), np.isinf(Ynp)) - - -@pytest.mark.parametrize("dtype", _all_dtypes) -def test_isinf_order(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - arg_dt = np.dtype(dtype) - input_shape = (10, 10, 10, 10) - X = dpt.ones(input_shape, dtype=arg_dt, sycl_queue=q) - - for ord in ["C", "F", "A", "K"]: - for perms in itertools.permutations(range(4)): - U = dpt.permute_dims(X[::2, ::-1, ::-1, ::5], perms) - Y = dpt.isinf(U, order=ord) - expected_Y = np.full(Y.shape, False, dtype=Y.dtype) - assert np.allclose(dpt.asnumpy(Y), expected_Y) From 85d468c5cd71ce98dde500d06fd8033fc89f17d0 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Wed, 17 May 2023 19:33:48 -0500 Subject: [PATCH 27/48] Migrated tests for sqrt into elementwise folder --- .../test_sqrt.py} | 23 +------------------ 1 file changed, 1 insertion(+), 22 deletions(-) rename dpctl/tests/{test_tensor_sqrt.py => elementwise/test_sqrt.py} (90%) diff --git a/dpctl/tests/test_tensor_sqrt.py b/dpctl/tests/elementwise/test_sqrt.py similarity index 90% rename from dpctl/tests/test_tensor_sqrt.py rename to dpctl/tests/elementwise/test_sqrt.py index 2f924027dc..e957807c5c 100644 --- a/dpctl/tests/test_tensor_sqrt.py +++ b/dpctl/tests/elementwise/test_sqrt.py @@ -5,30 +5,9 @@ from numpy.testing import assert_equal import dpctl.tensor as dpt -import dpctl.tensor._type_utils as tu from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported -_all_dtypes = [ - "b1", - "i1", - "u1", - "i2", - "u2", - "i4", - "u4", - "i8", - "u8", - "f2", - "f4", - "f8", - "c8", - "c16", -] -_usm_types = ["device", "shared", "host"] - - -def _map_to_device_dtype(dt, dev): - return tu._to_device_supported_dtype(dt, dev) +from .utils import _all_dtypes, _map_to_device_dtype, _usm_types @pytest.mark.parametrize("dtype", _all_dtypes) From 7453bf7ad1965ffa6266adbb5bb75914fb32befb Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 May 2023 06:30:55 -0500 Subject: [PATCH 28/48] Corrected order="K" handling for binary function in some cases When both inputs must be promoted, e.g. `divide(boolean, integral)`, order=K can create temporary buffers using empty_likeK, and then the result could be created using _empty_pair_likeK utilities. This resolves the test failure for `divide`. --- dpctl/tensor/_elementwise_common.py | 33 ++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index f8cb197dc7..cc2f1ff679 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -470,23 +470,36 @@ def __call__(self, o1, o2, order="K"): if order in ["K", "A"]: if src1.flags.f_contiguous and src2.flags.f_contiguous: order = "F" - else: + elif src1.flags.c_contiguous and src2.flags.c_contiguous: order = "C" - buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order) + else: + order = "C" if order == "A" else "K" + if order == "K": + buf1 = _empty_like_orderK(src1, buf1_dt) + else: + buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order) ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=src1, dst=buf1, sycl_queue=exec_q ) - buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order) + if order == "K": + buf2 = _empty_like_orderK(src2, buf2_dt) + else: + buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order) ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=src2, dst=buf2, sycl_queue=exec_q ) - r = dpt.empty( - res_shape, - dtype=res_dt, - usm_type=res_usm_type, - sycl_queue=exec_q, - order=order, - ) + if order == "K": + r = _empty_like_pair_orderK( + buf1, buf2, res_dt, res_usm_type, exec_q + ) + else: + r = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) buf1 = dpt.broadcast_to(buf1, res_shape) buf2 = dpt.broadcast_to(buf2, res_shape) ht_, _ = self.binary_fn_( From 89f8fe94ac5cb0f9b634fbe5f11f00176ef877e6 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 May 2023 07:59:16 -0500 Subject: [PATCH 29/48] Tweak to find_buf_dtype2 If both input types must be promoted outside of their kind, use default device data type of the kind of the result array data type. E.g. divide( int8_array, bool_array ) must return float32/float64 depending on the device capabilities, not float16. --- dpctl/tensor/_type_utils.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index ac82c67722..d33f2eba06 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -17,6 +17,7 @@ import builtins import dpctl.tensor as dpt +import dpctl.tensor._tensor_impl as ti def _all_data_types(_fp16, _fp64): @@ -237,6 +238,20 @@ def _find_buf_dtype(arg_dtype, query_fn, sycl_dev): return None, None +def _get_device_default_dtype(dt_kind, sycl_dev): + if dt_kind == "b": + return dpt.dtype(ti.default_device_bool_type(sycl_dev)) + elif dt_kind == "i": + return dpt.dtype(ti.default_device_int_type(sycl_dev)) + elif dt_kind == "u": + return dpt.dtype(ti.default_device_int_type(sycl_dev).upper()) + elif dt_kind == "f": + return dpt.dtype(ti.default_device_fp_type(sycl_dev)) + elif dt_kind == "c": + return dpt.dtype(ti.default_device_complex_type(sycl_dev)) + raise RuntimeError + + def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev): res_dt = query_fn(arg1_dtype, arg2_dtype) if res_dt: @@ -254,7 +269,24 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev): if res_dt: ret_buf1_dt = None if buf1_dt == arg1_dtype else buf1_dt ret_buf2_dt = None if buf2_dt == arg2_dtype else buf2_dt - return ret_buf1_dt, ret_buf2_dt, res_dt + if ret_buf1_dt is None or ret_buf2_dt is None: + return ret_buf1_dt, ret_buf2_dt, res_dt + else: + # both are being promoted, if the kind of result is + # different than the kind of original input dtypes, + # we must use default dtype for the resulting kind. + if (res_dt.kind != arg1_dtype.kind) and ( + res_dt.kind != arg2_dtype.kind + ): + default_dt = _get_device_default_dtype( + res_dt.kind, sycl_dev + ) + if res_dt == default_dt: + return ret_buf1_dt, ret_buf2_dt, res_dt + else: + continue + else: + return ret_buf1_dt, ret_buf2_dt, res_dt return None, None, None From 2be09e5bb92a349c6d9104a5da0b8c4465fc9184 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 May 2023 08:01:31 -0500 Subject: [PATCH 30/48] Adding tests for dpt.divide --- dpctl/tests/elementwise/test_divide.py | 173 +++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 dpctl/tests/elementwise/test_divide.py diff --git a/dpctl/tests/elementwise/test_divide.py b/dpctl/tests/elementwise/test_divide.py new file mode 100644 index 0000000000..168803f945 --- /dev/null +++ b/dpctl/tests/elementwise/test_divide.py @@ -0,0 +1,173 @@ +import ctypes + +import numpy as np +import pytest + +import dpctl +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +from .utils import _all_dtypes, _compare_dtypes, _usm_types + + +@pytest.mark.parametrize("op1_dtype", _all_dtypes) +@pytest.mark.parametrize("op2_dtype", _all_dtypes) +def test_divide_dtype_matrix(op1_dtype, op2_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(op1_dtype, q) + skip_if_dtype_not_supported(op2_dtype, q) + + sz = 127 + ar1 = dpt.ones(sz, dtype=op1_dtype) + ar2 = dpt.ones_like(ar1, dtype=op2_dtype) + + r = dpt.divide(ar1, ar2) + assert isinstance(r, dpt.usm_ndarray) + expected = np.divide( + np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype) + ) + assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) + assert r.shape == ar1.shape + assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all() + assert r.sycl_queue == ar1.sycl_queue + + ar3 = dpt.ones(sz, dtype=op1_dtype) + ar4 = dpt.ones(2 * sz, dtype=op2_dtype) + + r = dpt.divide(ar3[::-1], ar4[::2]) + assert isinstance(r, dpt.usm_ndarray) + expected = np.divide( + np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype) + ) + assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q) + assert r.shape == ar3.shape + assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all() + + +@pytest.mark.parametrize("op1_usm_type", _usm_types) +@pytest.mark.parametrize("op2_usm_type", _usm_types) +def test_divide_usm_type_matrix(op1_usm_type, op2_usm_type): + get_queue_or_skip() + + sz = 128 + ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type) + ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type) + + r = dpt.divide(ar1, ar2) + assert isinstance(r, dpt.usm_ndarray) + expected_usm_type = dpctl.utils.get_coerced_usm_type( + (op1_usm_type, op2_usm_type) + ) + assert r.usm_type == expected_usm_type + + +def test_divide_order(): + get_queue_or_skip() + + ar1 = dpt.ones((20, 20), dtype="i4", order="C") + ar2 = dpt.ones((20, 20), dtype="i4", order="C") + r1 = dpt.divide(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.divide(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.divide(ar1, ar2, order="A") + assert r3.flags.c_contiguous + r4 = dpt.divide(ar1, ar2, order="K") + assert r4.flags.c_contiguous + + ar1 = dpt.ones((20, 20), dtype="i4", order="F") + ar2 = dpt.ones((20, 20), dtype="i4", order="F") + r1 = dpt.divide(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.divide(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.divide(ar1, ar2, order="A") + assert r3.flags.f_contiguous + r4 = dpt.divide(ar1, ar2, order="K") + assert r4.flags.f_contiguous + + ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] + ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] + r4 = dpt.divide(ar1, ar2, order="K") + assert r4.strides == (20, -1) + + ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT + ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT + r4 = dpt.divide(ar1, ar2, order="K") + assert r4.strides == (-1, 20) + + +def test_divide_broadcasting(): + get_queue_or_skip() + + m = dpt.ones((100, 5), dtype="i4") + v = dpt.arange(1, 6, dtype="i4") + + r = dpt.divide(m, v) + + expected = np.divide( + np.ones((100, 5), dtype="i4"), np.arange(1, 6, dtype="i4") + ) + assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all() + + r2 = dpt.divide(v, m) + expected2 = np.divide( + np.arange(1, 6, dtype="i4"), np.ones((100, 5), dtype="i4") + ) + assert (dpt.asnumpy(r2) == expected2.astype(r2.dtype)).all() + + +@pytest.mark.parametrize("arr_dt", _all_dtypes) +def test_divide_python_scalar(arr_dt): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arr_dt, q) + + X = dpt.ones((10, 10), dtype=arr_dt, sycl_queue=q) + py_ones = ( + bool(1), + int(1), + float(1), + complex(1), + np.float32(1), + ctypes.c_int(1), + ) + for sc in py_ones: + R = dpt.divide(X, sc) + assert isinstance(R, dpt.usm_ndarray) + R = dpt.divide(sc, X) + assert isinstance(R, dpt.usm_ndarray) + + +class MockArray: + def __init__(self, arr): + self.data_ = arr + + @property + def __sycl_usm_array_interface__(self): + return self.data_.__sycl_usm_array_interface__ + + +def test_divide_mock_array(): + get_queue_or_skip() + a = dpt.arange(10) + b = dpt.ones(10) + c = MockArray(b) + r = dpt.divide(a, c) + assert isinstance(r, dpt.usm_ndarray) + + +def test_divide_canary_mock_array(): + get_queue_or_skip() + a = dpt.arange(10) + + class Canary: + def __init__(self): + pass + + @property + def __sycl_usm_array_interface__(self): + return None + + c = Canary() + with pytest.raises(ValueError): + dpt.divide(a, c) From 7a1256511643becd0b8b68aba1ba6ecfaa6d2195 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 May 2023 08:40:01 -0500 Subject: [PATCH 31/48] Add tests for _get_default_device_type utility --- dpctl/tests/elementwise/test_type_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/dpctl/tests/elementwise/test_type_utils.py b/dpctl/tests/elementwise/test_type_utils.py index cc1166e966..65bc8ba895 100644 --- a/dpctl/tests/elementwise/test_type_utils.py +++ b/dpctl/tests/elementwise/test_type_utils.py @@ -112,6 +112,17 @@ def _denier_fn(dt): ) +def test_type_utils_get_device_default_type(): + try: + dev = dpctl.SyclDevice() + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + for k in ["b", "i", "u", "f", "c"]: + dt = tu._get_device_default_dtype(k, dev) + assert isinstance(dt, dpt.dtype) + assert dt.kind == k + + def test_type_utils_find_buf_dtype2(): def _denier_fn(dt1, dt2): return False From c3a3f0163785ba82242b5e6ec7e1c93d35a06e58 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 May 2023 14:13:05 -0500 Subject: [PATCH 32/48] Removed superfluous semicolon --- .../libtensor/include/kernels/elementwise_functions/abs.hpp | 1 - .../libtensor/include/kernels/elementwise_functions/cos.hpp | 1 - .../libtensor/include/kernels/elementwise_functions/isfinite.hpp | 1 - .../libtensor/include/kernels/elementwise_functions/isnan.hpp | 1 - .../libtensor/include/kernels/elementwise_functions/sqrt.hpp | 1 - 5 files changed, 5 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index 266f0bed8c..b233674f67 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -143,7 +143,6 @@ template struct AbsTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename AbsOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index 2dbfde12aa..7121c594ec 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -130,7 +130,6 @@ template struct CosTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename CosOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp index 258993c6a5..a157f42376 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp @@ -145,7 +145,6 @@ template struct IsFiniteTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename IsFiniteOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp index e0bdae1ca6..d592e884eb 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp @@ -247,7 +247,6 @@ template struct IsNanTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename IsNanOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp index 719670cea0..536d6446c9 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp @@ -131,7 +131,6 @@ template struct SqrtTypeMapFactory std::enable_if_t::value, int> get() { using rT = typename SqrtOutputType::value_type; - ; return td_ns::GetTypeid{}.get(); } }; From 1aa37dbdeb8952a7ec0c3c06b801f8a1bb972966 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 May 2023 14:13:21 -0500 Subject: [PATCH 33/48] Changes to NullPtrTable, NullPtrVector to remove unneeded methods. Also added const qualifier to enable deployment of NullPtrTable to work. --- .../libtensor/include/utils/type_dispatch.hpp | 63 ++----------------- 1 file changed, 4 insertions(+), 59 deletions(-) diff --git a/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp b/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp index 378173bb4b..25f06acc75 100644 --- a/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp +++ b/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp @@ -337,44 +337,16 @@ template struct GetTypeid template struct NullPtrVector { - using iterator_category = std::forward_iterator_tag; - using different_type = std::ptrdiff_t; using value_type = FunPtrT; - using pointer = value_type *; - using reference = value_type &; + using const_reference = value_type const &; NullPtrVector() : val(nullptr) {} - reference operator*() + const_reference operator[](int) const { return val; } - reference operator[](int) - { - return val; - } - - NullPtrVector &operator++() - { - return *this; - } - NullPtrVector operator++(int) - { - return *this; - } - - friend bool operator==(const NullPtrVector &a, - const NullPtrVector &b) - { - return true; - } - friend bool operator!=(const NullPtrVector &a, - const NullPtrVector &b) - { - return false; - } - private: value_type val; }; @@ -382,42 +354,15 @@ template struct NullPtrVector /*! @brief Class to generate table of null function pointers */ template struct NullPtrTable { - using iterator_category = std::forward_iterator_tag; - using different_type = std::ptrdiff_t; using value_type = NullPtrVector; - using pointer = value_type *; - using reference = value_type &; + using const_reference = value_type const &; NullPtrTable() : val() {} - reference operator*() + const_reference operator[](int) const { return val; } - reference operator[](int) - { - return val; - } - - NullPtrTable &operator++() - { - return *this; - } - NullPtrTable operator++(int) - { - return *this; - } - - friend bool operator==(const NullPtrTable &a, - const NullPtrTable &b) - { - return true; - } - friend bool operator!=(const NullPtrTable &a, - const NullPtrTable &b) - { - return false; - } private: value_type val; From 61100e4d29e9c10801b079ed8d9b4dfacd088e01 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 May 2023 14:15:04 -0500 Subject: [PATCH 34/48] Also try _get_device_default_type with invalid kind argument --- dpctl/tests/elementwise/test_type_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dpctl/tests/elementwise/test_type_utils.py b/dpctl/tests/elementwise/test_type_utils.py index 65bc8ba895..415b70e1f8 100644 --- a/dpctl/tests/elementwise/test_type_utils.py +++ b/dpctl/tests/elementwise/test_type_utils.py @@ -113,6 +113,8 @@ def _denier_fn(dt): def test_type_utils_get_device_default_type(): + with pytest.raises(RuntimeError): + tu._get_device_default_dtype("-", MockDevice(True, True)) try: dev = dpctl.SyclDevice() except dpctl.SyclDeviceCreationError: From cde757c0b1290cc69da7b8e1e3efca8b71b58041 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 May 2023 14:16:02 -0500 Subject: [PATCH 35/48] Added implementation of dpctl.tensor.equal ``` In [4]: if dpt.all(dpt.equal( dpt.arange(30), dpt.arange(50)[:30])): print("Equal") Equal ``` --- dpctl/tensor/__init__.py | 2 + dpctl/tensor/_elementwise_funcs.py | 205 +++++- .../kernels/elementwise_functions/equal.hpp | 455 ++++++++++++ .../source/elementwise_functions.cpp | 686 ++++++++++++++---- 4 files changed, 1178 insertions(+), 170 deletions(-) create mode 100644 dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index c3f2b9c6dd..be9426a834 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -96,6 +96,7 @@ add, cos, divide, + equal, isfinite, isinf, isnan, @@ -185,4 +186,5 @@ "isfinite", "sqrt", "divide", + "equal", ] diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index d90ba5a39a..3068486827 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -2,14 +2,20 @@ from ._elementwise_common import BinaryElementwiseFunc, UnaryElementwiseFunc -# ABS +# U01: ==== ABS (x) _abs_docstring_ = """ Calculate the absolute value element-wise. """ abs = UnaryElementwiseFunc("abs", ti._abs_result_type, ti._abs, _abs_docstring_) -# ADD +# U02: ==== ACOS (x) +# FIXME: implement U02 + +# U03: ===== ACOSH (x) +# FIXME: implement U03 + +# B01: ===== ADD (x1, x2) _add_docstring_ = """ add(x1, x2, order='K') @@ -31,8 +37,58 @@ "add", ti._add_result_type, ti._add, _add_docstring_ ) -# DIVIDE +# U04: ===== ASIN (x) +# FIXME: implement U04 + +# U05: ===== ASINH (x) +# FIXME: implement U05 + +# U06: ===== ATAN (x) +# FIXME: implement U06 + +# B02: ===== ATAN2 (x1, x2) +# FIXME: implemetn B02 + +# U07: ===== ATANH (x) +# FIXME: implemetn U07 + +# B03: ===== BITWISE_AND (x1, x2) +# FIXME: implemetn B03 + +# B04: ===== BITWISE_LEFT_SHIFT (x1, x2) +# FIXME: implement B04 + +# U08: ===== BITWISE_INVERT (x) +# FIXME: implement U08 + +# B05: ===== BITWISE_OR (x1, x2) +# FIXME: implement B05 + +# B06: ===== BITWISE_RIGHT_SHIFT (x1, x2) +# FIXME: implement B06 + +# B07: ===== BITWISE_XOR (x1, x2) +# FIXME: implement B07 + +# U09: ==== CEIL (x) +# FIXME: implement U09 + +# U10: ==== CONJ (x) +# FIXME: implement U10 + +# U11: ==== COS (x) +_cos_docstring = """ +cos(x, order='K') + +Computes cosine for each element `x_i` for input array `x`. +""" + +cos = UnaryElementwiseFunc("cos", ti._cos_result_type, ti._cos, _cos_docstring) + +# U12: ==== COSH (x) +# FIXME: implement U12 +# B08: ==== DIVIDE (x1, x2) _divide_docstring_ = """ divide(x1, x2, order='K') @@ -49,23 +105,56 @@ an array containing the result of element-wise division. The data type of the returned array is determined by the Type Promotion Rules. """ + divide = BinaryElementwiseFunc( "divide", ti._divide_result_type, ti._divide, _divide_docstring_ ) +# B09: ==== EQUAL (x1, x2) +_equal_docstring_ = """ +equal(x1, x2, order='K') -# COS - -_cos_docstring = """ -cos(x, order='K') +Calculates equality test results for each element `x1_i` of the input array `x1` +with the respective element `x2_i` of the input array `x2`. -Computes cosine for each element `x_i` for input array `x`. +Args: + x1 (usm_ndarray): + First input array, expected to have numeric data type. + x2 (usm_ndarray): + Second input array, also expected to have numeric data type. +Returns: + usm_narray: + an array containing the result of element-wise equality comparison. + The data type of the returned array is determined by the + Type Promotion Rules. """ -cos = UnaryElementwiseFunc("cos", ti._cos_result_type, ti._cos, _cos_docstring) +equal = BinaryElementwiseFunc( + "equal", ti._equal_result_type, ti._equal, _equal_docstring_ +) + +# U13: ==== EXP (x) +# FIXME: implement U13 + +# U14: ==== EXPM1 (x) +# FIXME: implement U14 + +# U15: ==== FLOOR (x) +# FIXME: implement U15 -# ISFINITE +# B10: ==== FLOOR_DIVIDE (x1, x2) +# FIXME: implement B10 +# B11: ==== GREATER (x1, x2) +# FIXME: implement B11 + +# B12: ==== GREATER_EQUAL (x1, x2) +# FIXME: implement B12 + +# U16: ==== IMAG (x) +# FIXME: implement U16 + +# U17: ==== ISFINITE (x) _isfinite_docstring_ = """ Computes if every element of input array is a finite number. """ @@ -74,8 +163,16 @@ "isfinite", ti._isfinite_result_type, ti._isfinite, _isfinite_docstring_ ) -# ISNAN +# U18: ==== ISINF (x) +_isinf_docstring_ = """ +Computes if every element of input array is an infinity. +""" + +isinf = UnaryElementwiseFunc( + "isinf", ti._isinf_result_type, ti._isinf, _isinf_docstring_ +) +# U19: ==== ISNAN (x) _isnan_docstring_ = """ Computes if every element of input array is a NaN. """ @@ -84,18 +181,76 @@ "isnan", ti._isnan_result_type, ti._isnan, _isnan_docstring_ ) -# ISINF +# B13: ==== LESS (x1, x2) +# FIXME: implement B13 -_isinf_docstring_ = """ -Computes if every element of input array is an infinity. -""" +# B14: ==== LESS_EQUAL (x1, x2) +# FIXME: implement B14 -isinf = UnaryElementwiseFunc( - "isinf", ti._isinf_result_type, ti._isinf, _isinf_docstring_ -) +# U20: ==== LOG (x) +# FIXME: implement U20 + +# U21: ==== LOG1P (x) +# FIXME: implement U21 + +# U22: ==== LOG2 (x) +# FIXME: implement U22 + +# U23: ==== LOG10 (x) +# FIXME: implement U23 + +# B15: ==== LOGADDEXP (x1, x2) +# FIXME: implement B15 -# SQRT +# B16: ==== LOGICAL_AND (x1, x2) +# FIXME: implement B16 +# U24: ==== LOGICAL_NOT (x) +# FIXME: implement U24 + +# B17: ==== LOGICAL_OR (x1, x2) +# FIXME: implement B17 + +# B18: ==== LOGICAL_XOR (x1, x2) +# FIXME: implement B18 + +# B19: ==== MULTIPLY (x1, x2) +# FIXME: implement B19 + +# U25: ==== NEGATIVE (x) +# FIXME: implement U25 + +# B20: ==== NOT_EQUAL (x1, x2) +# FIXME: implement B20 + +# U26: ==== POSITIVE (x) +# FIXME: implement U26 + +# B21: ==== POW (x1, x2) +# FIXME: implement B21 + +# U27: ==== REAL (x) +# FIXME: implement U27 + +# B22: ==== REMAINDER (x1, x2) +# FIXME: implement B22 + +# U28: ==== ROUND (x) +# FIXME: implement U28 + +# U29: ==== SIGN (x) +# FIXME: implement U29 + +# U30: ==== SIN (x) +# FIXME: implement U30 + +# U31: ==== SINH (x) +# FIXME: implement U31 + +# U32: ==== SQUARE (x) +# FIXME: implement U32 + +# U33: ==== SQRT (x) _sqrt_docstring_ = """ Computes sqrt for each element `x_i` for input array `x`. """ @@ -103,3 +258,15 @@ sqrt = UnaryElementwiseFunc( "sqrt", ti._sqrt_result_type, ti._sqrt, _sqrt_docstring_ ) + +# B23: ==== SUBTRACT (x1, x2) +# FIXME: implement B23 + +# U34: ==== TAN (x) +# FIXME: implement U34 + +# U35: ==== TANH (x) +# FIXME: implement U35 + +# U36: ==== TRUNC (x) +# FIXME: implement U36 diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp new file mode 100644 index 0000000000..01b9393d06 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp @@ -0,0 +1,455 @@ +#pragma once +#include +#include +#include +#include + +#include "utils/offset_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" + +#include "kernels/elementwise_functions/common.hpp" +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ +namespace equal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; + +template struct EqualFunctor +{ + static_assert(std::is_same_v); + + using supports_sg_loadstore = std::negation< + std::disjunction, tu_ns::is_complex>>; + using supports_vec = std::conjunction< + std::is_same, + std::negation, + tu_ns::is_complex>>>; + + resT operator()(const argT1 &in1, const argT2 &in2) + { + return (in1 == in2); + } + + template + sycl::vec operator()(const sycl::vec &in1, + const sycl::vec &in2) + { + auto tmp = (in1 == in2); + if constexpr (std::is_same_v) { + return tmp; + } + else { + using dpctl::tensor::type_utils::vec_cast; + + return vec_cast( + tmp); + } + } +}; + +template +using EqualContigFunctor = + elementwise_common::BinaryContigFunctor, + vec_sz, + n_vecs>; + +template +using EqualStridedFunctor = + elementwise_common::BinaryStridedFunctor>; + +template struct EqualOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + td_ns::BinaryTypeMapEntry, + T2, + std::complex, + bool>, + td_ns::BinaryTypeMapEntry, + T2, + std::complex, + bool>, + td_ns::DefaultEntry>::result_type; +}; + +template +class equal_contig_kernel; + +typedef sycl::event (*equal_contig_impl_fn_ptr_t)( + sycl::queue, + size_t, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &); + +template +sycl::event equal_contig_impl(sycl::queue exec_q, + size_t nelems, + const char *arg1_p, + py::ssize_t arg1_offset, + const char *arg2_p, + py::ssize_t arg2_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends = {}) +{ + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + size_t lws = 64; + constexpr unsigned int vec_sz = 4; + constexpr unsigned int n_vecs = 2; + const size_t n_groups = + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); + + using resTy = typename EqualOutputType::value_type; + + const argTy1 *arg1_tp = + reinterpret_cast(arg1_p) + arg1_offset; + const argTy2 *arg2_tp = + reinterpret_cast(arg2_p) + arg2_offset; + resTy *res_tp = reinterpret_cast(res_p) + res_offset; + + cgh.parallel_for< + equal_contig_kernel>( + sycl::nd_range<1>(gws_range, lws_range), + EqualContigFunctor( + arg1_tp, arg2_tp, res_tp, nelems)); + }); + return comp_ev; +} + +template struct EqualContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename EqualOutputType::value_type, void>) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = equal_contig_impl; + return fn; + } + } +}; + +template struct EqualTypeMapFactory +{ + /*! @brief get typeid for output type of operator()==(x, y), always bool */ + std::enable_if_t::value, int> get() + { + using rT = typename EqualOutputType::value_type; + return td_ns::GetTypeid{}.get(); + } +}; + +template +class equal_strided_strided_kernel; + +typedef sycl::event (*equal_strided_impl_fn_ptr_t)( + sycl::queue, + size_t, + int, + const py::ssize_t *, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &, + const std::vector &); + +template +sycl::event +equal_strided_impl(sycl::queue exec_q, + size_t nelems, + int nd, + const py::ssize_t *shape_and_strides, + const char *arg1_p, + py::ssize_t arg1_offset, + const char *arg2_p, + py::ssize_t arg2_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.depends_on(additional_depends); + + using resTy = typename EqualOutputType::value_type; + + using IndexerT = + typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + + IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset, + shape_and_strides}; + + const argTy1 *arg1_tp = reinterpret_cast(arg1_p); + const argTy2 *arg2_tp = reinterpret_cast(arg2_p); + resTy *res_tp = reinterpret_cast(res_p); + + cgh.parallel_for< + equal_strided_strided_kernel>( + {nelems}, EqualStridedFunctor( + arg1_tp, arg2_tp, res_tp, indexer)); + }); + return comp_ev; +} + +template struct EqualStridedFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename EqualOutputType::value_type, void>) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = equal_strided_impl; + return fn; + } + } +}; + +template +class equal_matrix_row_broadcast_sg_krn; + +typedef sycl::event (*equal_contig_matrix_contig_row_broadcast_impl_fn_ptr_t)( + sycl::queue, + std::vector &, + size_t, + size_t, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &); + +template +using EqualContigMatrixContigRowBroadcastingFunctor = + elementwise_common::BinaryContigMatrixContigRowBroadcastingFunctor< + argT1, + argT2, + resT, + EqualFunctor>; + +template +sycl::event equal_contig_matrix_contig_row_broadcast_impl( + sycl::queue exec_q, + std::vector &host_tasks, + size_t n0, + size_t n1, + const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix + py::ssize_t mat_offset, + const char *vec_p, // typeless pointer to (n1,) contiguous row + py::ssize_t vec_offset, + char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, + // res[i,j] = (mat[i,j] == vec[j]) + py::ssize_t res_offset, + const std::vector &depends = {}) +{ + const argT1 *mat = reinterpret_cast(mat_p) + mat_offset; + const argT2 *vec = reinterpret_cast(vec_p) + vec_offset; + resT *res = reinterpret_cast(res_p) + res_offset; + + const auto &dev = exec_q.get_device(); + const auto &sg_sizes = dev.get_info(); + // Get device-specific kernel info max_sub_group_size + size_t max_sgSize = + *(std::max_element(std::begin(sg_sizes), std::end(sg_sizes))); + + size_t n1_padded = n1 + max_sgSize; + argT2 *padded_vec = sycl::malloc_device(n1_padded, exec_q); + + if (padded_vec == nullptr) { + throw std::runtime_error("Could not allocate memory on the device"); + } + sycl::event make_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); // ensure vec contains actual data + cgh.parallel_for({n1_padded}, [=](sycl::id<1> id) { + auto i = id[0]; + padded_vec[i] = vec[i % n1]; + }); + }); + + // sub-group spans work-items [I, I + sgSize) + // base = ndit.get_global_linear_id() - sg.get_local_id()[0] + // Generically, sg.load( &mat[base]) may load arrays from + // different rows of mat. The start corresponds to row (base / n0) + // We read sg.load(&padded_vec[(base / n0)]). The vector is padded to + // ensure that reads are accessible + + size_t lws = 64; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(make_padded_vec_ev); + + auto lwsRange = sycl::range<1>(lws); + size_t n_elems = n0 * n1; + size_t n_groups = (n_elems + lws - 1) / lws; + auto gwsRange = sycl::range<1>(n_groups * lws); + + cgh.parallel_for< + class equal_matrix_row_broadcast_sg_krn>( + sycl::nd_range<1>(gwsRange, lwsRange), + EqualContigMatrixContigRowBroadcastingFunctor( + mat, padded_vec, res, n_elems, n1)); + }); + + sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(comp_ev); + sycl::context ctx = exec_q.get_context(); + cgh.host_task([ctx, padded_vec]() { sycl::free(padded_vec, ctx); }); + }); + host_tasks.push_back(tmp_cleanup_ev); + + return comp_ev; +} + +template +struct EqualContigMatrixContigRowBroadcastFactory +{ + fnT get() + { + using resT = typename EqualOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + if constexpr (dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = + equal_contig_matrix_contig_row_broadcast_impl; + return fn; + } + } + } +}; + +typedef sycl::event (*equal_contig_row_contig_matrix_broadcast_impl_fn_ptr_t)( + sycl::queue, + std::vector &, + size_t, + size_t, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &); + +template +sycl::event equal_contig_row_contig_matrix_broadcast_impl( + sycl::queue exec_q, + std::vector &host_tasks, + size_t n0, + size_t n1, + const char *vec_p, // typeless pointer to (n1,) contiguous row + py::ssize_t vec_offset, + const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix + py::ssize_t mat_offset, + char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, + // res[i,j] = (mat[i,j] == vec[j]) + py::ssize_t res_offset, + const std::vector &depends = {}) +{ + return equal_contig_matrix_contig_row_broadcast_impl( + exec_q, host_tasks, n0, n1, mat_p, mat_offset, vec_p, vec_offset, res_p, + res_offset, depends); +}; + +template +struct EqualContigRowContigMatrixBroadcastFactory +{ + fnT get() + { + using resT = typename EqualOutputType::value_type; + if constexpr (std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + if constexpr (dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = + equal_contig_row_contig_matrix_broadcast_impl; + return fn; + } + } + } +}; + +} // namespace equal +} // namespace kernels +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp index 42e5943e6d..0bd751ae1c 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions.cpp @@ -35,6 +35,7 @@ #include "kernels/elementwise_functions/abs.hpp" #include "kernels/elementwise_functions/add.hpp" #include "kernels/elementwise_functions/cos.hpp" +#include "kernels/elementwise_functions/equal.hpp" #include "kernels/elementwise_functions/isfinite.hpp" #include "kernels/elementwise_functions/isinf.hpp" #include "kernels/elementwise_functions/isnan.hpp" @@ -96,7 +97,7 @@ int _result_typeid(int arg_typeid, const int *fn_output_id) return fn_output_id[arg_typeid]; } -// ABS +// U01: ==== ABS (x) namespace impl { @@ -131,7 +132,370 @@ void populate_abs_dispatch_vectors(void) } // namespace impl -// ISFINITE +// U02: ==== ACOS (x) +namespace impl +{ +// FIXME: add code for U02 +} // namespace impl + +// U03: ===== ACOSH (x) +namespace impl +{ +// FIXME: add code for U03 +} // namespace impl + +// B01: ===== ADD (x1, x2) +namespace impl +{ +namespace add_fn_ns = dpctl::tensor::kernels::add; + +using add_fn_ns::add_contig_impl_fn_ptr_t; +using add_fn_ns::add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using add_fn_ns::add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using add_fn_ns::add_strided_impl_fn_ptr_t; + +static add_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static int add_output_id_table[td_ns::num_types][td_ns::num_types]; + +static add_strided_impl_fn_ptr_t add_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +// add(matrix, row) +static add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t + add_contig_matrix_contig_row_broadcast_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +// add(row, matrix) +static add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t + add_contig_row_contig_matrix_broadcast_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_add_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = add_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::AddTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(add_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::AddStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(add_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::AddContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(add_contig_dispatch_table); + + // function pointers for operation on contiguous matrix, contiguous row + // with contiguous matrix output + using fn_ns::AddContigMatrixContigRowBroadcastFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table( + add_contig_matrix_contig_row_broadcast_dispatch_table); + + // function pointers for operation on contiguous row, contiguous matrix + // with contiguous matrix output + using fn_ns::AddContigRowContigMatrixBroadcastFactory; + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table( + add_contig_row_contig_matrix_broadcast_dispatch_table); +}; + +} // namespace impl + +// U04: ===== ASIN (x) +namespace impl +{ +// FIXME: add code for U04 +} // namespace impl + +// U05: ===== ASINH (x) +namespace impl +{ +// FIXME: add code for U05 +} // namespace impl + +// U06: ===== ATAN (x) +namespace impl +{ +// FIXME: add code for U06 +} // namespace impl + +// B02: ===== ATAN2 (x1, x2) +namespace impl +{ +// FIXME: add code for B02 +} // namespace impl + +// U07: ===== ATANH (x) +namespace impl +{ +// FIXME: add code for U07 +} // namespace impl + +// B03: ===== BITWISE_AND (x1, x2) +namespace impl +{ +// FIXME: add code for B03 +} // namespace impl + +// B04: ===== BITWISE_LEFT_SHIFT (x1, x2) +namespace impl +{ +// FIXME: add code for B04 +} // namespace impl + +// U08: ===== BITWISE_INVERT (x) +namespace impl +{ +// FIXME: add code for U08 +} // namespace impl + +// B05: ===== BITWISE_OR (x1, x2) +namespace impl +{ +// FIXME: add code for B05 +} // namespace impl + +// B06: ===== BITWISE_RIGHT_SHIFT (x1, x2) +namespace impl +{ +// FIXME: add code for B06 +} // namespace impl + +// B07: ===== BITWISE_XOR (x1, x2) +namespace impl +{ +// FIXME: add code for B07 +} // namespace impl + +// U09: ==== CEIL (x) +namespace impl +{ +// FIXME: add code for U09 +} // namespace impl + +// U10: ==== CONJ (x) +namespace impl +{ +// FIXME: add code for U10 +} // namespace impl + +// U11: ==== COS (x) +namespace impl +{ + +namespace cos_fn_ns = dpctl::tensor::kernels::cos; +using cos_fn_ns::cos_contig_impl_fn_ptr_t; +using cos_fn_ns::cos_strided_impl_fn_ptr_t; + +static cos_contig_impl_fn_ptr_t cos_contig_dispatch_vector[td_ns::num_types]; +static int cos_output_typeid_vector[td_ns::num_types]; +static cos_strided_impl_fn_ptr_t cos_strided_dispatch_vector[td_ns::num_types]; + +void populate_cos_dispatch_vectors(void) +{ + using namespace td_ns; + namespace fn_ns = cos_fn_ns; + + using fn_ns::CosContigFactory; + DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(cos_contig_dispatch_vector); + + using fn_ns::CosStridedFactory; + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(cos_strided_dispatch_vector); + + using fn_ns::CosTypeMapFactory; + DispatchVectorBuilder dvb3; + dvb3.populate_dispatch_vector(cos_output_typeid_vector); +} + +} // namespace impl + +// U12: ==== COSH (x) +namespace impl +{ +// FIXME: add code for U12 +} // namespace impl + +// B08: ==== DIVIDE (x1, x2) +namespace impl +{ +namespace true_divide_fn_ns = dpctl::tensor::kernels::true_divide; + +using true_divide_fn_ns::true_divide_contig_impl_fn_ptr_t; +using true_divide_fn_ns:: + true_divide_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using true_divide_fn_ns:: + true_divide_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using true_divide_fn_ns::true_divide_strided_impl_fn_ptr_t; + +static true_divide_contig_impl_fn_ptr_t + true_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int true_divide_output_id_table[td_ns::num_types][td_ns::num_types]; + +static true_divide_strided_impl_fn_ptr_t + true_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +// divide(matrix, row) +static true_divide_contig_matrix_contig_row_broadcast_impl_fn_ptr_t + true_divide_contig_matrix_contig_row_broadcast_dispatch_table + [td_ns::num_types][td_ns::num_types]; + +// divide(row, matrix) +static true_divide_contig_row_contig_matrix_broadcast_impl_fn_ptr_t + true_divide_contig_row_contig_matrix_broadcast_dispatch_table + [td_ns::num_types][td_ns::num_types]; + +void populate_true_divide_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = true_divide_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::TrueDivideTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(true_divide_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::TrueDivideStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(true_divide_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::TrueDivideContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(true_divide_contig_dispatch_table); + + // function pointers for operation on contiguous matrix, contiguous row + // with contiguous matrix output + using fn_ns::TrueDivideContigMatrixContigRowBroadcastFactory; + DispatchTableBuilder< + true_divide_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, + TrueDivideContigMatrixContigRowBroadcastFactory, num_types> + dtb4; + dtb4.populate_dispatch_table( + true_divide_contig_matrix_contig_row_broadcast_dispatch_table); + + // function pointers for operation on contiguous row, contiguous matrix + // with contiguous matrix output + using fn_ns::TrueDivideContigRowContigMatrixBroadcastFactory; + DispatchTableBuilder< + true_divide_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, + TrueDivideContigRowContigMatrixBroadcastFactory, num_types> + dtb5; + dtb5.populate_dispatch_table( + true_divide_contig_row_contig_matrix_broadcast_dispatch_table); +}; + +} // namespace impl + +// B09: ==== EQUAL (x1, x2) +namespace impl +{ +namespace equal_fn_ns = dpctl::tensor::kernels::equal; + +using equal_fn_ns::equal_contig_impl_fn_ptr_t; +using equal_fn_ns::equal_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using equal_fn_ns::equal_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using equal_fn_ns::equal_strided_impl_fn_ptr_t; + +static equal_contig_impl_fn_ptr_t equal_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static int equal_output_id_table[td_ns::num_types][td_ns::num_types]; + +static equal_strided_impl_fn_ptr_t + equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_equal_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = equal_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::EqualTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(equal_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::EqualStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(equal_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::EqualContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(equal_contig_dispatch_table); +}; +} // namespace impl + +// U13: ==== EXP (x) +namespace impl +{ +// FIXME: add code for U13 +} // namespace impl + +// U14: ==== EXPM1 (x) +namespace impl +{ +// FIXME: add code for U14 +} // namespace impl + +// U15: ==== FLOOR (x) +namespace impl +{ +// FIXME: add code for U15 +} // namespace impl + +// B10: ==== FLOOR_DIVIDE (x1, x2) +namespace impl +{ +// FIXME: add code for B10 +} // namespace impl + +// B11: ==== GREATER (x1, x2) +namespace impl +{ +// FIXME: add code for B11 +} // namespace impl + +// B12: ==== GREATER_EQUAL (x1, x2) +namespace impl +{ +// FIXME: add code for B12 +} // namespace impl + +// U16: ==== IMAG (x) +namespace impl +{ +// FIXME: add code for U16 +} // namespace impl + +// U17: ==== ISFINITE (x) namespace impl { namespace isfinite_fn_ns = dpctl::tensor::kernels::isfinite; @@ -168,7 +532,7 @@ void populate_isfinite_dispatch_vectors(void) } // namespace impl -// ISINF +// U18: ==== ISINF (x) namespace impl { namespace isinf_fn_ns = dpctl::tensor::kernels::isinf; @@ -205,7 +569,7 @@ void populate_isinf_dispatch_vectors(void) } // namespace impl -// ISNAN +// U19: ==== ISNAN (x) namespace impl { namespace isnan_fn_ns = dpctl::tensor::kernels::isnan; @@ -242,114 +606,145 @@ void populate_isnan_dispatch_vectors(void) } // namespace impl -// COS +// B13: ==== LESS (x1, x2) namespace impl { +// FIXME: add code for B13 +} // namespace impl -namespace cos_fn_ns = dpctl::tensor::kernels::cos; -using cos_fn_ns::cos_contig_impl_fn_ptr_t; -using cos_fn_ns::cos_strided_impl_fn_ptr_t; +// B14: ==== LESS_EQUAL (x1, x2) +namespace impl +{ +// FIXME: add code for B14 +} // namespace impl -static cos_contig_impl_fn_ptr_t cos_contig_dispatch_vector[td_ns::num_types]; -static int cos_output_typeid_vector[td_ns::num_types]; -static cos_strided_impl_fn_ptr_t cos_strided_dispatch_vector[td_ns::num_types]; +// U20: ==== LOG (x) +namespace impl +{ +// FIXME: add code for U20 +} // namespace impl -void populate_cos_dispatch_vectors(void) +// U21: ==== LOG1P (x) +namespace impl { - using namespace td_ns; - namespace fn_ns = cos_fn_ns; +// FIXME: add code for U21 +} // namespace impl - using fn_ns::CosContigFactory; - DispatchVectorBuilder - dvb1; - dvb1.populate_dispatch_vector(cos_contig_dispatch_vector); +// U22: ==== LOG2 (x) +namespace impl +{ +// FIXME: add code for U22 +} // namespace impl - using fn_ns::CosStridedFactory; - DispatchVectorBuilder - dvb2; - dvb2.populate_dispatch_vector(cos_strided_dispatch_vector); +// U23: ==== LOG10 (x) +namespace impl +{ +// FIXME: add code for U23 +} // namespace impl - using fn_ns::CosTypeMapFactory; - DispatchVectorBuilder dvb3; - dvb3.populate_dispatch_vector(cos_output_typeid_vector); -} +// B15: ==== LOGADDEXP (x1, x2) +namespace impl +{ +// FIXME: add code for B15 +} // namespace impl +// B16: ==== LOGICAL_AND (x1, x2) +namespace impl +{ +// FIXME: add code for B16 } // namespace impl -// ADD +// U24: ==== LOGICAL_NOT (x) +namespace impl +{ +// FIXME: add code for U24 +} // namespace impl +// B17: ==== LOGICAL_OR (x1, x2) namespace impl { -namespace add_fn_ns = dpctl::tensor::kernels::add; +// FIXME: add code for B17 +} // namespace impl -using add_fn_ns::add_contig_impl_fn_ptr_t; -using add_fn_ns::add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; -using add_fn_ns::add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; -using add_fn_ns::add_strided_impl_fn_ptr_t; +// B18: ==== LOGICAL_XOR (x1, x2) +namespace impl +{ +// FIXME: add code for B18 +} // namespace impl -static add_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static int add_output_id_table[td_ns::num_types][td_ns::num_types]; +// B19: ==== MULTIPLY (x1, x2) +namespace impl +{ +// FIXME: add code for B19 +} // namespace impl -static add_strided_impl_fn_ptr_t add_strided_dispatch_table[td_ns::num_types] - [td_ns::num_types]; +// U25: ==== NEGATIVE (x) +namespace impl +{ +// FIXME: add code for U25 +} // namespace impl -// add(matrix, row) -static add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t - add_contig_matrix_contig_row_broadcast_dispatch_table[td_ns::num_types] - [td_ns::num_types]; +// B20: ==== NOT_EQUAL (x1, x2) +namespace impl +{ +// FIXME: add code for B20 +} // namespace impl -// add(row, matrix) -static add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t - add_contig_row_contig_matrix_broadcast_dispatch_table[td_ns::num_types] - [td_ns::num_types]; +// U26: ==== POSITIVE (x) +namespace impl +{ +// FIXME: add code for U26 +} // namespace impl -void populate_add_dispatch_tables(void) +// B21: ==== POW (x1, x2) +namespace impl { - using namespace td_ns; - namespace fn_ns = add_fn_ns; +// FIXME: add code for B21 +} // namespace impl - // which input types are supported, and what is the type of the result - using fn_ns::AddTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(add_output_id_table); +// U27: ==== REAL (x) +namespace impl +{ +// FIXME: add code for U27 +} // namespace impl - // function pointers for operation on general strided arrays - using fn_ns::AddStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(add_strided_dispatch_table); +// B22: ==== REMAINDER (x1, x2) +namespace impl +{ +// FIXME: add code for B22 +} // namespace impl - // function pointers for operation on contiguous inputs and output - using fn_ns::AddContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(add_contig_dispatch_table); +// U28: ==== ROUND (x) +namespace impl +{ +// FIXME: add code for U28 +} // namespace impl - // function pointers for operation on contiguous matrix, contiguous row - // with contiguous matrix output - using fn_ns::AddContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder - dtb4; - dtb4.populate_dispatch_table( - add_contig_matrix_contig_row_broadcast_dispatch_table); +// U29: ==== SIGN (x) +namespace impl +{ +// FIXME: add code for U29 +} // namespace impl - // function pointers for operation on contiguous row, contiguous matrix - // with contiguous matrix output - using fn_ns::AddContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder - dtb5; - dtb5.populate_dispatch_table( - add_contig_row_contig_matrix_broadcast_dispatch_table); -}; +// U30: ==== SIN (x) +namespace impl +{ +// FIXME: add code for U30 +} // namespace impl + +// U31: ==== SINH (x) +namespace impl +{ +// FIXME: add code for U31 +} // namespace impl +// U32: ==== SQUARE (x) +namespace impl +{ +// FIXME: add code for U32 } // namespace impl -// SQRT +// U33: ==== SQRT (x) namespace impl { @@ -386,82 +781,33 @@ void populate_sqrt_dispatch_vectors(void) } // namespace impl -// DIVIDE +// B23: ==== SUBTRACT (x1, x2) namespace impl { -namespace true_divide_fn_ns = dpctl::tensor::kernels::true_divide; - -using true_divide_fn_ns::true_divide_contig_impl_fn_ptr_t; -using true_divide_fn_ns:: - true_divide_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; -using true_divide_fn_ns:: - true_divide_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; -using true_divide_fn_ns::true_divide_strided_impl_fn_ptr_t; - -static true_divide_contig_impl_fn_ptr_t - true_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static int true_divide_output_id_table[td_ns::num_types][td_ns::num_types]; - -static true_divide_strided_impl_fn_ptr_t - true_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; - -// divide(matrix, row) -static true_divide_contig_matrix_contig_row_broadcast_impl_fn_ptr_t - true_divide_contig_matrix_contig_row_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; - -// divide(row, matrix) -static true_divide_contig_row_contig_matrix_broadcast_impl_fn_ptr_t - true_divide_contig_row_contig_matrix_broadcast_dispatch_table - [td_ns::num_types][td_ns::num_types]; +// FIXME: add code for B23 +} // namespace impl -void populate_true_divide_dispatch_tables(void) +// U34: ==== TAN (x) +namespace impl { - using namespace td_ns; - namespace fn_ns = true_divide_fn_ns; - - // which input types are supported, and what is the type of the result - using fn_ns::TrueDivideTypeMapFactory; - DispatchTableBuilder dtb1; - dtb1.populate_dispatch_table(true_divide_output_id_table); - - // function pointers for operation on general strided arrays - using fn_ns::TrueDivideStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(true_divide_strided_dispatch_table); - - // function pointers for operation on contiguous inputs and output - using fn_ns::TrueDivideContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(true_divide_contig_dispatch_table); - - // function pointers for operation on contiguous matrix, contiguous row - // with contiguous matrix output - using fn_ns::TrueDivideContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder< - true_divide_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, - TrueDivideContigMatrixContigRowBroadcastFactory, num_types> - dtb4; - dtb4.populate_dispatch_table( - true_divide_contig_matrix_contig_row_broadcast_dispatch_table); +// FIXME: add code for U34 +} // namespace impl - // function pointers for operation on contiguous row, contiguous matrix - // with contiguous matrix output - using fn_ns::TrueDivideContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder< - true_divide_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, - TrueDivideContigRowContigMatrixBroadcastFactory, num_types> - dtb5; - dtb5.populate_dispatch_table( - true_divide_contig_row_contig_matrix_broadcast_dispatch_table); -}; +// U35: ==== TANH (x) +namespace impl +{ +// FIXME: add code for U35 +} // namespace impl +// U36: ==== TRUNC (x) +namespace impl +{ +// FIXME: add code for U36 } // namespace impl +// ========================================================================================== +// // + namespace py = pybind11; void init_elementwise_functions(py::module_ m) @@ -642,7 +988,45 @@ void init_elementwise_functions(py::module_ m) } // B09: ==== EQUAL (x1, x2) - // FIXME: + { + impl::populate_equal_dispatch_tables(); + using impl::equal_contig_dispatch_table; + using impl::equal_output_id_table; + using impl::equal_strided_dispatch_table; + + auto equal_pyapi = [&](dpctl::tensor::usm_ndarray src1, + dpctl::tensor::usm_ndarray src2, + dpctl::tensor::usm_ndarray dst, + sycl::queue exec_q, + const std::vector &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, equal_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + equal_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + equal_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + impl:: + equal_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + impl:: + equal_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto equal_result_type_pyapi = [&](py::dtype dtype1, py::dtype dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + equal_output_id_table); + }; + m.def("_equal", equal_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_equal_result_type", equal_result_type_pyapi, ""); + } // U13: ==== EXP (x) // FIXME: From 5f659feb0eb94666de710dcae359c6210b9a5b70 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 May 2023 14:24:21 -0500 Subject: [PATCH 36/48] Adding tests for equal --- dpctl/tests/elementwise/test_equal.py | 170 ++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 dpctl/tests/elementwise/test_equal.py diff --git a/dpctl/tests/elementwise/test_equal.py b/dpctl/tests/elementwise/test_equal.py new file mode 100644 index 0000000000..a6821a3d5a --- /dev/null +++ b/dpctl/tests/elementwise/test_equal.py @@ -0,0 +1,170 @@ +import ctypes + +import numpy as np +import pytest + +import dpctl +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported + +from .utils import _all_dtypes, _compare_dtypes, _usm_types + + +@pytest.mark.parametrize("op1_dtype", _all_dtypes) +@pytest.mark.parametrize("op2_dtype", _all_dtypes) +def test_equal_dtype_matrix(op1_dtype, op2_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(op1_dtype, q) + skip_if_dtype_not_supported(op2_dtype, q) + + sz = 127 + ar1 = dpt.ones(sz, dtype=op1_dtype) + ar2 = dpt.ones_like(ar1, dtype=op2_dtype) + + r = dpt.equal(ar1, ar2) + assert isinstance(r, dpt.usm_ndarray) + expected_dtype = np.equal( + np.zeros(1, dtype=op1_dtype), np.zeros(1, dtype=op2_dtype) + ).dtype + assert _compare_dtypes(r.dtype, expected_dtype, sycl_queue=q) + assert r.shape == ar1.shape + assert (dpt.asnumpy(r) == np.full(r.shape, True, dtype=r.dtype)).all() + assert r.sycl_queue == ar1.sycl_queue + + ar3 = dpt.ones(sz, dtype=op1_dtype) + ar4 = dpt.ones(2 * sz, dtype=op2_dtype) + + r = dpt.equal(ar3[::-1], ar4[::2]) + assert isinstance(r, dpt.usm_ndarray) + expected_dtype = np.equal( + np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype) + ).dtype + assert _compare_dtypes(r.dtype, expected_dtype, sycl_queue=q) + assert r.shape == ar3.shape + assert (dpt.asnumpy(r) == np.full(r.shape, True, dtype=r.dtype)).all() + + +@pytest.mark.parametrize("op1_usm_type", _usm_types) +@pytest.mark.parametrize("op2_usm_type", _usm_types) +def test_equal_usm_type_matrix(op1_usm_type, op2_usm_type): + get_queue_or_skip() + + sz = 128 + ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type) + ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type) + + r = dpt.equal(ar1, ar2) + assert isinstance(r, dpt.usm_ndarray) + expected_usm_type = dpctl.utils.get_coerced_usm_type( + (op1_usm_type, op2_usm_type) + ) + assert r.usm_type == expected_usm_type + + +def test_equal_order(): + get_queue_or_skip() + + ar1 = dpt.ones((20, 20), dtype="i4", order="C") + ar2 = dpt.ones((20, 20), dtype="i4", order="C") + r1 = dpt.equal(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.equal(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.equal(ar1, ar2, order="A") + assert r3.flags.c_contiguous + r4 = dpt.equal(ar1, ar2, order="K") + assert r4.flags.c_contiguous + + ar1 = dpt.ones((20, 20), dtype="i4", order="F") + ar2 = dpt.ones((20, 20), dtype="i4", order="F") + r1 = dpt.equal(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.equal(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.equal(ar1, ar2, order="A") + assert r3.flags.f_contiguous + r4 = dpt.equal(ar1, ar2, order="K") + assert r4.flags.f_contiguous + + ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] + ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] + r4 = dpt.equal(ar1, ar2, order="K") + assert r4.strides == (20, -1) + + ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT + ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT + r4 = dpt.equal(ar1, ar2, order="K") + assert r4.strides == (-1, 20) + + +def test_equal_broadcasting(): + get_queue_or_skip() + + m = dpt.ones((100, 5), dtype="i4") + v = dpt.arange(5, dtype="i4") + + r = dpt.equal(m, v) + expected = np.full((100, 5), [False, True, False, False, False], dtype="?") + + assert (dpt.asnumpy(r) == expected).all() + + r2 = dpt.equal(v, m) + assert (dpt.asnumpy(r2) == expected).all() + + +@pytest.mark.parametrize("arr_dt", _all_dtypes) +def test_equal_python_scalar(arr_dt): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arr_dt, q) + + X = dpt.zeros((10, 10), dtype=arr_dt, sycl_queue=q) + py_zeros = ( + bool(0), + int(0), + float(0), + complex(0), + np.float32(0), + ctypes.c_int(0), + ) + for sc in py_zeros: + R = dpt.equal(X, sc) + assert isinstance(R, dpt.usm_ndarray) + assert dpt.all(R) + R = dpt.equal(sc, X) + assert isinstance(R, dpt.usm_ndarray) + assert dpt.all(R) + + +class MockArray: + def __init__(self, arr): + self.data_ = arr + + @property + def __sycl_usm_array_interface__(self): + return self.data_.__sycl_usm_array_interface__ + + +def test_equal_mock_array(): + get_queue_or_skip() + a = dpt.arange(10) + b = dpt.ones(10) + c = MockArray(b) + r = dpt.equal(a, c) + assert isinstance(r, dpt.usm_ndarray) + + +def test_equal_canary_mock_array(): + get_queue_or_skip() + a = dpt.arange(10) + + class Canary: + def __init__(self): + pass + + @property + def __sycl_usm_array_interface__(self): + return None + + c = Canary() + with pytest.raises(ValueError): + dpt.equal(a, c) From 4439bf745639ee6ae1701c2beb8d7cbf55a202a5 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 May 2023 15:16:50 -0500 Subject: [PATCH 37/48] Renamed struct to better reflect its purpose --- .../kernels/elementwise_functions/abs.hpp | 30 ++--- .../kernels/elementwise_functions/add.hpp | 111 ++++++++++-------- .../kernels/elementwise_functions/cos.hpp | 13 +- .../kernels/elementwise_functions/equal.hpp | 62 ++++++---- .../kernels/elementwise_functions/sqrt.hpp | 13 +- .../elementwise_functions/true_divide.hpp | 72 ++++++------ .../libtensor/include/utils/type_dispatch.hpp | 18 ++- 7 files changed, 180 insertions(+), 139 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index b233674f67..e0a84e5139 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -64,21 +64,21 @@ template struct AbsOutputType { using value_type = typename std::disjunction< // disjunction is C++17 // feature, supported by DPC++ - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, float>, - td_ns::TypeMapEntry, double>, - td_ns::DefaultEntry>::result_type; + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, float>, + td_ns::TypeMapResultEntry, double>, + td_ns::DefaultResultEntry>::result_type; }; template diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index 2884e06f19..56829c4c38 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -80,58 +80,65 @@ template struct AddOutputType { using value_type = typename std::disjunction< // disjunction is C++17 // feature, supported by DPC++ - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns:: - BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - T2, - std::complex, - std::complex>, - td_ns::BinaryTypeMapEntry, - T2, - std::complex, - std::complex>, - td_ns::DefaultEntry>::result_type; + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::DefaultResultEntry>::result_type; }; template struct CosOutputType { using value_type = typename std::disjunction< // disjunction is C++17 // feature, supported by DPC++ - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, std::complex>, - td_ns::TypeMapEntry, std::complex>, - td_ns::DefaultEntry>::result_type; + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, std::complex>, + td_ns:: + TypeMapResultEntry, std::complex>, + td_ns::DefaultResultEntry>::result_type; }; typedef sycl::event (*cos_contig_impl_fn_ptr_t)( diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp index 01b9393d06..a9f8adc92d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp @@ -83,29 +83,45 @@ template struct EqualOutputType { using value_type = typename std::disjunction< // disjunction is C++17 // feature, supported by DPC++ - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - T2, - std::complex, - bool>, - td_ns::BinaryTypeMapEntry, - T2, - std::complex, - bool>, - td_ns::DefaultEntry>::result_type; + td_ns::BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns:: + BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + bool>, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + bool>, + td_ns::DefaultResultEntry>::result_type; }; template struct SqrtOutputType { using value_type = typename std::disjunction< // disjunction is C++17 // feature, supported by DPC++ - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, - td_ns::TypeMapEntry, std::complex>, - td_ns::TypeMapEntry, std::complex>, - td_ns::DefaultEntry>::result_type; + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, std::complex>, + td_ns:: + TypeMapResultEntry, std::complex>, + td_ns::DefaultResultEntry>::result_type; }; typedef sycl::event (*sqrt_contig_impl_fn_ptr_t)( diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp index 58abf0cfd5..372a7bb128 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp @@ -81,40 +81,44 @@ template struct TrueDivideOutputType { using value_type = typename std::disjunction< // disjunction is C++17 // feature, supported by DPC++ - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - td_ns::BinaryTypeMapEntry, - T2, - std::complex, - std::complex>, - td_ns::BinaryTypeMapEntry, - T2, - float, - std::complex>, - td_ns::BinaryTypeMapEntry, - std::complex>, - td_ns::BinaryTypeMapEntry, - T2, - std::complex, - std::complex>, - td_ns::BinaryTypeMapEntry, - std::complex>, - td_ns::BinaryTypeMapEntry, - T2, - double, - std::complex>, - td_ns::DefaultEntry>::result_type; + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + T2, + float, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + T2, + double, + std::complex>, + td_ns::DefaultResultEntry>::result_type; }; template -struct TypeMapEntry : std::bool_constant> +struct TypeMapResultEntry : std::bool_constant> { using result_type = ResTy; }; @@ -264,7 +264,7 @@ template -struct BinaryTypeMapEntry +struct BinaryTypeMapResultEntry : std::bool_constant, std::is_same>> { @@ -272,7 +272,7 @@ struct BinaryTypeMapEntry }; /*! @brief fall-through struct with specified result_type, usually void */ -template struct DefaultEntry : std::true_type +template struct DefaultResultEntry : std::true_type { using result_type = Ty; }; @@ -368,6 +368,18 @@ template struct NullPtrTable value_type val; }; +template +struct TypePairDefinedEntry : std::bool_constant && + std::is_same_v> +{ + static constexpr bool is_defined = true; +}; + +struct NotDefinedEntry : std::true_type +{ + static constexpr bool is_defined = false; +}; + } // namespace type_dispatch } // namespace tensor From c6ef0751213ad7eadbf6fdaa0130107cde6f5080 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 May 2023 15:48:34 -0500 Subject: [PATCH 38/48] Use MemoryOverlap --- .../source/elementwise_functions.hpp | 46 +++++-------------- 1 file changed, 11 insertions(+), 35 deletions(-) diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.hpp b/dpctl/tensor/libtensor/source/elementwise_functions.hpp index b76e3da297..b4e8bac10d 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions.hpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions.hpp @@ -34,6 +34,7 @@ #include #include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" #include "utils/offset_utils.hpp" #include "utils/type_dispatch.hpp" @@ -122,23 +123,14 @@ py_unary_ufunc(dpctl::tensor::usm_ndarray src, } // check memory overlap - const char *src_data = src.get_data(); - char *dst_data = dst.get_data(); - - // check that arrays do not overlap, and concurrent copying is safe. - auto src_offsets = src.get_minmax_offsets(); - int src_elem_size = src.get_elemsize(); - int dst_elem_size = dst.get_elemsize(); - - bool memory_overlap = - ((dst_data - src_data > src_offsets.second * src_elem_size - - dst_offsets.first * dst_elem_size) && - (src_data - dst_data > dst_offsets.second * dst_elem_size - - src_offsets.first * src_elem_size)); - if (memory_overlap) { + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { throw py::value_error("Arrays index overlapping segments of memory"); } + const char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + // handle contiguous inputs bool is_src_c_contig = src.is_c_contiguous(); bool is_src_f_contig = src.is_f_contiguous(); @@ -378,32 +370,16 @@ std::pair py_binary_ufunc( } } + // check memory overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src1, dst) || overlap(src2, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } // check memory overlap const char *src1_data = src1.get_data(); const char *src2_data = src2.get_data(); char *dst_data = dst.get_data(); - // check that arrays do not overlap, and concurrent copying is safe. - auto src1_offsets = src1.get_minmax_offsets(); - int src1_elem_size = src1.get_elemsize(); - auto src2_offsets = src2.get_minmax_offsets(); - int src2_elem_size = src2.get_elemsize(); - int dst_elem_size = dst.get_elemsize(); - - bool memory_overlap_src1_dst = - ((dst_data - src1_data > src1_offsets.second * src1_elem_size - - dst_offsets.first * dst_elem_size) && - (src1_data - dst_data > dst_offsets.second * dst_elem_size - - src1_offsets.first * src1_elem_size)); - bool memory_overlap_src2_dst = - ((dst_data - src2_data > src2_offsets.second * src2_elem_size - - dst_offsets.first * dst_elem_size) && - (src2_data - dst_data > dst_offsets.second * dst_elem_size - - src2_offsets.first * src2_elem_size)); - if (memory_overlap_src1_dst || memory_overlap_src2_dst) { - throw py::value_error("Arrays index overlapping segments of memory"); - } - // handle contiguous inputs bool is_src1_c_contig = src1.is_c_contiguous(); bool is_src1_f_contig = src1.is_f_contiguous(); From 6bc4ba8202547ed57b544cc2be54b33e7b95ae9b Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 May 2023 16:51:45 -0500 Subject: [PATCH 39/48] Added wait to avoid dangling host_task --- dpctl/tensor/_elementwise_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index cc2f1ff679..46837ab1df 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -86,6 +86,7 @@ def __call__(self, x, order="K"): r = dpt.empty_like(buf, dtype=res_dt, order=order) ht, _ = self.unary_fn_(buf, r, sycl_queue=exec_q, depends=[copy_ev]) + ht_copy_ev.wait() ht.wait() return r From aebaf3a0c2871014e0609d7c79dfb41e42b2e4c0 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 18 May 2023 20:37:37 -0500 Subject: [PATCH 40/48] Fixed type of output multi_ptr per PR review remark by @ndgrigorian --- .../libtensor/include/kernels/elementwise_functions/common.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp index 7c567e00c6..bcf1623b61 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp @@ -99,7 +99,7 @@ struct UnaryContigFunctor sycl::multi_ptr; using out_ptrT = - sycl::multi_ptr; auto sg = ndit.get_sub_group(); From aef6de3d36c5594986408626ab04d7811d0b5144 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Thu, 18 May 2023 20:37:14 -0500 Subject: [PATCH 41/48] out keyword for elementwise functions --- dpctl/tensor/_elementwise_common.py | 246 ++++++++++++++++------- dpctl/tests/elementwise/test_abs.py | 17 ++ dpctl/tests/elementwise/test_add.py | 77 +++++++ dpctl/tests/elementwise/test_cos.py | 61 ++++++ dpctl/tests/elementwise/test_isfinite.py | 35 ++++ dpctl/tests/elementwise/test_isnan.py | 35 ++++ 6 files changed, 395 insertions(+), 76 deletions(-) diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index 46837ab1df..e7d1d2ca41 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -47,9 +47,33 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs): self.unary_fn_ = unary_dp_impl_fn self.__doc__ = docs - def __call__(self, x, order="K"): + def __call__(self, x, out=None, order="K"): if not isinstance(x, dpt.usm_ndarray): raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + + if out.shape != x.shape: + raise TypeError( + "The shape of input and output arrays are inconsistent." + f"Expected output shape is {x.shape}, got {out.shape}" + ) + + if ti._array_overlap(x, out): + raise TypeError("Input and output arrays have memory overlap") + + if ( + dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue)) + is None + ): + raise TypeError( + "Input and output allocation queues are not compatible" + ) + if order not in ["C", "F", "K", "A"]: order = "K" buf_dt, res_dt = _find_buf_dtype( @@ -59,17 +83,24 @@ def __call__(self, x, order="K"): raise RuntimeError exec_q = x.sycl_queue if buf_dt is None: - if order == "K": - r = _empty_like_orderK(x, res_dt) + if out is None: + if order == "K": + out = _empty_like_orderK(x, res_dt) + else: + if order == "A": + order = "F" if x.flags.f_contiguous else "C" + out = dpt.empty_like(x, dtype=res_dt, order=order) else: - if order == "A": - order = "F" if x.flags.f_contiguous else "C" - r = dpt.empty_like(x, dtype=res_dt, order=order) + if res_dt != out.dtype: + raise TypeError( + f"Expected output array of type {res_dt} is supported" + f", got {out.dtype}" + ) - ht, _ = self.unary_fn_(x, r, sycl_queue=exec_q) + ht, _ = self.unary_fn_(x, out, sycl_queue=exec_q) ht.wait() - return r + return out if order == "K": buf = _empty_like_orderK(x, buf_dt) else: @@ -80,16 +111,23 @@ def __call__(self, x, order="K"): ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=x, dst=buf, sycl_queue=exec_q ) - if order == "K": - r = _empty_like_orderK(buf, res_dt) + if out is None: + if order == "K": + out = _empty_like_orderK(buf, res_dt) + else: + out = dpt.empty_like(buf, dtype=res_dt, order=order) else: - r = dpt.empty_like(buf, dtype=res_dt, order=order) + if buf_dt != out.dtype: + raise TypeError( + f"Expected output array of type {buf_dt} is supported," + f"got {out.dtype}" + ) - ht, _ = self.unary_fn_(buf, r, sycl_queue=exec_q, depends=[copy_ev]) + ht, _ = self.unary_fn_(buf, out, sycl_queue=exec_q, depends=[copy_ev]) ht_copy_ev.wait() ht.wait() - return r + return out def _get_queue_usm_type(o): @@ -281,7 +319,7 @@ def __str__(self): def __repr__(self): return f"" - def __call__(self, o1, o2, order="K"): + def __call__(self, o1, o2, out=None, order="K"): if order not in ["K", "C", "F", "A"]: order = "K" q1, o1_usm_type = _get_queue_usm_type(o1) @@ -358,6 +396,31 @@ def __call__(self, o1, o2, order="K"): "supported types according to the casting rule ''safe''." ) + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + + if out.shape != o1_shape or out.shape != o2_shape: + raise TypeError( + "The shape of input and output arrays are inconsistent." + f"Expected output shape is {o1_shape}, got {out.shape}" + ) + + if ti._array_overlap(o1, out) or ti._array_overlap(o2, out): + raise TypeError("Input and output arrays have memory overlap") + + if ( + dpctl.utils.get_execution_queue( + (o1.sycl_queue, o2.sycl_queue, out.sycl_queue) + ) + is None + ): + raise TypeError( + "Input and output allocation queues are not compatible" + ) + if isinstance(o1, dpt.usm_ndarray): src1 = o1 else: @@ -368,37 +431,45 @@ def __call__(self, o1, o2, order="K"): src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q) if buf1_dt is None and buf2_dt is None: - if order == "K": - r = _empty_like_pair_orderK( - src1, src2, res_dt, res_usm_type, exec_q - ) - else: - if order == "A": - order = ( - "F" - if all( - arr.flags.f_contiguous - for arr in ( - src1, - src2, + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + src1, src2, res_dt, res_usm_type, exec_q + ) + else: + if order == "A": + order = ( + "F" + if all( + arr.flags.f_contiguous + for arr in ( + src1, + src2, + ) ) + else "C" ) - else "C" + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, ) - r = dpt.empty( - res_shape, - dtype=res_dt, - usm_type=res_usm_type, - sycl_queue=exec_q, - order=order, - ) + else: + if res_dt != out.dtype: + raise TypeError( + f"Output array of type {res_dt} is needed," + f"got {out.dtype}" + ) + src1 = dpt.broadcast_to(src1, res_shape) src2 = dpt.broadcast_to(src2, res_shape) ht_, _ = self.binary_fn_( - src1=src1, src2=src2, dst=r, sycl_queue=exec_q + src1=src1, src2=src2, dst=out, sycl_queue=exec_q ) ht_.wait() - return r + return out elif buf1_dt is None: if order == "K": buf2 = _empty_like_orderK(src2, buf2_dt) @@ -409,30 +480,38 @@ def __call__(self, o1, o2, order="K"): ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=src2, dst=buf2, sycl_queue=exec_q ) - if order == "K": - r = _empty_like_pair_orderK( - src1, buf2, res_dt, res_usm_type, exec_q - ) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + src1, buf2, res_dt, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) else: - r = dpt.empty( - res_shape, - dtype=res_dt, - usm_type=res_usm_type, - sycl_queue=exec_q, - order=order, - ) + if res_dt != out.dtype: + raise TypeError( + f"Output array of type {res_dt} is needed," + f"got {out.dtype}" + ) + src1 = dpt.broadcast_to(src1, res_shape) buf2 = dpt.broadcast_to(buf2, res_shape) ht_, _ = self.binary_fn_( src1=src1, src2=buf2, - dst=r, + dst=out, sycl_queue=exec_q, depends=[copy_ev], ) ht_copy_ev.wait() ht_.wait() - return r + return out elif buf2_dt is None: if order == "K": buf1 = _empty_like_orderK(src1, buf1_dt) @@ -443,30 +522,38 @@ def __call__(self, o1, o2, order="K"): ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=src1, dst=buf1, sycl_queue=exec_q ) - if order == "K": - r = _empty_like_pair_orderK( - buf1, src2, res_dt, res_usm_type, exec_q - ) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + buf1, src2, res_dt, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) else: - r = dpt.empty( - res_shape, - dtype=res_dt, - usm_type=res_usm_type, - sycl_queue=exec_q, - order=order, - ) + if res_dt != out.dtype: + raise TypeError( + f"Output array of type {res_dt} is needed," + f"got {out.dtype}" + ) + buf1 = dpt.broadcast_to(buf1, res_shape) src2 = dpt.broadcast_to(src2, res_shape) ht_, _ = self.binary_fn_( src1=buf1, src2=src2, - dst=r, + dst=out, sycl_queue=exec_q, depends=[copy_ev], ) ht_copy_ev.wait() ht_.wait() - return r + return out if order in ["K", "A"]: if src1.flags.f_contiguous and src2.flags.f_contiguous: @@ -489,26 +576,33 @@ def __call__(self, o1, o2, order="K"): ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=src2, dst=buf2, sycl_queue=exec_q ) - if order == "K": - r = _empty_like_pair_orderK( - buf1, buf2, res_dt, res_usm_type, exec_q - ) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + buf1, buf2, res_dt, res_usm_type, exec_q + ) + else: + out = dpt.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) else: - r = dpt.empty( - res_shape, - dtype=res_dt, - usm_type=res_usm_type, - sycl_queue=exec_q, - order=order, - ) + if res_dt != out.dtype: + raise TypeError( + f"Output array of type {res_dt} is needed, got {out.dtype}" + ) + buf1 = dpt.broadcast_to(buf1, res_shape) buf2 = dpt.broadcast_to(buf2, res_shape) ht_, _ = self.binary_fn_( src1=buf1, src2=buf2, - dst=r, + dst=out, sycl_queue=exec_q, depends=[copy1_ev, copy2_ev], ) dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_]) - return r + return out diff --git a/dpctl/tests/elementwise/test_abs.py b/dpctl/tests/elementwise/test_abs.py index 275be0d573..94bc05d918 100644 --- a/dpctl/tests/elementwise/test_abs.py +++ b/dpctl/tests/elementwise/test_abs.py @@ -89,3 +89,20 @@ def test_abs_complex(dtype): np.testing.assert_allclose( dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol ) + + +@pytest.mark.parametrize("dtype", _all_dtypes[:-2]) +def test_abs_out_keyword(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + arg_dt = np.dtype(dtype) + input_shape = (10, 10, 10, 10) + X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) + X[..., 0::2] = 1 + X[..., 1::2] = 0 + Y = dpt.empty_like(X, dtype=arg_dt) + dpt.abs(X, Y) + + expected_Y = dpt.asnumpy(X) + assert np.allclose(dpt.asnumpy(Y), expected_Y) diff --git a/dpctl/tests/elementwise/test_add.py b/dpctl/tests/elementwise/test_add.py index 81176f4f45..95104a7441 100644 --- a/dpctl/tests/elementwise/test_add.py +++ b/dpctl/tests/elementwise/test_add.py @@ -2,6 +2,7 @@ import numpy as np import pytest +from numpy.testing import assert_raises_regex import dpctl import dpctl.tensor as dpt @@ -165,3 +166,79 @@ def __sycl_usm_array_interface__(self): c = Canary() with pytest.raises(ValueError): dpt.add(a, c) + + +@pytest.mark.parametrize("op1_dtype", _all_dtypes) +@pytest.mark.parametrize("op2_dtype", _all_dtypes) +def test_add_dtype_out_keyword(op1_dtype, op2_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(op1_dtype, q) + skip_if_dtype_not_supported(op2_dtype, q) + + sz = 127 + ar1 = dpt.ones(sz, dtype=op1_dtype) + ar2 = dpt.ones_like(ar1, dtype=op2_dtype) + + r = dpt.add(ar1, ar2) + + y = dpt.zeros_like(ar1, dtype=r.dtype) + dpt.add(ar1, ar2, y) + + assert np.array_equal(dpt.asnumpy(r), dpt.asnumpy(y)) + + +def test_add_errors(): + ar1 = dpt.ones(2, dtype="float32", device="gpu") + ar2 = dpt.ones_like(ar1, dtype="int32", device="gpu") + y = dpt.empty_like(ar1, device="cpu") + assert_raises_regex( + TypeError, + "Input and output allocation queues are not compatible", + dpt.add, + ar1, + ar2, + y, + ) + + ar1 = dpt.ones(2, dtype="float32") + ar2 = dpt.ones_like(ar1, dtype="int32") + y = dpt.empty(3) + assert_raises_regex( + TypeError, + "The shape of input and output arrays are inconsistent", + dpt.add, + ar1, + ar2, + y, + ) + + ar1 = dpt.ones(2, dtype="float32") + ar2 = dpt.ones_like(ar1, dtype="int32") + y = ar1 + assert_raises_regex( + TypeError, + "Input and output arrays have memory overlap", + dpt.add, + ar1, + ar2, + y, + ) + + ar1 = dpt.ones(2, dtype="float32") + ar2 = dpt.ones_like(ar1, dtype="int32") + y = dpt.empty_like(ar1, dtype="int32") + assert_raises_regex( + TypeError, "Output array of type.*is needed", dpt.add, ar1, ar2, y + ) + + ar1 = dpt.ones(2, dtype="float32") + ar2 = dpt.ones_like(ar1, dtype="int32") + y = np.empty_like(ar1) + assert_raises_regex( + TypeError, + "output array must be of usm_ndarray type", + dpt.add, + ar1, + ar2, + y, + ) diff --git a/dpctl/tests/elementwise/test_cos.py b/dpctl/tests/elementwise/test_cos.py index 22588aea44..9397261596 100644 --- a/dpctl/tests/elementwise/test_cos.py +++ b/dpctl/tests/elementwise/test_cos.py @@ -2,6 +2,7 @@ import numpy as np import pytest +from numpy.testing import assert_raises_regex import dpctl.tensor as dpt from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported @@ -85,3 +86,63 @@ def test_cos_order(dtype): np.testing.assert_allclose( dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol ) + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"]) +def test_cos_out_keyword(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + n_seq = 100 + n_rep = 137 + + Xnp = np.linspace(-np.pi / 4, np.pi / 4, num=n_seq, dtype=dtype) + X = dpt.asarray(np.repeat(Xnp, n_rep), dtype=dtype, sycl_queue=q) + Y = dpt.empty_like(X, dtype=dtype) + + dpt.cos(X, Y) + tol = 8 * dpt.finfo(Y.dtype).resolution + + np.testing.assert_allclose( + dpt.asnumpy(Y), np.repeat(np.cos(Xnp), n_rep), atol=tol, rtol=tol + ) + + +def test_cos_errors(): + x = dpt.zeros(2, device="gpu") + y = dpt.empty_like(x, device="cpu") + assert_raises_regex( + TypeError, + "Input and output allocation queues are not compatible", + dpt.cos, + x, + y, + ) + + x = dpt.zeros(2) + y = dpt.empty(3) + assert_raises_regex( + TypeError, + "The shape of input and output arrays are inconsistent", + dpt.cos, + x, + y, + ) + + x = dpt.zeros(2) + y = x + assert_raises_regex( + TypeError, "Input and output arrays have memory overlap", dpt.cos, x, y + ) + + x = dpt.zeros(2, dtype="int32") + y = dpt.empty_like(x, dtype="int32") + assert_raises_regex( + TypeError, "Expected output array of type.*is supported", dpt.cos, x, y + ) + + x = dpt.zeros(2, dtype="float32") + y = np.empty_like(x) + assert_raises_regex( + TypeError, "output array must be of usm_ndarray type", dpt.cos, x, y + ) diff --git a/dpctl/tests/elementwise/test_isfinite.py b/dpctl/tests/elementwise/test_isfinite.py index 5cc9699cf8..fa3150fc4f 100644 --- a/dpctl/tests/elementwise/test_isfinite.py +++ b/dpctl/tests/elementwise/test_isfinite.py @@ -72,3 +72,38 @@ def test_isfinite_order(dtype): Y = dpt.isfinite(U, order=ord) expected_Y = np.full(Y.shape, True, dtype=Y.dtype) assert np.allclose(dpt.asnumpy(Y), expected_Y) + + +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_isnan_complex_out_keyword(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = complex(np.nan, np.nan) + y2 = complex(1, np.nan) + y3 = complex(np.nan, 1) + y4 = complex(2, 1) + y5 = complex(np.inf, 1) + + Ynp = np.repeat(np.array([y1, y2, y3, y4, y5], dtype=dtype), 12) + Y = dpt.asarray(Ynp, sycl_queue=q) + out = dpt.empty_like(Y, dtype="bool") + dpt.isfinite(Y, out) + assert np.array_equal(dpt.asnumpy(out)[()], np.isfinite(Ynp)) + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) +def test_isfinite_floats_out_keyword(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = np.nan + y2 = 1 + y3 = np.inf + + for mult in [123, 137, 255, 271, 272]: + Ynp = np.repeat(np.array([y1, y2, y3], dtype=dtype), mult) + Y = dpt.asarray(Ynp, sycl_queue=q) + out = dpt.empty_like(Y, dtype="bool") + dpt.isfinite(Y, out) + assert np.array_equal(dpt.asnumpy(out)[()], np.isfinite(Ynp)) diff --git a/dpctl/tests/elementwise/test_isnan.py b/dpctl/tests/elementwise/test_isnan.py index 8e983cb2dc..0796f0f0f8 100644 --- a/dpctl/tests/elementwise/test_isnan.py +++ b/dpctl/tests/elementwise/test_isnan.py @@ -72,3 +72,38 @@ def test_isnan_order(dtype): Y = dpt.isnan(U, order=ord) expected_Y = np.full(Y.shape, False, dtype=Y.dtype) assert np.allclose(dpt.asnumpy(Y), expected_Y) + + +@pytest.mark.parametrize("dtype", ["c8", "c16"]) +def test_isnan_complex_out_keyword(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = complex(np.nan, np.nan) + y2 = complex(1, np.nan) + y3 = complex(np.nan, 1) + y4 = complex(2, 1) + y5 = complex(np.inf, 1) + + Ynp = np.repeat(np.array([y1, y2, y3, y4, y5], dtype=dtype), 123) + Y = dpt.asarray(Ynp, sycl_queue=q) + out = dpt.empty_like(Y, dtype="bool") + dpt.isnan(Y, out) + assert np.array_equal(dpt.asnumpy(out)[()], np.isnan(Ynp)) + + +@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) +def test_isnan_floats_out_keyword(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + y1 = np.nan + y2 = 1 + y3 = np.inf + + for mult in [123, 137, 255, 271, 272]: + Ynp = np.repeat(np.array([y1, y2, y3], dtype=dtype), mult) + Y = dpt.asarray(Ynp, sycl_queue=q) + out = dpt.empty_like(Y, dtype="bool") + dpt.isnan(Y, out) + assert np.array_equal(dpt.asnumpy(out)[()], np.isnan(Ynp)) From 21e14158cfe49c8257fa550224a9bd226926cb3a Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Thu, 18 May 2023 22:05:10 -0500 Subject: [PATCH 42/48] update the tests --- dpctl/tensor/_elementwise_common.py | 2 +- dpctl/tests/elementwise/test_add.py | 27 ++++++++++++++++++++++++--- dpctl/tests/elementwise/test_cos.py | 14 ++++++++++++-- 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index e7d1d2ca41..d78e5fe2b2 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -402,7 +402,7 @@ def __call__(self, o1, o2, out=None, order="K"): f"output array must be of usm_ndarray type, got {type(out)}" ) - if out.shape != o1_shape or out.shape != o2_shape: + if out.shape != res_shape: raise TypeError( "The shape of input and output arrays are inconsistent." f"Expected output shape is {o1_shape}, got {out.shape}" diff --git a/dpctl/tests/elementwise/test_add.py b/dpctl/tests/elementwise/test_add.py index 95104a7441..ee0a311546 100644 --- a/dpctl/tests/elementwise/test_add.py +++ b/dpctl/tests/elementwise/test_add.py @@ -111,6 +111,18 @@ def test_add_broadcasting(): r2 = dpt.add(v, m) assert (dpt.asnumpy(r2) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all() + out = dpt.empty_like(m) + dpt.add(m, v, out) + assert ( + dpt.asnumpy(out) == np.arange(1, 6, dtype="i4")[np.newaxis, :] + ).all() + + out2 = dpt.empty_like(m) + dpt.add(v, m, out2) + assert ( + dpt.asnumpy(out2) == np.arange(1, 6, dtype="i4")[np.newaxis, :] + ).all() + @pytest.mark.parametrize("arr_dt", _all_dtypes) def test_add_python_scalar(arr_dt): @@ -188,9 +200,18 @@ def test_add_dtype_out_keyword(op1_dtype, op2_dtype): def test_add_errors(): - ar1 = dpt.ones(2, dtype="float32", device="gpu") - ar2 = dpt.ones_like(ar1, dtype="int32", device="gpu") - y = dpt.empty_like(ar1, device="cpu") + try: + gpu_queue = dpctl.SyclQueue("gpu") + except dpctl.SyclQueueCreationError: + pytest.skip("SyclQueue('gpu') failed, skipping") + try: + cpu_queue = dpctl.SyclQueue("cpu") + except dpctl.SyclQueueCreationError: + pytest.skip("SyclQueue('cpu') failed, skipping") + + ar1 = dpt.ones(2, dtype="float32", sycl_queue=gpu_queue) + ar2 = dpt.ones_like(ar1, sycl_queue=gpu_queue) + y = dpt.empty_like(ar1, sycl_queue=cpu_queue) assert_raises_regex( TypeError, "Input and output allocation queues are not compatible", diff --git a/dpctl/tests/elementwise/test_cos.py b/dpctl/tests/elementwise/test_cos.py index 9397261596..b8fff07bd9 100644 --- a/dpctl/tests/elementwise/test_cos.py +++ b/dpctl/tests/elementwise/test_cos.py @@ -4,6 +4,7 @@ import pytest from numpy.testing import assert_raises_regex +import dpctl import dpctl.tensor as dpt from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported @@ -109,8 +110,17 @@ def test_cos_out_keyword(dtype): def test_cos_errors(): - x = dpt.zeros(2, device="gpu") - y = dpt.empty_like(x, device="cpu") + try: + gpu_queue = dpctl.SyclQueue("gpu") + except dpctl.SyclQueueCreationError: + pytest.skip("SyclQueue('gpu') failed, skipping") + try: + cpu_queue = dpctl.SyclQueue("cpu") + except dpctl.SyclQueueCreationError: + pytest.skip("SyclQueue('cpu') failed, skipping") + + x = dpt.zeros(2, sycl_queue=gpu_queue) + y = dpt.empty_like(x, sycl_queue=cpu_queue) assert_raises_regex( TypeError, "Input and output allocation queues are not compatible", From 0820addc9d69b7e8cc6c3037c27cfb0431db4e78 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Fri, 19 May 2023 16:40:54 -0500 Subject: [PATCH 43/48] Removed unsed code Also consolidated various *impl_fn_ptr_t which were common among unary functions, and among binary functions into common reused types. --- .../kernels/elementwise_functions/abs.hpp | 19 -- .../kernels/elementwise_functions/add.hpp | 51 ----- .../kernels/elementwise_functions/common.hpp | 72 ++++++ .../kernels/elementwise_functions/cos.hpp | 19 -- .../kernels/elementwise_functions/equal.hpp | 209 ------------------ .../elementwise_functions/isfinite.hpp | 19 -- .../kernels/elementwise_functions/isinf.hpp | 19 -- .../kernels/elementwise_functions/isnan.hpp | 19 -- .../kernels/elementwise_functions/sqrt.hpp | 19 -- .../elementwise_functions/true_divide.hpp | 53 ----- .../source/elementwise_functions.cpp | 154 ++++++------- 11 files changed, 141 insertions(+), 512 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index e0a84e5139..83e1f2709a 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -84,13 +84,6 @@ template struct AbsOutputType template class abs_contig_kernel; -typedef sycl::event (*abs_contig_impl_fn_ptr_t)( - sycl::queue, - size_t, - const char *, - char *, - const std::vector &); - template sycl::event abs_contig_impl(sycl::queue exec_q, size_t nelems, @@ -153,18 +146,6 @@ using AbsStridedFunctor = elementwise_common:: template class abs_strided_kernel; -typedef sycl::event (*abs_strided_impl_fn_ptr_t)( - sycl::queue, - size_t, - int, - const py::ssize_t *, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &, - const std::vector &); - template sycl::event abs_strided_impl(sycl::queue exec_q, size_t nelems, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index 56829c4c38..f045dad1b6 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -148,17 +148,6 @@ template class add_contig_kernel; -typedef sycl::event (*add_contig_impl_fn_ptr_t)( - sycl::queue, - size_t, - const char *, - py::ssize_t, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &); - template sycl::event add_contig_impl(sycl::queue exec_q, size_t nelems, @@ -228,20 +217,6 @@ template struct AddTypeMapFactory template class add_strided_strided_kernel; -typedef sycl::event (*add_strided_impl_fn_ptr_t)( - sycl::queue, - size_t, - int, - const py::ssize_t *, - const char *, - py::ssize_t, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &, - const std::vector &); - template sycl::event add_strided_impl(sycl::queue exec_q, size_t nelems, @@ -299,19 +274,6 @@ template struct AddStridedFactory template class add_matrix_row_broadcast_sg_krn; -typedef sycl::event (*add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t)( - sycl::queue, - std::vector &, - size_t, - size_t, - const char *, - py::ssize_t, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &); - template using AddContigMatrixContigRowBroadcastingFunctor = elementwise_common::BinaryContigMatrixContigRowBroadcastingFunctor< @@ -420,19 +382,6 @@ struct AddContigMatrixContigRowBroadcastFactory } }; -typedef sycl::event (*add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t)( - sycl::queue, - std::vector &, - size_t, - size_t, - const char *, - py::ssize_t, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &); - template sycl::event add_contig_row_contig_matrix_broadcast_impl( sycl::queue exec_q, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp index bcf1623b61..bb46d4cdca 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp @@ -544,6 +544,78 @@ struct BinaryContigRowContigMatrixBroadcastingFunctor } }; +// Typdefs for function pointers + +typedef sycl::event (*unary_contig_impl_fn_ptr_t)( + sycl::queue, + size_t, + const char *, + char *, + const std::vector &); + +typedef sycl::event (*unary_strided_impl_fn_ptr_t)( + sycl::queue, + size_t, + int, + const py::ssize_t *, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &, + const std::vector &); + +typedef sycl::event (*binary_contig_impl_fn_ptr_t)( + sycl::queue, + size_t, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &); + +typedef sycl::event (*binary_strided_impl_fn_ptr_t)( + sycl::queue, + size_t, + int, + const py::ssize_t *, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &, + const std::vector &); + +typedef sycl::event (*binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t)( + sycl::queue, + std::vector &, + size_t, + size_t, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &); + +typedef sycl::event (*binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t)( + sycl::queue, + std::vector &, + size_t, + size_t, + const char *, + py::ssize_t, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &); + } // namespace elementwise_common } // namespace kernels } // namespace tensor diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index e0284591da..b69cf2698d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -69,13 +69,6 @@ template struct CosOutputType td_ns::DefaultResultEntry>::result_type; }; -typedef sycl::event (*cos_contig_impl_fn_ptr_t)( - sycl::queue, - size_t, - const char *, - char *, - const std::vector &); - template class cos_contig_kernel; @@ -137,18 +130,6 @@ template struct CosTypeMapFactory template class cos_strided_kernel; -typedef sycl::event (*cos_strided_impl_fn_ptr_t)( - sycl::queue, - size_t, - int, - const py::ssize_t *, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &, - const std::vector &); - template sycl::event cos_strided_impl(sycl::queue exec_q, size_t nelems, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp index a9f8adc92d..e7c118f216 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp @@ -131,17 +131,6 @@ template class equal_contig_kernel; -typedef sycl::event (*equal_contig_impl_fn_ptr_t)( - sycl::queue, - size_t, - const char *, - py::ssize_t, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &); - template sycl::event equal_contig_impl(sycl::queue exec_q, size_t nelems, @@ -211,20 +200,6 @@ template struct EqualTypeMapFactory template class equal_strided_strided_kernel; -typedef sycl::event (*equal_strided_impl_fn_ptr_t)( - sycl::queue, - size_t, - int, - const py::ssize_t *, - const char *, - py::ssize_t, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &, - const std::vector &); - template sycl::event equal_strided_impl(sycl::queue exec_q, @@ -281,190 +256,6 @@ template struct EqualStridedFactory } }; -template -class equal_matrix_row_broadcast_sg_krn; - -typedef sycl::event (*equal_contig_matrix_contig_row_broadcast_impl_fn_ptr_t)( - sycl::queue, - std::vector &, - size_t, - size_t, - const char *, - py::ssize_t, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &); - -template -using EqualContigMatrixContigRowBroadcastingFunctor = - elementwise_common::BinaryContigMatrixContigRowBroadcastingFunctor< - argT1, - argT2, - resT, - EqualFunctor>; - -template -sycl::event equal_contig_matrix_contig_row_broadcast_impl( - sycl::queue exec_q, - std::vector &host_tasks, - size_t n0, - size_t n1, - const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, - const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, - char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, - // res[i,j] = (mat[i,j] == vec[j]) - py::ssize_t res_offset, - const std::vector &depends = {}) -{ - const argT1 *mat = reinterpret_cast(mat_p) + mat_offset; - const argT2 *vec = reinterpret_cast(vec_p) + vec_offset; - resT *res = reinterpret_cast(res_p) + res_offset; - - const auto &dev = exec_q.get_device(); - const auto &sg_sizes = dev.get_info(); - // Get device-specific kernel info max_sub_group_size - size_t max_sgSize = - *(std::max_element(std::begin(sg_sizes), std::end(sg_sizes))); - - size_t n1_padded = n1 + max_sgSize; - argT2 *padded_vec = sycl::malloc_device(n1_padded, exec_q); - - if (padded_vec == nullptr) { - throw std::runtime_error("Could not allocate memory on the device"); - } - sycl::event make_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(depends); // ensure vec contains actual data - cgh.parallel_for({n1_padded}, [=](sycl::id<1> id) { - auto i = id[0]; - padded_vec[i] = vec[i % n1]; - }); - }); - - // sub-group spans work-items [I, I + sgSize) - // base = ndit.get_global_linear_id() - sg.get_local_id()[0] - // Generically, sg.load( &mat[base]) may load arrays from - // different rows of mat. The start corresponds to row (base / n0) - // We read sg.load(&padded_vec[(base / n0)]). The vector is padded to - // ensure that reads are accessible - - size_t lws = 64; - - sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(make_padded_vec_ev); - - auto lwsRange = sycl::range<1>(lws); - size_t n_elems = n0 * n1; - size_t n_groups = (n_elems + lws - 1) / lws; - auto gwsRange = sycl::range<1>(n_groups * lws); - - cgh.parallel_for< - class equal_matrix_row_broadcast_sg_krn>( - sycl::nd_range<1>(gwsRange, lwsRange), - EqualContigMatrixContigRowBroadcastingFunctor( - mat, padded_vec, res, n_elems, n1)); - }); - - sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(comp_ev); - sycl::context ctx = exec_q.get_context(); - cgh.host_task([ctx, padded_vec]() { sycl::free(padded_vec, ctx); }); - }); - host_tasks.push_back(tmp_cleanup_ev); - - return comp_ev; -} - -template -struct EqualContigMatrixContigRowBroadcastFactory -{ - fnT get() - { - using resT = typename EqualOutputType::value_type; - if constexpr (std::is_same_v) { - fnT fn = nullptr; - return fn; - } - else { - if constexpr (dpctl::tensor::type_utils::is_complex::value || - dpctl::tensor::type_utils::is_complex::value || - dpctl::tensor::type_utils::is_complex::value) - { - fnT fn = nullptr; - return fn; - } - else { - fnT fn = - equal_contig_matrix_contig_row_broadcast_impl; - return fn; - } - } - } -}; - -typedef sycl::event (*equal_contig_row_contig_matrix_broadcast_impl_fn_ptr_t)( - sycl::queue, - std::vector &, - size_t, - size_t, - const char *, - py::ssize_t, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &); - -template -sycl::event equal_contig_row_contig_matrix_broadcast_impl( - sycl::queue exec_q, - std::vector &host_tasks, - size_t n0, - size_t n1, - const char *vec_p, // typeless pointer to (n1,) contiguous row - py::ssize_t vec_offset, - const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix - py::ssize_t mat_offset, - char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, - // res[i,j] = (mat[i,j] == vec[j]) - py::ssize_t res_offset, - const std::vector &depends = {}) -{ - return equal_contig_matrix_contig_row_broadcast_impl( - exec_q, host_tasks, n0, n1, mat_p, mat_offset, vec_p, vec_offset, res_p, - res_offset, depends); -}; - -template -struct EqualContigRowContigMatrixBroadcastFactory -{ - fnT get() - { - using resT = typename EqualOutputType::value_type; - if constexpr (std::is_same_v) { - fnT fn = nullptr; - return fn; - } - else { - if constexpr (dpctl::tensor::type_utils::is_complex::value || - dpctl::tensor::type_utils::is_complex::value || - dpctl::tensor::type_utils::is_complex::value) - { - fnT fn = nullptr; - return fn; - } - else { - fnT fn = - equal_contig_row_contig_matrix_broadcast_impl; - return fn; - } - } - } -}; - } // namespace equal } // namespace kernels } // namespace tensor diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp index a157f42376..e1e88cda02 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp @@ -88,13 +88,6 @@ template struct IsFiniteOutputType using value_type = bool; }; -typedef sycl::event (*isfinite_contig_impl_fn_ptr_t)( - sycl::queue, - size_t, - const char *, - char *, - const std::vector &); - template class isfinite_contig_kernel; @@ -151,18 +144,6 @@ template struct IsFiniteTypeMapFactory template class isfinite_strided_kernel; -typedef sycl::event (*isfinite_strided_impl_fn_ptr_t)( - sycl::queue, - size_t, - int, - const py::ssize_t *, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &, - const std::vector &); - template sycl::event isfinite_strided_impl(sycl::queue exec_q, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp index 74f10fdc4e..05e7ce6f6d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp @@ -86,13 +86,6 @@ template struct IsInfOutputType using value_type = bool; }; -typedef sycl::event (*isinf_contig_impl_fn_ptr_t)( - sycl::queue, - size_t, - const char *, - char *, - const std::vector &); - template class isinf_contig_kernel; @@ -148,18 +141,6 @@ template struct IsInfTypeMapFactory template class isinf_strided_kernel; -typedef sycl::event (*isinf_strided_impl_fn_ptr_t)( - sycl::queue, - size_t, - int, - const py::ssize_t *, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &, - const std::vector &); - template sycl::event isinf_strided_impl(sycl::queue exec_q, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp index d592e884eb..edc62a5a1d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp @@ -191,13 +191,6 @@ template struct IsNanOutputType using value_type = bool; }; -typedef sycl::event (*isnan_contig_impl_fn_ptr_t)( - sycl::queue, - size_t, - const char *, - char *, - const std::vector &); - template class isnan_contig_kernel; @@ -295,18 +288,6 @@ struct IsNanStridedFunctorOld template class isnan_strided_kernel; -typedef sycl::event (*isnan_strided_impl_fn_ptr_t)( - sycl::queue, - size_t, - int, - const py::ssize_t *, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &, - const std::vector &); - template sycl::event isnan_strided_impl(sycl::queue exec_q, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp index 0d92a97176..7eaf7e2e93 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp @@ -69,13 +69,6 @@ template struct SqrtOutputType td_ns::DefaultResultEntry>::result_type; }; -typedef sycl::event (*sqrt_contig_impl_fn_ptr_t)( - sycl::queue, - size_t, - const char *, - char *, - const std::vector &); - template class sqrt_contig_kernel; @@ -138,18 +131,6 @@ template struct SqrtTypeMapFactory template class sqrt_strided_kernel; -typedef sycl::event (*sqrt_strided_impl_fn_ptr_t)( - sycl::queue, - size_t, - int, - const py::ssize_t *, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &, - const std::vector &); - template sycl::event sqrt_strided_impl(sycl::queue exec_q, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp index 372a7bb128..4d72893e58 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp @@ -128,17 +128,6 @@ template class true_divide_contig_kernel; -typedef sycl::event (*true_divide_contig_impl_fn_ptr_t)( - sycl::queue, - size_t, - const char *, - py::ssize_t, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &); - template sycl::event true_divide_contig_impl(sycl::queue exec_q, @@ -211,20 +200,6 @@ struct TrueDivideTypeMapFactory template class true_divide_strided_strided_kernel; -typedef sycl::event (*true_divide_strided_impl_fn_ptr_t)( - sycl::queue, - size_t, - int, - const py::ssize_t *, - const char *, - py::ssize_t, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &, - const std::vector &); - template sycl::event true_divide_strided_impl(sycl::queue exec_q, @@ -307,20 +282,6 @@ class true_divide_matrix_row_broadcast_sg_krn; template class true_divide_row_matrix_broadcast_sg_krn; -typedef sycl::event ( - *true_divide_contig_matrix_contig_row_broadcast_impl_fn_ptr_t)( - sycl::queue, - std::vector &, - size_t, - size_t, - const char *, - py::ssize_t, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &); - template sycl::event true_divide_contig_matrix_contig_row_broadcast_impl( sycl::queue exec_q, @@ -426,20 +387,6 @@ struct TrueDivideContigMatrixContigRowBroadcastFactory } }; -typedef sycl::event ( - *true_divide_contig_row_contig_matrix_broadcast_impl_fn_ptr_t)( - sycl::queue, - std::vector &, - size_t, - size_t, - const char *, - py::ssize_t, - const char *, - py::ssize_t, - char *, - py::ssize_t, - const std::vector &); - template sycl::event true_divide_contig_row_contig_matrix_broadcast_impl( sycl::queue exec_q, diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp index 0bd751ae1c..681c76905b 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions.cpp @@ -97,17 +97,24 @@ int _result_typeid(int arg_typeid, const int *fn_output_id) return fn_output_id[arg_typeid]; } +namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; +using ew_cmn_ns::binary_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_strided_impl_fn_ptr_t; +using ew_cmn_ns::unary_contig_impl_fn_ptr_t; +using ew_cmn_ns::unary_strided_impl_fn_ptr_t; + // U01: ==== ABS (x) namespace impl { namespace abs_fn_ns = dpctl::tensor::kernels::abs; -using abs_fn_ns::abs_contig_impl_fn_ptr_t; -using abs_fn_ns::abs_strided_impl_fn_ptr_t; -static abs_contig_impl_fn_ptr_t abs_contig_dispatch_vector[td_ns::num_types]; +static unary_contig_impl_fn_ptr_t abs_contig_dispatch_vector[td_ns::num_types]; static int abs_output_typeid_vector[td_ns::num_types]; -static abs_strided_impl_fn_ptr_t abs_strided_dispatch_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + abs_strided_dispatch_vector[td_ns::num_types]; void populate_abs_dispatch_vectors(void) { @@ -115,12 +122,13 @@ void populate_abs_dispatch_vectors(void) namespace fn_ns = abs_fn_ns; using fn_ns::AbsContigFactory; - DispatchVectorBuilder + DispatchVectorBuilder dvb1; dvb1.populate_dispatch_vector(abs_contig_dispatch_vector); using fn_ns::AbsStridedFactory; - DispatchVectorBuilder dvb2; dvb2.populate_dispatch_vector(abs_strided_dispatch_vector); @@ -149,25 +157,20 @@ namespace impl { namespace add_fn_ns = dpctl::tensor::kernels::add; -using add_fn_ns::add_contig_impl_fn_ptr_t; -using add_fn_ns::add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; -using add_fn_ns::add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; -using add_fn_ns::add_strided_impl_fn_ptr_t; - -static add_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; +static binary_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; static int add_output_id_table[td_ns::num_types][td_ns::num_types]; -static add_strided_impl_fn_ptr_t add_strided_dispatch_table[td_ns::num_types] - [td_ns::num_types]; +static binary_strided_impl_fn_ptr_t + add_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; // add(matrix, row) -static add_contig_matrix_contig_row_broadcast_impl_fn_ptr_t +static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t add_contig_matrix_contig_row_broadcast_dispatch_table[td_ns::num_types] [td_ns::num_types]; // add(row, matrix) -static add_contig_row_contig_matrix_broadcast_impl_fn_ptr_t +static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t add_contig_row_contig_matrix_broadcast_dispatch_table[td_ns::num_types] [td_ns::num_types]; @@ -183,22 +186,24 @@ void populate_add_dispatch_tables(void) // function pointers for operation on general strided arrays using fn_ns::AddStridedFactory; - DispatchTableBuilder dtb2; dtb2.populate_dispatch_table(add_strided_dispatch_table); // function pointers for operation on contiguous inputs and output using fn_ns::AddContigFactory; - DispatchTableBuilder + DispatchTableBuilder dtb3; dtb3.populate_dispatch_table(add_contig_dispatch_table); // function pointers for operation on contiguous matrix, contiguous row // with contiguous matrix output using fn_ns::AddContigMatrixContigRowBroadcastFactory; - DispatchTableBuilder + DispatchTableBuilder< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, + AddContigMatrixContigRowBroadcastFactory, num_types> dtb4; dtb4.populate_dispatch_table( add_contig_matrix_contig_row_broadcast_dispatch_table); @@ -206,8 +211,9 @@ void populate_add_dispatch_tables(void) // function pointers for operation on contiguous row, contiguous matrix // with contiguous matrix output using fn_ns::AddContigRowContigMatrixBroadcastFactory; - DispatchTableBuilder + DispatchTableBuilder< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, + AddContigRowContigMatrixBroadcastFactory, num_types> dtb5; dtb5.populate_dispatch_table( add_contig_row_contig_matrix_broadcast_dispatch_table); @@ -298,12 +304,11 @@ namespace impl { namespace cos_fn_ns = dpctl::tensor::kernels::cos; -using cos_fn_ns::cos_contig_impl_fn_ptr_t; -using cos_fn_ns::cos_strided_impl_fn_ptr_t; -static cos_contig_impl_fn_ptr_t cos_contig_dispatch_vector[td_ns::num_types]; +static unary_contig_impl_fn_ptr_t cos_contig_dispatch_vector[td_ns::num_types]; static int cos_output_typeid_vector[td_ns::num_types]; -static cos_strided_impl_fn_ptr_t cos_strided_dispatch_vector[td_ns::num_types]; +static unary_strided_impl_fn_ptr_t + cos_strided_dispatch_vector[td_ns::num_types]; void populate_cos_dispatch_vectors(void) { @@ -311,12 +316,13 @@ void populate_cos_dispatch_vectors(void) namespace fn_ns = cos_fn_ns; using fn_ns::CosContigFactory; - DispatchVectorBuilder + DispatchVectorBuilder dvb1; dvb1.populate_dispatch_vector(cos_contig_dispatch_vector); using fn_ns::CosStridedFactory; - DispatchVectorBuilder dvb2; dvb2.populate_dispatch_vector(cos_strided_dispatch_vector); @@ -339,27 +345,20 @@ namespace impl { namespace true_divide_fn_ns = dpctl::tensor::kernels::true_divide; -using true_divide_fn_ns::true_divide_contig_impl_fn_ptr_t; -using true_divide_fn_ns:: - true_divide_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; -using true_divide_fn_ns:: - true_divide_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; -using true_divide_fn_ns::true_divide_strided_impl_fn_ptr_t; - -static true_divide_contig_impl_fn_ptr_t +static binary_contig_impl_fn_ptr_t true_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; static int true_divide_output_id_table[td_ns::num_types][td_ns::num_types]; -static true_divide_strided_impl_fn_ptr_t +static binary_strided_impl_fn_ptr_t true_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; // divide(matrix, row) -static true_divide_contig_matrix_contig_row_broadcast_impl_fn_ptr_t +static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t true_divide_contig_matrix_contig_row_broadcast_dispatch_table [td_ns::num_types][td_ns::num_types]; // divide(row, matrix) -static true_divide_contig_row_contig_matrix_broadcast_impl_fn_ptr_t +static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t true_divide_contig_row_contig_matrix_broadcast_dispatch_table [td_ns::num_types][td_ns::num_types]; @@ -375,15 +374,15 @@ void populate_true_divide_dispatch_tables(void) // function pointers for operation on general strided arrays using fn_ns::TrueDivideStridedFactory; - DispatchTableBuilder + DispatchTableBuilder dtb2; dtb2.populate_dispatch_table(true_divide_strided_dispatch_table); // function pointers for operation on contiguous inputs and output using fn_ns::TrueDivideContigFactory; - DispatchTableBuilder + DispatchTableBuilder dtb3; dtb3.populate_dispatch_table(true_divide_contig_dispatch_table); @@ -391,7 +390,7 @@ void populate_true_divide_dispatch_tables(void) // with contiguous matrix output using fn_ns::TrueDivideContigMatrixContigRowBroadcastFactory; DispatchTableBuilder< - true_divide_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, TrueDivideContigMatrixContigRowBroadcastFactory, num_types> dtb4; dtb4.populate_dispatch_table( @@ -401,7 +400,7 @@ void populate_true_divide_dispatch_tables(void) // with contiguous matrix output using fn_ns::TrueDivideContigRowContigMatrixBroadcastFactory; DispatchTableBuilder< - true_divide_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, TrueDivideContigRowContigMatrixBroadcastFactory, num_types> dtb5; dtb5.populate_dispatch_table( @@ -415,16 +414,11 @@ namespace impl { namespace equal_fn_ns = dpctl::tensor::kernels::equal; -using equal_fn_ns::equal_contig_impl_fn_ptr_t; -using equal_fn_ns::equal_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; -using equal_fn_ns::equal_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; -using equal_fn_ns::equal_strided_impl_fn_ptr_t; - -static equal_contig_impl_fn_ptr_t equal_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; +static binary_contig_impl_fn_ptr_t + equal_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; static int equal_output_id_table[td_ns::num_types][td_ns::num_types]; -static equal_strided_impl_fn_ptr_t +static binary_strided_impl_fn_ptr_t equal_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; void populate_equal_dispatch_tables(void) @@ -439,14 +433,14 @@ void populate_equal_dispatch_tables(void) // function pointers for operation on general strided arrays using fn_ns::EqualStridedFactory; - DispatchTableBuilder dtb2; dtb2.populate_dispatch_table(equal_strided_dispatch_table); // function pointers for operation on contiguous inputs and output using fn_ns::EqualContigFactory; - DispatchTableBuilder dtb3; dtb3.populate_dispatch_table(equal_contig_dispatch_table); @@ -499,13 +493,11 @@ namespace impl namespace impl { namespace isfinite_fn_ns = dpctl::tensor::kernels::isfinite; -using isfinite_fn_ns::isfinite_contig_impl_fn_ptr_t; -using isfinite_fn_ns::isfinite_strided_impl_fn_ptr_t; -static isfinite_contig_impl_fn_ptr_t +static unary_contig_impl_fn_ptr_t isfinite_contig_dispatch_vector[td_ns::num_types]; static int isfinite_output_typeid_vector[td_ns::num_types]; -static isfinite_strided_impl_fn_ptr_t +static unary_strided_impl_fn_ptr_t isfinite_strided_dispatch_vector[td_ns::num_types]; void populate_isfinite_dispatch_vectors(void) @@ -514,14 +506,14 @@ void populate_isfinite_dispatch_vectors(void) namespace fn_ns = isfinite_fn_ns; using fn_ns::IsFiniteContigFactory; - DispatchVectorBuilder dvb1; dvb1.populate_dispatch_vector(isfinite_contig_dispatch_vector); using fn_ns::IsFiniteStridedFactory; - DispatchVectorBuilder + DispatchVectorBuilder dvb2; dvb2.populate_dispatch_vector(isfinite_strided_dispatch_vector); @@ -536,13 +528,11 @@ void populate_isfinite_dispatch_vectors(void) namespace impl { namespace isinf_fn_ns = dpctl::tensor::kernels::isinf; -using isinf_fn_ns::isinf_contig_impl_fn_ptr_t; -using isinf_fn_ns::isinf_strided_impl_fn_ptr_t; -static isinf_contig_impl_fn_ptr_t +static unary_contig_impl_fn_ptr_t isinf_contig_dispatch_vector[td_ns::num_types]; static int isinf_output_typeid_vector[td_ns::num_types]; -static isinf_strided_impl_fn_ptr_t +static unary_strided_impl_fn_ptr_t isinf_strided_dispatch_vector[td_ns::num_types]; void populate_isinf_dispatch_vectors(void) @@ -551,13 +541,13 @@ void populate_isinf_dispatch_vectors(void) namespace fn_ns = isinf_fn_ns; using fn_ns::IsInfContigFactory; - DispatchVectorBuilder dvb1; dvb1.populate_dispatch_vector(isinf_contig_dispatch_vector); using fn_ns::IsInfStridedFactory; - DispatchVectorBuilder dvb2; dvb2.populate_dispatch_vector(isinf_strided_dispatch_vector); @@ -573,13 +563,11 @@ void populate_isinf_dispatch_vectors(void) namespace impl { namespace isnan_fn_ns = dpctl::tensor::kernels::isnan; -using isnan_fn_ns::isnan_contig_impl_fn_ptr_t; -using isnan_fn_ns::isnan_strided_impl_fn_ptr_t; -static isnan_contig_impl_fn_ptr_t +static unary_contig_impl_fn_ptr_t isnan_contig_dispatch_vector[td_ns::num_types]; static int isnan_output_typeid_vector[td_ns::num_types]; -static isnan_strided_impl_fn_ptr_t +static unary_strided_impl_fn_ptr_t isnan_strided_dispatch_vector[td_ns::num_types]; void populate_isnan_dispatch_vectors(void) @@ -588,13 +576,13 @@ void populate_isnan_dispatch_vectors(void) namespace fn_ns = isnan_fn_ns; using fn_ns::IsNanContigFactory; - DispatchVectorBuilder dvb1; dvb1.populate_dispatch_vector(isnan_contig_dispatch_vector); using fn_ns::IsNanStridedFactory; - DispatchVectorBuilder dvb2; dvb2.populate_dispatch_vector(isnan_strided_dispatch_vector); @@ -749,12 +737,10 @@ namespace impl { namespace sqrt_fn_ns = dpctl::tensor::kernels::sqrt; -using sqrt_fn_ns::sqrt_contig_impl_fn_ptr_t; -using sqrt_fn_ns::sqrt_strided_impl_fn_ptr_t; -static sqrt_contig_impl_fn_ptr_t sqrt_contig_dispatch_vector[td_ns::num_types]; +static unary_contig_impl_fn_ptr_t sqrt_contig_dispatch_vector[td_ns::num_types]; static int sqrt_output_typeid_vector[td_ns::num_types]; -static sqrt_strided_impl_fn_ptr_t +static unary_strided_impl_fn_ptr_t sqrt_strided_dispatch_vector[td_ns::num_types]; void populate_sqrt_dispatch_vectors(void) @@ -763,13 +749,13 @@ void populate_sqrt_dispatch_vectors(void) namespace fn_ns = sqrt_fn_ns; using fn_ns::SqrtContigFactory; - DispatchVectorBuilder dvb1; dvb1.populate_dispatch_vector(sqrt_contig_dispatch_vector); using fn_ns::SqrtStridedFactory; - DispatchVectorBuilder dvb2; dvb2.populate_dispatch_vector(sqrt_strided_dispatch_vector); @@ -1010,13 +996,11 @@ void init_elementwise_functions(py::module_ m) // function pointers to handle operation of c-contig matrix and // c-contig row with broadcasting (may be nullptr) td_ns::NullPtrTable< - impl:: - equal_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, // function pointers to handle operation of c-contig matrix and // c-contig row with broadcasting (may be nullptr) td_ns::NullPtrTable< - impl:: - equal_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); }; auto equal_result_type_pyapi = [&](py::dtype dtype1, py::dtype dtype2) { return py_binary_ufunc_result_type(dtype1, dtype2, From 121f819be41909f2a634b2dbe8307e6bf9de0408 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Fri, 19 May 2023 19:54:55 -0500 Subject: [PATCH 44/48] add new tests to improve coverage --- dpctl/tensor/_elementwise_common.py | 7 ++- dpctl/tests/elementwise/test_abs.py | 25 ++++------- dpctl/tests/elementwise/test_add.py | 57 ++++++++++++++---------- dpctl/tests/elementwise/test_cos.py | 53 +++++++++++----------- dpctl/tests/elementwise/test_isfinite.py | 43 ++++-------------- dpctl/tests/elementwise/test_isnan.py | 43 ++++-------------- 6 files changed, 87 insertions(+), 141 deletions(-) diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index d78e5fe2b2..6677670e73 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -93,8 +93,8 @@ def __call__(self, x, out=None, order="K"): else: if res_dt != out.dtype: raise TypeError( - f"Expected output array of type {res_dt} is supported" - f", got {out.dtype}" + f"Output array of type {res_dt} is needed," + f" got {out.dtype}" ) ht, _ = self.unary_fn_(x, out, sycl_queue=exec_q) @@ -119,8 +119,7 @@ def __call__(self, x, out=None, order="K"): else: if buf_dt != out.dtype: raise TypeError( - f"Expected output array of type {buf_dt} is supported," - f"got {out.dtype}" + f"Output array of type {buf_dt} is needed, got {out.dtype}" ) ht, _ = self.unary_fn_(buf, out, sycl_queue=exec_q, depends=[copy_ev]) diff --git a/dpctl/tests/elementwise/test_abs.py b/dpctl/tests/elementwise/test_abs.py index 94bc05d918..2e2ff69ea4 100644 --- a/dpctl/tests/elementwise/test_abs.py +++ b/dpctl/tests/elementwise/test_abs.py @@ -22,9 +22,17 @@ def test_abs_out_type(dtype): np.dtype("c16"): np.dtype("f8"), } assert dpt.abs(X).dtype == type_map[arg_dt] + + out = dpt.empty_like(X, dtype=type_map[arg_dt]) + dpt.abs(X, out) + assert np.allclose(dpt.asnumpy(out), dpt.asnumpy(dpt.abs(X))) else: assert dpt.abs(X).dtype == arg_dt + out = dpt.empty_like(X, dtype=arg_dt) + dpt.abs(X, out) + assert np.allclose(dpt.asnumpy(out), dpt.asnumpy(dpt.abs(X))) + @pytest.mark.parametrize("usm_type", _usm_types) def test_abs_usm_type(usm_type): @@ -89,20 +97,3 @@ def test_abs_complex(dtype): np.testing.assert_allclose( dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol ) - - -@pytest.mark.parametrize("dtype", _all_dtypes[:-2]) -def test_abs_out_keyword(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - arg_dt = np.dtype(dtype) - input_shape = (10, 10, 10, 10) - X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q) - X[..., 0::2] = 1 - X[..., 1::2] = 0 - Y = dpt.empty_like(X, dtype=arg_dt) - dpt.abs(X, Y) - - expected_Y = dpt.asnumpy(X) - assert np.allclose(dpt.asnumpy(Y), expected_Y) diff --git a/dpctl/tests/elementwise/test_add.py b/dpctl/tests/elementwise/test_add.py index ee0a311546..8a17b4f761 100644 --- a/dpctl/tests/elementwise/test_add.py +++ b/dpctl/tests/elementwise/test_add.py @@ -7,6 +7,7 @@ import dpctl import dpctl.tensor as dpt from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported +from dpctl.utils import ExecutionPlacementError from .utils import _all_dtypes, _compare_dtypes, _usm_types @@ -32,6 +33,10 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype): assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all() assert r.sycl_queue == ar1.sycl_queue + out = dpt.empty_like(ar1, dtype=expected_dtype) + dpt.add(ar1, ar2, out) + assert (dpt.asnumpy(out) == np.full(out.shape, 2, dtype=out.dtype)).all() + ar3 = dpt.ones(sz, dtype=op1_dtype) ar4 = dpt.ones(2 * sz, dtype=op2_dtype) @@ -44,6 +49,10 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype): assert r.shape == ar3.shape assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all() + out = dpt.empty_like(ar1, dtype=expected_dtype) + dpt.add(ar3[::-1], ar4[::2], out) + assert (dpt.asnumpy(out) == np.full(out.shape, 2, dtype=out.dtype)).all() + @pytest.mark.parametrize("op1_usm_type", _usm_types) @pytest.mark.parametrize("op2_usm_type", _usm_types) @@ -105,7 +114,6 @@ def test_add_broadcasting(): v = dpt.arange(5, dtype="i4") r = dpt.add(m, v) - assert (dpt.asnumpy(r) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all() r2 = dpt.add(v, m) @@ -180,26 +188,8 @@ def __sycl_usm_array_interface__(self): dpt.add(a, c) -@pytest.mark.parametrize("op1_dtype", _all_dtypes) -@pytest.mark.parametrize("op2_dtype", _all_dtypes) -def test_add_dtype_out_keyword(op1_dtype, op2_dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(op1_dtype, q) - skip_if_dtype_not_supported(op2_dtype, q) - - sz = 127 - ar1 = dpt.ones(sz, dtype=op1_dtype) - ar2 = dpt.ones_like(ar1, dtype=op2_dtype) - - r = dpt.add(ar1, ar2) - - y = dpt.zeros_like(ar1, dtype=r.dtype) - dpt.add(ar1, ar2, y) - - assert np.array_equal(dpt.asnumpy(r), dpt.asnumpy(y)) - - def test_add_errors(): + get_queue_or_skip() try: gpu_queue = dpctl.SyclQueue("gpu") except dpctl.SyclQueueCreationError: @@ -245,11 +235,14 @@ def test_add_errors(): y, ) - ar1 = dpt.ones(2, dtype="float32") - ar2 = dpt.ones_like(ar1, dtype="int32") - y = dpt.empty_like(ar1, dtype="int32") + ar1 = np.ones(2, dtype="float32") + ar2 = np.ones_like(ar1, dtype="int32") assert_raises_regex( - TypeError, "Output array of type.*is needed", dpt.add, ar1, ar2, y + ExecutionPlacementError, + "Execution placement can not be unambiguously inferred.*", + dpt.add, + ar1, + ar2, ) ar1 = dpt.ones(2, dtype="float32") @@ -263,3 +256,19 @@ def test_add_errors(): ar2, y, ) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_add_dtype_error( + dtype, +): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + ar1 = dpt.ones(5, dtype=dtype) + ar2 = dpt.ones_like(ar1, dtype="f8") + + y = dpt.zeros_like(ar1, dtype="int8") + assert_raises_regex( + TypeError, "Output array of type.*is needed", dpt.add, ar1, ar2, y + ) diff --git a/dpctl/tests/elementwise/test_cos.py b/dpctl/tests/elementwise/test_cos.py index b8fff07bd9..0d1e641b5f 100644 --- a/dpctl/tests/elementwise/test_cos.py +++ b/dpctl/tests/elementwise/test_cos.py @@ -21,6 +21,13 @@ def test_cos_out_type(dtype): expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device) assert dpt.cos(X).dtype == expected_dtype + X = dpt.asarray(0, dtype=dtype, sycl_queue=q) + expected_dtype = np.cos(np.array(0, dtype=dtype)).dtype + expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device) + Y = dpt.empty_like(X, dtype=expected_dtype) + dpt.cos(X, Y) + np.testing.assert_allclose(dpt.asnumpy(dpt.cos(X)), dpt.asnumpy(Y)) + @pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"]) def test_cos_output(dtype): @@ -40,6 +47,13 @@ def test_cos_output(dtype): dpt.asnumpy(Y), np.repeat(np.cos(Xnp), n_rep), atol=tol, rtol=tol ) + Z = dpt.empty_like(X, dtype=dtype) + dpt.cos(X, Z) + + np.testing.assert_allclose( + dpt.asnumpy(Z), np.repeat(np.cos(Xnp), n_rep), atol=tol, rtol=tol + ) + @pytest.mark.parametrize("usm_type", ["device", "shared", "host"]) def test_cos_usm_type(usm_type): @@ -89,27 +103,8 @@ def test_cos_order(dtype): ) -@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"]) -def test_cos_out_keyword(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - n_seq = 100 - n_rep = 137 - - Xnp = np.linspace(-np.pi / 4, np.pi / 4, num=n_seq, dtype=dtype) - X = dpt.asarray(np.repeat(Xnp, n_rep), dtype=dtype, sycl_queue=q) - Y = dpt.empty_like(X, dtype=dtype) - - dpt.cos(X, Y) - tol = 8 * dpt.finfo(Y.dtype).resolution - - np.testing.assert_allclose( - dpt.asnumpy(Y), np.repeat(np.cos(Xnp), n_rep), atol=tol, rtol=tol - ) - - def test_cos_errors(): + get_queue_or_skip() try: gpu_queue = dpctl.SyclQueue("gpu") except dpctl.SyclQueueCreationError: @@ -145,14 +140,20 @@ def test_cos_errors(): TypeError, "Input and output arrays have memory overlap", dpt.cos, x, y ) - x = dpt.zeros(2, dtype="int32") - y = dpt.empty_like(x, dtype="int32") - assert_raises_regex( - TypeError, "Expected output array of type.*is supported", dpt.cos, x, y - ) - x = dpt.zeros(2, dtype="float32") y = np.empty_like(x) assert_raises_regex( TypeError, "output array must be of usm_ndarray type", dpt.cos, x, y ) + + +@pytest.mark.parametrize("dtype", _all_dtypes) +def test_cos_error_dtype(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + + x = dpt.zeros(5, dtype=dtype) + y = dpt.empty_like(x, dtype="int16") + assert_raises_regex( + TypeError, "Output array of type.*is needed", dpt.cos, x, y + ) diff --git a/dpctl/tests/elementwise/test_isfinite.py b/dpctl/tests/elementwise/test_isfinite.py index fa3150fc4f..7f86277c07 100644 --- a/dpctl/tests/elementwise/test_isfinite.py +++ b/dpctl/tests/elementwise/test_isfinite.py @@ -41,6 +41,10 @@ def test_isfinite_complex(dtype): Y = dpt.asarray(Ynp, sycl_queue=q) assert np.array_equal(dpt.asnumpy(dpt.isfinite(Y)), np.isfinite(Ynp)) + out = dpt.empty_like(Y, dtype="bool") + dpt.isfinite(Y, out) + assert np.array_equal(dpt.asnumpy(out)[()], np.isfinite(Ynp)) + @pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) def test_isfinite_floats(dtype): @@ -56,6 +60,10 @@ def test_isfinite_floats(dtype): Y = dpt.asarray(Ynp, sycl_queue=q) assert np.array_equal(dpt.asnumpy(dpt.isfinite(Y)), np.isfinite(Ynp)) + out = dpt.empty_like(Y, dtype="bool") + dpt.isfinite(Y, out) + assert np.array_equal(dpt.asnumpy(out)[()], np.isfinite(Ynp)) + @pytest.mark.parametrize("dtype", _all_dtypes) def test_isfinite_order(dtype): @@ -72,38 +80,3 @@ def test_isfinite_order(dtype): Y = dpt.isfinite(U, order=ord) expected_Y = np.full(Y.shape, True, dtype=Y.dtype) assert np.allclose(dpt.asnumpy(Y), expected_Y) - - -@pytest.mark.parametrize("dtype", ["c8", "c16"]) -def test_isnan_complex_out_keyword(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - y1 = complex(np.nan, np.nan) - y2 = complex(1, np.nan) - y3 = complex(np.nan, 1) - y4 = complex(2, 1) - y5 = complex(np.inf, 1) - - Ynp = np.repeat(np.array([y1, y2, y3, y4, y5], dtype=dtype), 12) - Y = dpt.asarray(Ynp, sycl_queue=q) - out = dpt.empty_like(Y, dtype="bool") - dpt.isfinite(Y, out) - assert np.array_equal(dpt.asnumpy(out)[()], np.isfinite(Ynp)) - - -@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) -def test_isfinite_floats_out_keyword(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - y1 = np.nan - y2 = 1 - y3 = np.inf - - for mult in [123, 137, 255, 271, 272]: - Ynp = np.repeat(np.array([y1, y2, y3], dtype=dtype), mult) - Y = dpt.asarray(Ynp, sycl_queue=q) - out = dpt.empty_like(Y, dtype="bool") - dpt.isfinite(Y, out) - assert np.array_equal(dpt.asnumpy(out)[()], np.isfinite(Ynp)) diff --git a/dpctl/tests/elementwise/test_isnan.py b/dpctl/tests/elementwise/test_isnan.py index 0796f0f0f8..74d922b6db 100644 --- a/dpctl/tests/elementwise/test_isnan.py +++ b/dpctl/tests/elementwise/test_isnan.py @@ -41,6 +41,10 @@ def test_isnan_complex(dtype): Y = dpt.asarray(Ynp, sycl_queue=q) assert np.array_equal(dpt.asnumpy(dpt.isnan(Y)), np.isnan(Ynp)) + out = dpt.empty_like(Y, dtype="bool") + dpt.isnan(Y, out) + assert np.array_equal(dpt.asnumpy(out)[()], np.isnan(Ynp)) + @pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) def test_isnan_floats(dtype): @@ -56,6 +60,10 @@ def test_isnan_floats(dtype): Y = dpt.asarray(Ynp, sycl_queue=q) assert np.array_equal(dpt.asnumpy(dpt.isnan(Y)), np.isnan(Ynp)) + out = dpt.empty_like(Y, dtype="bool") + dpt.isnan(Y, out) + assert np.array_equal(dpt.asnumpy(out)[()], np.isnan(Ynp)) + @pytest.mark.parametrize("dtype", _all_dtypes) def test_isnan_order(dtype): @@ -72,38 +80,3 @@ def test_isnan_order(dtype): Y = dpt.isnan(U, order=ord) expected_Y = np.full(Y.shape, False, dtype=Y.dtype) assert np.allclose(dpt.asnumpy(Y), expected_Y) - - -@pytest.mark.parametrize("dtype", ["c8", "c16"]) -def test_isnan_complex_out_keyword(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - y1 = complex(np.nan, np.nan) - y2 = complex(1, np.nan) - y3 = complex(np.nan, 1) - y4 = complex(2, 1) - y5 = complex(np.inf, 1) - - Ynp = np.repeat(np.array([y1, y2, y3, y4, y5], dtype=dtype), 123) - Y = dpt.asarray(Ynp, sycl_queue=q) - out = dpt.empty_like(Y, dtype="bool") - dpt.isnan(Y, out) - assert np.array_equal(dpt.asnumpy(out)[()], np.isnan(Ynp)) - - -@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) -def test_isnan_floats_out_keyword(dtype): - q = get_queue_or_skip() - skip_if_dtype_not_supported(dtype, q) - - y1 = np.nan - y2 = 1 - y3 = np.inf - - for mult in [123, 137, 255, 271, 272]: - Ynp = np.repeat(np.array([y1, y2, y3], dtype=dtype), mult) - Y = dpt.asarray(Ynp, sycl_queue=q) - out = dpt.empty_like(Y, dtype="bool") - dpt.isnan(Y, out) - assert np.array_equal(dpt.asnumpy(out)[()], np.isnan(Ynp)) From 05ae945382ca1515ea40ff1c655629f4afbf57d6 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Sat, 20 May 2023 11:11:07 -0500 Subject: [PATCH 45/48] Fixed tests to run on Iris Xe Also ensure that test_add_order exercises non-same dtypes to improve coverage. --- dpctl/tests/elementwise/test_add.py | 80 +++++++++++++++++------------ 1 file changed, 46 insertions(+), 34 deletions(-) diff --git a/dpctl/tests/elementwise/test_add.py b/dpctl/tests/elementwise/test_add.py index 8a17b4f761..586c78eab2 100644 --- a/dpctl/tests/elementwise/test_add.py +++ b/dpctl/tests/elementwise/test_add.py @@ -33,7 +33,7 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype): assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all() assert r.sycl_queue == ar1.sycl_queue - out = dpt.empty_like(ar1, dtype=expected_dtype) + out = dpt.empty_like(ar1, dtype=r.dtype) dpt.add(ar1, ar2, out) assert (dpt.asnumpy(out) == np.full(out.shape, 2, dtype=out.dtype)).all() @@ -49,7 +49,7 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype): assert r.shape == ar3.shape assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all() - out = dpt.empty_like(ar1, dtype=expected_dtype) + out = dpt.empty_like(ar1, dtype=r.dtype) dpt.add(ar3[::-1], ar4[::2], out) assert (dpt.asnumpy(out) == np.full(out.shape, 2, dtype=out.dtype)).all() @@ -74,37 +74,49 @@ def test_add_usm_type_matrix(op1_usm_type, op2_usm_type): def test_add_order(): get_queue_or_skip() - ar1 = dpt.ones((20, 20), dtype="i4", order="C") - ar2 = dpt.ones((20, 20), dtype="i4", order="C") - r1 = dpt.add(ar1, ar2, order="C") - assert r1.flags.c_contiguous - r2 = dpt.add(ar1, ar2, order="F") - assert r2.flags.f_contiguous - r3 = dpt.add(ar1, ar2, order="A") - assert r3.flags.c_contiguous - r4 = dpt.add(ar1, ar2, order="K") - assert r4.flags.c_contiguous - - ar1 = dpt.ones((20, 20), dtype="i4", order="F") - ar2 = dpt.ones((20, 20), dtype="i4", order="F") - r1 = dpt.add(ar1, ar2, order="C") - assert r1.flags.c_contiguous - r2 = dpt.add(ar1, ar2, order="F") - assert r2.flags.f_contiguous - r3 = dpt.add(ar1, ar2, order="A") - assert r3.flags.f_contiguous - r4 = dpt.add(ar1, ar2, order="K") - assert r4.flags.f_contiguous - - ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] - ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2] - r4 = dpt.add(ar1, ar2, order="K") - assert r4.strides == (20, -1) - - ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT - ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT - r4 = dpt.add(ar1, ar2, order="K") - assert r4.strides == (-1, 20) + test_shape = ( + 20, + 20, + ) + test_shape2 = tuple(2 * dim for dim in test_shape) + n = test_shape[-1] + + for dt1, dt2 in zip(["i4", "i4", "f4"], ["i4", "f4", "i4"]): + ar1 = dpt.ones(test_shape, dtype=dt1, order="C") + ar2 = dpt.ones(test_shape, dtype=dt2, order="C") + r1 = dpt.add(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.add(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.add(ar1, ar2, order="A") + assert r3.flags.c_contiguous + r4 = dpt.add(ar1, ar2, order="K") + assert r4.flags.c_contiguous + + ar1 = dpt.ones(test_shape, dtype=dt1, order="F") + ar2 = dpt.ones(test_shape, dtype=dt2, order="F") + r1 = dpt.add(ar1, ar2, order="C") + assert r1.flags.c_contiguous + r2 = dpt.add(ar1, ar2, order="F") + assert r2.flags.f_contiguous + r3 = dpt.add(ar1, ar2, order="A") + assert r3.flags.f_contiguous + r4 = dpt.add(ar1, ar2, order="K") + assert r4.flags.f_contiguous + + ar1 = dpt.ones(test_shape2, dtype=dt1, order="C")[:20, ::-2] + ar2 = dpt.ones(test_shape2, dtype=dt2, order="C")[:20, ::-2] + r4 = dpt.add(ar1, ar2, order="K") + assert r4.strides == (n, -1) + r5 = dpt.add(ar1, ar2, order="C") + assert r5.strides == (n, 1) + + ar1 = dpt.ones(test_shape2, dtype=dt1, order="C")[:20, ::-2].mT + ar2 = dpt.ones(test_shape2, dtype=dt2, order="C")[:20, ::-2].mT + r4 = dpt.add(ar1, ar2, order="K") + assert r4.strides == (-1, n) + r5 = dpt.add(ar1, ar2, order="C") + assert r5.strides == (n, 1) def test_add_broadcasting(): @@ -266,7 +278,7 @@ def test_add_dtype_error( skip_if_dtype_not_supported(dtype, q) ar1 = dpt.ones(5, dtype=dtype) - ar2 = dpt.ones_like(ar1, dtype="f8") + ar2 = dpt.ones_like(ar1, dtype="f4") y = dpt.zeros_like(ar1, dtype="int8") assert_raises_regex( From d691846f34cd153ecaea6d7423ef70719e7fe64f Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Sat, 20 May 2023 11:21:47 -0500 Subject: [PATCH 46/48] Added a test for broadcasting error --- dpctl/tests/elementwise/test_add.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dpctl/tests/elementwise/test_add.py b/dpctl/tests/elementwise/test_add.py index 586c78eab2..b8d202c54f 100644 --- a/dpctl/tests/elementwise/test_add.py +++ b/dpctl/tests/elementwise/test_add.py @@ -144,6 +144,14 @@ def test_add_broadcasting(): ).all() +def test_add_broadcasting_error(): + get_queue_or_skip() + m = dpt.ones((10, 10), dtype="i4") + v = dpt.ones((3,), dtype="i4") + with pytest.raises(ValueError): + dpt.add(m, v) + + @pytest.mark.parametrize("arr_dt", _all_dtypes) def test_add_python_scalar(arr_dt): q = get_queue_or_skip() From f3d5519d55a3c8dbb79dfa6620c84346bf445f9d Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Sat, 20 May 2023 13:26:10 -0500 Subject: [PATCH 47/48] use keywrod argument --- dpctl/tests/elementwise/test_abs.py | 12 +++++----- dpctl/tests/elementwise/test_add.py | 30 ++++++++++-------------- dpctl/tests/elementwise/test_cos.py | 4 ++-- dpctl/tests/elementwise/test_isfinite.py | 12 +++++----- dpctl/tests/elementwise/test_isnan.py | 12 +++++----- 5 files changed, 33 insertions(+), 37 deletions(-) diff --git a/dpctl/tests/elementwise/test_abs.py b/dpctl/tests/elementwise/test_abs.py index 2e2ff69ea4..d4aefaf76d 100644 --- a/dpctl/tests/elementwise/test_abs.py +++ b/dpctl/tests/elementwise/test_abs.py @@ -23,15 +23,15 @@ def test_abs_out_type(dtype): } assert dpt.abs(X).dtype == type_map[arg_dt] - out = dpt.empty_like(X, dtype=type_map[arg_dt]) - dpt.abs(X, out) - assert np.allclose(dpt.asnumpy(out), dpt.asnumpy(dpt.abs(X))) + r = dpt.empty_like(X, dtype=type_map[arg_dt]) + dpt.abs(X, out=r) + assert np.allclose(dpt.asnumpy(r), dpt.asnumpy(dpt.abs(X))) else: assert dpt.abs(X).dtype == arg_dt - out = dpt.empty_like(X, dtype=arg_dt) - dpt.abs(X, out) - assert np.allclose(dpt.asnumpy(out), dpt.asnumpy(dpt.abs(X))) + r = dpt.empty_like(X, dtype=arg_dt) + dpt.abs(X, out=r) + assert np.allclose(dpt.asnumpy(r), dpt.asnumpy(dpt.abs(X))) @pytest.mark.parametrize("usm_type", _usm_types) diff --git a/dpctl/tests/elementwise/test_add.py b/dpctl/tests/elementwise/test_add.py index b8d202c54f..2b4ab8e3cb 100644 --- a/dpctl/tests/elementwise/test_add.py +++ b/dpctl/tests/elementwise/test_add.py @@ -33,9 +33,9 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype): assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all() assert r.sycl_queue == ar1.sycl_queue - out = dpt.empty_like(ar1, dtype=r.dtype) - dpt.add(ar1, ar2, out) - assert (dpt.asnumpy(out) == np.full(out.shape, 2, dtype=out.dtype)).all() + r2 = dpt.empty_like(ar1, dtype=r.dtype) + dpt.add(ar1, ar2, out=r2) + assert (dpt.asnumpy(r2) == np.full(r2.shape, 2, dtype=r2.dtype)).all() ar3 = dpt.ones(sz, dtype=op1_dtype) ar4 = dpt.ones(2 * sz, dtype=op2_dtype) @@ -49,9 +49,9 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype): assert r.shape == ar3.shape assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all() - out = dpt.empty_like(ar1, dtype=r.dtype) - dpt.add(ar3[::-1], ar4[::2], out) - assert (dpt.asnumpy(out) == np.full(out.shape, 2, dtype=out.dtype)).all() + r2 = dpt.empty_like(ar1, dtype=r.dtype) + dpt.add(ar3[::-1], ar4[::2], out=r2) + assert (dpt.asnumpy(r2) == np.full(r2.shape, 2, dtype=r2.dtype)).all() @pytest.mark.parametrize("op1_usm_type", _usm_types) @@ -131,17 +131,13 @@ def test_add_broadcasting(): r2 = dpt.add(v, m) assert (dpt.asnumpy(r2) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all() - out = dpt.empty_like(m) - dpt.add(m, v, out) - assert ( - dpt.asnumpy(out) == np.arange(1, 6, dtype="i4")[np.newaxis, :] - ).all() - - out2 = dpt.empty_like(m) - dpt.add(v, m, out2) - assert ( - dpt.asnumpy(out2) == np.arange(1, 6, dtype="i4")[np.newaxis, :] - ).all() + r3 = dpt.empty_like(m) + dpt.add(m, v, out=r3) + assert (dpt.asnumpy(r3) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all() + + r4 = dpt.empty_like(m) + dpt.add(v, m, out=r4) + assert (dpt.asnumpy(r4) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all() def test_add_broadcasting_error(): diff --git a/dpctl/tests/elementwise/test_cos.py b/dpctl/tests/elementwise/test_cos.py index 0d1e641b5f..395fa60435 100644 --- a/dpctl/tests/elementwise/test_cos.py +++ b/dpctl/tests/elementwise/test_cos.py @@ -25,7 +25,7 @@ def test_cos_out_type(dtype): expected_dtype = np.cos(np.array(0, dtype=dtype)).dtype expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device) Y = dpt.empty_like(X, dtype=expected_dtype) - dpt.cos(X, Y) + dpt.cos(X, out=Y) np.testing.assert_allclose(dpt.asnumpy(dpt.cos(X)), dpt.asnumpy(Y)) @@ -48,7 +48,7 @@ def test_cos_output(dtype): ) Z = dpt.empty_like(X, dtype=dtype) - dpt.cos(X, Z) + dpt.cos(X, out=Z) np.testing.assert_allclose( dpt.asnumpy(Z), np.repeat(np.cos(Xnp), n_rep), atol=tol, rtol=tol diff --git a/dpctl/tests/elementwise/test_isfinite.py b/dpctl/tests/elementwise/test_isfinite.py index 7f86277c07..92e585d217 100644 --- a/dpctl/tests/elementwise/test_isfinite.py +++ b/dpctl/tests/elementwise/test_isfinite.py @@ -41,9 +41,9 @@ def test_isfinite_complex(dtype): Y = dpt.asarray(Ynp, sycl_queue=q) assert np.array_equal(dpt.asnumpy(dpt.isfinite(Y)), np.isfinite(Ynp)) - out = dpt.empty_like(Y, dtype="bool") - dpt.isfinite(Y, out) - assert np.array_equal(dpt.asnumpy(out)[()], np.isfinite(Ynp)) + r = dpt.empty_like(Y, dtype="bool") + dpt.isfinite(Y, out=r) + assert np.array_equal(dpt.asnumpy(r)[()], np.isfinite(Ynp)) @pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) @@ -60,9 +60,9 @@ def test_isfinite_floats(dtype): Y = dpt.asarray(Ynp, sycl_queue=q) assert np.array_equal(dpt.asnumpy(dpt.isfinite(Y)), np.isfinite(Ynp)) - out = dpt.empty_like(Y, dtype="bool") - dpt.isfinite(Y, out) - assert np.array_equal(dpt.asnumpy(out)[()], np.isfinite(Ynp)) + r = dpt.empty_like(Y, dtype="bool") + dpt.isfinite(Y, out=r) + assert np.array_equal(dpt.asnumpy(r)[()], np.isfinite(Ynp)) @pytest.mark.parametrize("dtype", _all_dtypes) diff --git a/dpctl/tests/elementwise/test_isnan.py b/dpctl/tests/elementwise/test_isnan.py index 74d922b6db..7545251bf2 100644 --- a/dpctl/tests/elementwise/test_isnan.py +++ b/dpctl/tests/elementwise/test_isnan.py @@ -41,9 +41,9 @@ def test_isnan_complex(dtype): Y = dpt.asarray(Ynp, sycl_queue=q) assert np.array_equal(dpt.asnumpy(dpt.isnan(Y)), np.isnan(Ynp)) - out = dpt.empty_like(Y, dtype="bool") - dpt.isnan(Y, out) - assert np.array_equal(dpt.asnumpy(out)[()], np.isnan(Ynp)) + r = dpt.empty_like(Y, dtype="bool") + dpt.isnan(Y, out=r) + assert np.array_equal(dpt.asnumpy(r)[()], np.isnan(Ynp)) @pytest.mark.parametrize("dtype", ["f2", "f4", "f8"]) @@ -60,9 +60,9 @@ def test_isnan_floats(dtype): Y = dpt.asarray(Ynp, sycl_queue=q) assert np.array_equal(dpt.asnumpy(dpt.isnan(Y)), np.isnan(Ynp)) - out = dpt.empty_like(Y, dtype="bool") - dpt.isnan(Y, out) - assert np.array_equal(dpt.asnumpy(out)[()], np.isnan(Ynp)) + r = dpt.empty_like(Y, dtype="bool") + dpt.isnan(Y, out=r) + assert np.array_equal(dpt.asnumpy(r)[()], np.isnan(Ynp)) @pytest.mark.parametrize("dtype", _all_dtypes) From 62f2d464c14b2398969f35ddb4605df306480c2c Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 22 May 2023 09:37:42 -0500 Subject: [PATCH 48/48] Added missing license headers, updated license year to 2023 --- dpctl/tensor/_elementwise_funcs.py | 16 ++++++++++++ .../include/kernels/boolean_reductions.hpp | 4 +-- .../include/kernels/constructors.hpp | 2 +- .../include/kernels/copy_and_cast.hpp | 2 +- .../kernels/elementwise_functions/abs.hpp | 24 ++++++++++++++++++ .../kernels/elementwise_functions/add.hpp | 25 +++++++++++++++++++ .../kernels/elementwise_functions/common.hpp | 11 ++++---- .../kernels/elementwise_functions/cos.hpp | 24 ++++++++++++++++++ .../kernels/elementwise_functions/equal.hpp | 25 +++++++++++++++++++ .../elementwise_functions/isfinite.hpp | 25 +++++++++++++++++++ .../kernels/elementwise_functions/isinf.hpp | 25 +++++++++++++++++++ .../kernels/elementwise_functions/isnan.hpp | 25 +++++++++++++++++++ .../kernels/elementwise_functions/sqrt.hpp | 25 +++++++++++++++++++ .../elementwise_functions/true_divide.hpp | 25 +++++++++++++++++++ .../kernels/integer_advanced_indexing.hpp | 2 +- .../source/boolean_advanced_indexing.cpp | 2 +- .../source/boolean_advanced_indexing.hpp | 2 +- .../source/copy_and_cast_usm_to_usm.cpp | 2 +- .../source/copy_and_cast_usm_to_usm.hpp | 2 +- .../libtensor/source/copy_for_reshape.cpp | 2 +- .../libtensor/source/copy_for_reshape.hpp | 2 +- .../copy_numpy_ndarray_into_usm_ndarray.cpp | 2 +- .../copy_numpy_ndarray_into_usm_ndarray.hpp | 2 +- .../source/device_support_queries.cpp | 2 +- .../source/device_support_queries.hpp | 2 +- .../source/elementwise_functions.cpp | 2 +- .../source/elementwise_functions.hpp | 2 +- dpctl/tensor/libtensor/source/eye_ctor.cpp | 2 +- dpctl/tensor/libtensor/source/eye_ctor.hpp | 2 +- dpctl/tensor/libtensor/source/full_ctor.cpp | 2 +- dpctl/tensor/libtensor/source/full_ctor.hpp | 2 +- .../source/integer_advanced_indexing.cpp | 2 +- .../source/integer_advanced_indexing.hpp | 2 +- .../libtensor/source/linear_sequences.cpp | 2 +- .../libtensor/source/linear_sequences.hpp | 2 +- .../source/simplify_iteration_space.cpp | 2 +- .../source/simplify_iteration_space.hpp | 2 +- dpctl/tensor/libtensor/source/tensor_py.cpp | 2 +- dpctl/tensor/libtensor/source/triul_ctor.cpp | 2 +- dpctl/tensor/libtensor/source/triul_ctor.hpp | 2 +- dpctl/tests/elementwise/__init__.py | 20 +++++++++++++++ dpctl/tests/elementwise/test_abs.py | 16 ++++++++++++ dpctl/tests/elementwise/test_add.py | 16 ++++++++++++ dpctl/tests/elementwise/test_cos.py | 16 ++++++++++++ dpctl/tests/elementwise/test_divide.py | 16 ++++++++++++ dpctl/tests/elementwise/test_equal.py | 16 ++++++++++++ dpctl/tests/elementwise/test_isfinite.py | 16 ++++++++++++ dpctl/tests/elementwise/test_isinf.py | 16 ++++++++++++ dpctl/tests/elementwise/test_isnan.py | 16 ++++++++++++ dpctl/tests/elementwise/test_sqrt.py | 16 ++++++++++++ dpctl/tests/elementwise/test_type_utils.py | 16 ++++++++++++ dpctl/tests/elementwise/utils.py | 16 ++++++++++++ 52 files changed, 470 insertions(+), 36 deletions(-) diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index 3068486827..27c549c034 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -1,3 +1,19 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import dpctl.tensor._tensor_impl as ti from ._elementwise_common import BinaryElementwiseFunc, UnaryElementwiseFunc diff --git a/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp b/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp index dec96fab2a..8418fca83c 100644 --- a/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp @@ -1,5 +1,5 @@ -//=== boolean_reductions.hpp - Implementation of boolean reduction kernels -//---*-C++-*--/===// +//=== boolean_reductions.hpp - Implementation of boolean reduction kernels // +// ---*-C++-*--/===// // // Data Parallel Control (dpctl) // diff --git a/dpctl/tensor/libtensor/include/kernels/constructors.hpp b/dpctl/tensor/libtensor/include/kernels/constructors.hpp index 6449d992cd..49111cbb61 100644 --- a/dpctl/tensor/libtensor/include/kernels/constructors.hpp +++ b/dpctl/tensor/libtensor/include/kernels/constructors.hpp @@ -3,7 +3,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp b/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp index 02b7ac3c2d..f1e63ccc60 100644 --- a/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp +++ b/dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp index 83e1f2709a..09f2995874 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp @@ -1,3 +1,27 @@ +//=== abs.hpp - Unary function ABS ------ *-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise evaluation of ABS(x) function. +//===---------------------------------------------------------------------===// + #pragma once #include #include diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index f045dad1b6..d0ab25d270 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -1,3 +1,28 @@ +//=== add.hpp - Binary function ADD ------ *-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise evaluation of ADD(x1, x2) +/// function. +//===---------------------------------------------------------------------===// + #pragma once #include #include diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp index bb46d4cdca..a1108c541c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp @@ -1,9 +1,8 @@ -//=== common.hpp - -----------------------------------*-C++-*--/===// -//= Implementation of tensor elementwise operation kernels ------===// +//=== common.hpp - Common code for elementwise operations ----- *-C++-*--/===// // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,11 +16,11 @@ // See the License for the specific language governing permissions and // limitations under the License. // -//===----------------------------------------------------------------------===// +//===---------------------------------------------------------------------===// /// /// \file -/// This file defines kernels for elementwise operations over tensor . -//===----------------------------------------------------------------------===// +/// This file defines common code for elementwise tensor operations. +//===---------------------------------------------------------------------===// #pragma once #include diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index b69cf2698d..b6859910e3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -1,3 +1,27 @@ +//=== cos.hpp - Unary function COS ------ *-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise evaluation of COS(x) function. +//===---------------------------------------------------------------------===// + #pragma once #include #include diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp index e7c118f216..edbbc393c5 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp @@ -1,3 +1,28 @@ +//=== equal.hpp - Binary function EQUAL ------ *-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise evaluation of equality of +/// tensor elements. +//===---------------------------------------------------------------------===// + #pragma once #include #include diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp index e1e88cda02..86340329da 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp @@ -1,3 +1,28 @@ +//=== isfinite.hpp - Unary function ISFINITE ------ *-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise evaluation of ISFINITE(x) +/// function that tests whether a tensor element is finite. +//===---------------------------------------------------------------------===// + #pragma once #include #include diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp index 05e7ce6f6d..8f7dcca6c7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp @@ -1,3 +1,28 @@ +//=== isinf.hpp - Unary function ISINF ------ *-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise evaluation of ISINF(x) +/// function that tests whether a tensor element is an infinity. +//===---------------------------------------------------------------------===// + #pragma once #include #include diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp index edc62a5a1d..3e4f68ed57 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp @@ -1,3 +1,28 @@ +//=== isnan.hpp - Unary function ISNAN ------ *-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise evaluation of ISNAN(x) +/// function that tests whether a tensor element is a NaN. +//===---------------------------------------------------------------------===// + #pragma once #include #include diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp index 7eaf7e2e93..7e576c8746 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp @@ -1,3 +1,28 @@ +//=== sqrt.hpp - Unary function SQRT ------ *-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise evaluation of SQRT(x) +/// function that compute a square root. +//===---------------------------------------------------------------------===// + #pragma once #include #include diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp index 4d72893e58..f34da1b415 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp @@ -1,3 +1,28 @@ +//=== true_divide.hpp - Binary function DIVIDE ------ *-C++-*--/===// +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise evaluation of DIVIDE(x1, x2) +/// function. +//===---------------------------------------------------------------------===// + #pragma once #include #include diff --git a/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp b/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp index ce24dc799a..0f60c7a4b2 100644 --- a/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp +++ b/dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp index 0a706146e2..59f62af5f1 100644 --- a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp +++ b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.hpp b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.hpp index 96c264f563..e6e8a54ed6 100644 --- a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.hpp +++ b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp b/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp index d8692c1098..57386e736f 100644 --- a/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp +++ b/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.hpp b/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.hpp index 192d70c0f2..109062516a 100644 --- a/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.hpp +++ b/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/copy_for_reshape.cpp b/dpctl/tensor/libtensor/source/copy_for_reshape.cpp index d4f7f50437..7f4b8d718c 100644 --- a/dpctl/tensor/libtensor/source/copy_for_reshape.cpp +++ b/dpctl/tensor/libtensor/source/copy_for_reshape.cpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/copy_for_reshape.hpp b/dpctl/tensor/libtensor/source/copy_for_reshape.hpp index 51c3719b97..09caddf824 100644 --- a/dpctl/tensor/libtensor/source/copy_for_reshape.hpp +++ b/dpctl/tensor/libtensor/source/copy_for_reshape.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp b/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp index f4e29411b0..0464a14cd3 100644 --- a/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp +++ b/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.hpp b/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.hpp index 16adb921ee..e5bf513921 100644 --- a/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.hpp +++ b/dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/device_support_queries.cpp b/dpctl/tensor/libtensor/source/device_support_queries.cpp index 74ae3464fc..16ae43ba97 100644 --- a/dpctl/tensor/libtensor/source/device_support_queries.cpp +++ b/dpctl/tensor/libtensor/source/device_support_queries.cpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/device_support_queries.hpp b/dpctl/tensor/libtensor/source/device_support_queries.hpp index 905ba4b535..a54835fc75 100644 --- a/dpctl/tensor/libtensor/source/device_support_queries.hpp +++ b/dpctl/tensor/libtensor/source/device_support_queries.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp index 681c76905b..de209fdb1c 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions.cpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.hpp b/dpctl/tensor/libtensor/source/elementwise_functions.hpp index b4e8bac10d..58127756ca 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions.hpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/eye_ctor.cpp b/dpctl/tensor/libtensor/source/eye_ctor.cpp index f04518bc48..c4a8f0cd08 100644 --- a/dpctl/tensor/libtensor/source/eye_ctor.cpp +++ b/dpctl/tensor/libtensor/source/eye_ctor.cpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/eye_ctor.hpp b/dpctl/tensor/libtensor/source/eye_ctor.hpp index 1067ed8d8b..3436c23bd8 100644 --- a/dpctl/tensor/libtensor/source/eye_ctor.hpp +++ b/dpctl/tensor/libtensor/source/eye_ctor.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/full_ctor.cpp b/dpctl/tensor/libtensor/source/full_ctor.cpp index f4b8ae5f42..2f7182807c 100644 --- a/dpctl/tensor/libtensor/source/full_ctor.cpp +++ b/dpctl/tensor/libtensor/source/full_ctor.cpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/full_ctor.hpp b/dpctl/tensor/libtensor/source/full_ctor.hpp index 4a620a03db..3870573fa4 100644 --- a/dpctl/tensor/libtensor/source/full_ctor.hpp +++ b/dpctl/tensor/libtensor/source/full_ctor.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp b/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp index 1039820014..80b12314e3 100644 --- a/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp +++ b/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/integer_advanced_indexing.hpp b/dpctl/tensor/libtensor/source/integer_advanced_indexing.hpp index c6d5ed74b8..438bf613e2 100644 --- a/dpctl/tensor/libtensor/source/integer_advanced_indexing.hpp +++ b/dpctl/tensor/libtensor/source/integer_advanced_indexing.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/linear_sequences.cpp b/dpctl/tensor/libtensor/source/linear_sequences.cpp index 9b17581b8e..306add5f54 100644 --- a/dpctl/tensor/libtensor/source/linear_sequences.cpp +++ b/dpctl/tensor/libtensor/source/linear_sequences.cpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/linear_sequences.hpp b/dpctl/tensor/libtensor/source/linear_sequences.hpp index b463fdf533..8da56ecd10 100644 --- a/dpctl/tensor/libtensor/source/linear_sequences.hpp +++ b/dpctl/tensor/libtensor/source/linear_sequences.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/simplify_iteration_space.cpp b/dpctl/tensor/libtensor/source/simplify_iteration_space.cpp index 4eecef2d3f..2fb2d6078e 100644 --- a/dpctl/tensor/libtensor/source/simplify_iteration_space.cpp +++ b/dpctl/tensor/libtensor/source/simplify_iteration_space.cpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/simplify_iteration_space.hpp b/dpctl/tensor/libtensor/source/simplify_iteration_space.hpp index 356afca08d..1bd8ff5aa0 100644 --- a/dpctl/tensor/libtensor/source/simplify_iteration_space.hpp +++ b/dpctl/tensor/libtensor/source/simplify_iteration_space.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index bf72c73200..3a72c205fb 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/triul_ctor.cpp b/dpctl/tensor/libtensor/source/triul_ctor.cpp index b9cf4543f9..b40b50d030 100644 --- a/dpctl/tensor/libtensor/source/triul_ctor.cpp +++ b/dpctl/tensor/libtensor/source/triul_ctor.cpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tensor/libtensor/source/triul_ctor.hpp b/dpctl/tensor/libtensor/source/triul_ctor.hpp index 3789df80c5..2f277bb416 100644 --- a/dpctl/tensor/libtensor/source/triul_ctor.hpp +++ b/dpctl/tensor/libtensor/source/triul_ctor.hpp @@ -2,7 +2,7 @@ // // Data Parallel Control (dpctl) // -// Copyright 2020-2022 Intel Corporation +// Copyright 2020-2023 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/dpctl/tests/elementwise/__init__.py b/dpctl/tests/elementwise/__init__.py index e69de29bb2..ac810ba127 100644 --- a/dpctl/tests/elementwise/__init__.py +++ b/dpctl/tests/elementwise/__init__.py @@ -0,0 +1,20 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Collection of test and utility files for testing elementwise operations +over :class:`dpctl.tensor.usm_ndarray`. +""" diff --git a/dpctl/tests/elementwise/test_abs.py b/dpctl/tests/elementwise/test_abs.py index d4aefaf76d..ee7fa0cb6c 100644 --- a/dpctl/tests/elementwise/test_abs.py +++ b/dpctl/tests/elementwise/test_abs.py @@ -1,3 +1,19 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import itertools import numpy as np diff --git a/dpctl/tests/elementwise/test_add.py b/dpctl/tests/elementwise/test_add.py index 2b4ab8e3cb..fa97b1c1c7 100644 --- a/dpctl/tests/elementwise/test_add.py +++ b/dpctl/tests/elementwise/test_add.py @@ -1,3 +1,19 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import ctypes import numpy as np diff --git a/dpctl/tests/elementwise/test_cos.py b/dpctl/tests/elementwise/test_cos.py index 395fa60435..3bf441a8dc 100644 --- a/dpctl/tests/elementwise/test_cos.py +++ b/dpctl/tests/elementwise/test_cos.py @@ -1,3 +1,19 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import itertools import numpy as np diff --git a/dpctl/tests/elementwise/test_divide.py b/dpctl/tests/elementwise/test_divide.py index 168803f945..41aac736d7 100644 --- a/dpctl/tests/elementwise/test_divide.py +++ b/dpctl/tests/elementwise/test_divide.py @@ -1,3 +1,19 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import ctypes import numpy as np diff --git a/dpctl/tests/elementwise/test_equal.py b/dpctl/tests/elementwise/test_equal.py index a6821a3d5a..cdd26a32d0 100644 --- a/dpctl/tests/elementwise/test_equal.py +++ b/dpctl/tests/elementwise/test_equal.py @@ -1,3 +1,19 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import ctypes import numpy as np diff --git a/dpctl/tests/elementwise/test_isfinite.py b/dpctl/tests/elementwise/test_isfinite.py index 92e585d217..f25005542d 100644 --- a/dpctl/tests/elementwise/test_isfinite.py +++ b/dpctl/tests/elementwise/test_isfinite.py @@ -1,3 +1,19 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import itertools import numpy as np diff --git a/dpctl/tests/elementwise/test_isinf.py b/dpctl/tests/elementwise/test_isinf.py index 3ce1c74f36..37f9a68a2b 100644 --- a/dpctl/tests/elementwise/test_isinf.py +++ b/dpctl/tests/elementwise/test_isinf.py @@ -1,3 +1,19 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import itertools import numpy as np diff --git a/dpctl/tests/elementwise/test_isnan.py b/dpctl/tests/elementwise/test_isnan.py index 7545251bf2..5a2ac4e582 100644 --- a/dpctl/tests/elementwise/test_isnan.py +++ b/dpctl/tests/elementwise/test_isnan.py @@ -1,3 +1,19 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import itertools import numpy as np diff --git a/dpctl/tests/elementwise/test_sqrt.py b/dpctl/tests/elementwise/test_sqrt.py index e957807c5c..ce168a5ccb 100644 --- a/dpctl/tests/elementwise/test_sqrt.py +++ b/dpctl/tests/elementwise/test_sqrt.py @@ -1,3 +1,19 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import itertools import numpy as np diff --git a/dpctl/tests/elementwise/test_type_utils.py b/dpctl/tests/elementwise/test_type_utils.py index 415b70e1f8..c040713925 100644 --- a/dpctl/tests/elementwise/test_type_utils.py +++ b/dpctl/tests/elementwise/test_type_utils.py @@ -1,3 +1,19 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pytest import dpctl diff --git a/dpctl/tests/elementwise/utils.py b/dpctl/tests/elementwise/utils.py index b4e71f14ad..0d9396dcb4 100644 --- a/dpctl/tests/elementwise/utils.py +++ b/dpctl/tests/elementwise/utils.py @@ -1,3 +1,19 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import dpctl import dpctl.tensor._type_utils as tu