diff --git a/numba_dpex/core/datamodel/models.py b/numba_dpex/core/datamodel/models.py index 2085be7eb8..7cf8a4eaa1 100644 --- a/numba_dpex/core/datamodel/models.py +++ b/numba_dpex/core/datamodel/models.py @@ -22,8 +22,9 @@ ) -def _get_flattened_member_count(ty): - """Return the number of fields in an instance of a given StructModel.""" +def get_flattened_member_count(ty): + """Returns the number of fields in an instance of a given StructModel.""" + flattened_member_count = 0 members = ty._members for member in members: @@ -109,7 +110,7 @@ def flattened_field_count(self): """ Return the number of fields in an instance of a USMArrayDeviceModel. """ - return _get_flattened_member_count(self) + return get_flattened_member_count(self) class USMArrayHostModel(StructModel): @@ -143,7 +144,7 @@ def __init__(self, dmm, fe_type): @property def flattened_field_count(self): """Return the number of fields in an instance of a USMArrayHostModel.""" - return _get_flattened_member_count(self) + return get_flattened_member_count(self) class SyclQueueModel(StructModel): @@ -223,7 +224,7 @@ def __init__(self, dmm, fe_type): @property def flattened_field_count(self): """Return the number of fields in an instance of a RangeModel.""" - return _get_flattened_member_count(self) + return get_flattened_member_count(self) class NdRangeModel(StructModel): @@ -246,7 +247,7 @@ def __init__(self, dmm, fe_type): @property def flattened_field_count(self): """Return the number of fields in an instance of a NdRangeModel.""" - return _get_flattened_member_count(self) + return get_flattened_member_count(self) def _init_data_model_manager() -> datamodel.DataModelManager: diff --git a/numba_dpex/core/types/kernel_api/local_accessor.py b/numba_dpex/core/types/kernel_api/local_accessor.py new file mode 100644 index 0000000000..a34ce73b8f --- /dev/null +++ b/numba_dpex/core/types/kernel_api/local_accessor.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: 2024 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +from numba.core import cgutils +from numba.core.types import Type, UniTuple, intp +from numba.extending import NativeValue, unbox +from numba.np import numpy_support + +from numba_dpex.core.types import USMNdArray +from numba_dpex.utils import address_space as AddressSpace + + +class DpctlMDLocalAccessorType(Type): + """numba-dpex internal type to represent a dpctl SyclInterface type + `MDLocalAccessorTy`. + """ + + def __init__(self): + super().__init__(name="DpctlMDLocalAccessor") + + +class LocalAccessorType(USMNdArray): + """numba-dpex internal type to represent a Python object of + :class:`numba_dpex.experimental.kernel_iface.LocalAccessor`. + """ + + def __init__(self, ndim, dtype): + try: + if isinstance(dtype, Type): + parsed_dtype = dtype + else: + parsed_dtype = numpy_support.from_dtype(dtype) + except NotImplementedError as exc: + raise ValueError(f"Unsupported array dtype: {dtype}") from exc + + type_name = ( + f"LocalAccessor(dtype={parsed_dtype}, ndim={ndim}, " + f"address_space={AddressSpace.LOCAL})" + ) + + super().__init__( + ndim=ndim, + layout="C", + dtype=parsed_dtype, + addrspace=AddressSpace.LOCAL, + name=type_name, + ) + + def cast_python_value(self, args): + """The helper function is not overloaded and using it on the + LocalAccessorType throws a NotImplementedError. + """ + raise NotImplementedError + + +@unbox(LocalAccessorType) +def unbox_local_accessor(typ, obj, c): # pylint: disable=unused-argument + """Unboxes a Python LocalAccessor PyObject* into a numba-dpex internal + representation. + + A LocalAccessor object is represented internally in numba-dpex with the + same data model as a numpy.ndarray. It is done as a LocalAccessor object + serves only as a placeholder type when passed to ``call_kernel`` and the + data buffer should never be accessed inside a host-side compiled function + such as ``call_kernel``. + + When a LocalAccessor object is passed as an argument to a kernel function + it uses the USMArrayDeviceModel. Doing so allows numba-dpex to correctly + generate the kernel signature passing in a pointer in the local address + space. + """ + shape = c.pyapi.object_getattr_string(obj, "_shape") + local_accessor = cgutils.create_struct_proxy(typ)(c.context, c.builder) + + ty_unituple = UniTuple(intp, typ.ndim) + ll_shape = c.unbox(ty_unituple, shape) + local_accessor.shape = ll_shape.value + + return NativeValue( + c.builder.load(local_accessor._getpointer()), + is_error=ll_shape.is_error, + cleanup=ll_shape.cleanup, + ) diff --git a/numba_dpex/core/utils/kernel_flattened_args_builder.py b/numba_dpex/core/utils/kernel_flattened_args_builder.py index 1df93722db..c00ca17d44 100644 --- a/numba_dpex/core/utils/kernel_flattened_args_builder.py +++ b/numba_dpex/core/utils/kernel_flattened_args_builder.py @@ -7,14 +7,21 @@ object. """ +from functools import reduce +from math import ceil from typing import NamedTuple +import dpctl from llvmlite import ir as llvmir -from numba.core import types +from numba.core import cgutils, types from numba.core.cpu import CPUContext from numba_dpex import utils from numba_dpex.core.types import USMNdArray +from numba_dpex.core.types.kernel_api.local_accessor import ( + DpctlMDLocalAccessorType, + LocalAccessorType, +) from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum @@ -70,8 +77,14 @@ def add_argument( arg_type, arg_packed_llvm_val, ): - """Add kernel argument that need to be flatten.""" - if isinstance(arg_type, USMNdArray): + """Add flattened representation of a kernel argument.""" + if isinstance(arg_type, LocalAccessorType): + self._kernel_arg_list.extend( + self._build_local_accessor_arg( + arg_type, llvm_val=arg_packed_llvm_val + ) + ) + elif isinstance(arg_type, USMNdArray): self._kernel_arg_list.extend( self._build_array_arg( arg_type, llvm_array_val=arg_packed_llvm_val @@ -213,6 +226,121 @@ def _store_val_into_struct(self, struct_ref, index, val): ), ) + def _build_local_accessor_metadata_arg( + self, llvm_val, arg_type: LocalAccessorType, data_attr_ty + ): + """Handles the special case of building the kernel argument for the data + attribute of a kernel_api.LocalAccessor object. + + A kernel_api.LocalAccessor conceptually represents a device-only memory + allocation. The mock kernel_api.LocalAccessor uses a numpy.ndarray to + represent the data allocation. The numpy.ndarray cannot be passed to the + kernel and is ignored when building the kernel argument. Instead, a + struct is allocated to store the metadata about the size of the device + memory allocation and a reference to the struct is passed to the + DPCTLQueue_Submit call. The DPCTLQueue_Submit then constructs a + sycl::local_accessor object using the metadata and passes the + sycl::local_accessor as the kernel argument, letting the DPC++ runtime + handle proper device memory allocation. + """ + + ndim = arg_type.ndim + + md_proxy = cgutils.create_struct_proxy(DpctlMDLocalAccessorType())( + self._context, + self._builder, + ) + la_proxy = cgutils.create_struct_proxy(arg_type)( + self._context, self._builder, value=self._builder.load(llvm_val) + ) + + md_proxy.ndim = self._context.get_constant(types.int64, ndim) + md_proxy.dpctl_type_id = numba_type_to_dpctl_typenum( + self._context, data_attr_ty.dtype + ) + for i, val in enumerate( + cgutils.unpack_tuple(self._builder, la_proxy.shape) + ): + setattr(md_proxy, f"dim{i}", val) + + return self._build_arg( + llvm_val=md_proxy._getpointer(), + numba_type=LocalAccessorType( + ndim, dpctl.tensor.dtype(data_attr_ty.dtype.name) + ), + ) + + def _build_local_accessor_arg(self, arg_type: LocalAccessorType, llvm_val): + """Creates a list of kernel LLVM Values for an unpacked USMNdArray + kernel argument from the local accessor. + + Method generates UsmNdArray fields from local accessor type and value. + """ + # TODO: move extra values build on device side of codegen. + ndim = arg_type.ndim + la_proxy = cgutils.create_struct_proxy(arg_type)( + self._context, self._builder, value=self._builder.load(llvm_val) + ) + shape = cgutils.unpack_tuple(self._builder, la_proxy.shape) + ll_size = reduce(self._builder.mul, shape) + + size_ptr = cgutils.alloca_once_value(self._builder, ll_size) + itemsize = self._context.get_constant( + types.intp, ceil(arg_type.dtype.bitwidth / types.byte.bitwidth) + ) + itemsize_ptr = cgutils.alloca_once_value(self._builder, itemsize) + + kernel_arg_list = [] + + kernel_dm = self._kernel_dmm.lookup(arg_type) + + kernel_arg_list.extend( + self._build_arg( + llvm_val=size_ptr, + numba_type=kernel_dm.get_member_fe_type("nitems"), + ) + ) + + # Argument itemsize + kernel_arg_list.extend( + self._build_arg( + llvm_val=itemsize_ptr, + numba_type=kernel_dm.get_member_fe_type("itemsize"), + ) + ) + + # Argument data + data_attr_ty = kernel_dm.get_member_fe_type("data") + + kernel_arg_list.extend( + self._build_local_accessor_metadata_arg( + llvm_val=llvm_val, + arg_type=arg_type, + data_attr_ty=data_attr_ty, + ) + ) + + # Arguments for shape + for val in shape: + shape_ptr = cgutils.alloca_once_value(self._builder, val) + kernel_arg_list.extend( + self._build_arg( + llvm_val=shape_ptr, + numba_type=types.int64, + ) + ) + + # Arguments for strides + for i in range(ndim): + kernel_arg_list.extend( + self._build_arg( + llvm_val=itemsize_ptr, + numba_type=types.int64, + ) + ) + + return kernel_arg_list + def _build_array_arg(self, arg_type, llvm_array_val): """Creates a list of LLVM Values for an unpacked USMNdArray kernel argument. @@ -240,6 +368,7 @@ def _build_array_arg(self, arg_type, llvm_array_val): # Argument data data_attr_pos = host_data_model.get_field_position("data") data_attr_ty = kernel_data_model.get_member_fe_type("data") + kernel_arg_list.extend( self._build_collections_attr_arg( llvm_val=llvm_array_val, diff --git a/numba_dpex/core/utils/kernel_launcher.py b/numba_dpex/core/utils/kernel_launcher.py index d0f9426d88..f20a58efc6 100644 --- a/numba_dpex/core/utils/kernel_launcher.py +++ b/numba_dpex/core/utils/kernel_launcher.py @@ -21,6 +21,7 @@ from numba_dpex.core.exceptions import UnreachableError from numba_dpex.core.runtime.context import DpexRTContext from numba_dpex.core.types import USMNdArray +from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType from numba_dpex.core.types.kernel_api.ranges import NdRangeType, RangeType from numba_dpex.core.utils.kernel_flattened_args_builder import ( KernelFlattenedArgsBuilder, @@ -675,7 +676,9 @@ def get_queue_from_llvm_values( the queue from the first USMNdArray argument can be extracted. """ for arg_num, argty in enumerate(ty_kernel_args): - if isinstance(argty, USMNdArray): + if isinstance(argty, USMNdArray) and not isinstance( + argty, LocalAccessorType + ): llvm_val = ll_kernel_args[arg_num] datamodel = ctx.data_model_manager.lookup(argty) sycl_queue_attr_pos = datamodel.get_field_position("sycl_queue") diff --git a/numba_dpex/dpctl_iface/_helpers.py b/numba_dpex/dpctl_iface/_helpers.py index f46915eaf0..cd72014d84 100644 --- a/numba_dpex/dpctl_iface/_helpers.py +++ b/numba_dpex/dpctl_iface/_helpers.py @@ -5,6 +5,7 @@ from numba.core import types from numba_dpex import dpctl_sem_version +from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType def numba_type_to_dpctl_typenum(context, ty): @@ -34,6 +35,10 @@ def numba_type_to_dpctl_typenum(context, ty): return context.get_constant( types.int32, kargty.dpctl_void_ptr.value ) + elif isinstance(ty, LocalAccessorType): + return context.get_constant( + types.int32, kargty.dpctl_local_accessor.value + ) else: raise NotImplementedError else: @@ -61,5 +66,9 @@ def numba_type_to_dpctl_typenum(context, ty): elif ty == types.voidptr or isinstance(ty, types.CPointer): # DPCTL_VOID_PTR return context.get_constant(types.int32, 15) + elif isinstance(ty, LocalAccessorType): + raise NotImplementedError( + "LocalAccessor args for kernels requires dpctl 0.17 or greater." + ) else: raise NotImplementedError diff --git a/numba_dpex/experimental/__init__.py b/numba_dpex/experimental/__init__.py index f764e40843..97134ee042 100644 --- a/numba_dpex/experimental/__init__.py +++ b/numba_dpex/experimental/__init__.py @@ -12,6 +12,7 @@ from numba_dpex.core.boxing import * from numba_dpex.kernel_api_impl.spirv.dispatcher import SPIRVKernelDispatcher +from . import typeof from ._kernel_dpcpp_spirv_overloads import ( _atomic_fence_overloads, _atomic_ref_overloads, diff --git a/numba_dpex/experimental/launcher.py b/numba_dpex/experimental/launcher.py index 82809a4c9e..44827835e5 100644 --- a/numba_dpex/experimental/launcher.py +++ b/numba_dpex/experimental/launcher.py @@ -25,6 +25,7 @@ ItemType, NdItemType, ) +from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType from numba_dpex.core.utils import kernel_launcher as kl from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl from numba_dpex.dpctl_iface.wrappers import wrap_event_reference @@ -42,6 +43,23 @@ class _LLRange(NamedTuple): local_range_extents: list +def _has_a_local_accessor_argument(args): + """Checks if there exists at least one LocalAccessorType object in the + input tuple. + + Args: + args (_type_): A tuple of numba.core.Type objects + + Returns: + bool : True if at least one LocalAccessorType object was found, + otherwise False. + """ + for arg in args: + if isinstance(arg, LocalAccessorType): + return True + return False + + def _wrap_event_reference_tuple(ctx, builder, event1, event2): """Creates tuple data model from two event data models, so it can be boxed to Python.""" @@ -153,6 +171,18 @@ def _submit_kernel( # pylint: disable=too-many-arguments DeprecationWarning, ) + # Validate local accessor arguments are passed only to a kernel that is + # launched with an NdRange index space. Reference section 4.7.6.11. of the + # SYCL 2020 specification: A local_accessor must not be used in a SYCL + # kernel function that is invoked via single_task or via the simple form of + # parallel_for that takes a range parameter. + if _has_a_local_accessor_argument(ty_kernel_args_tuple) and isinstance( + ty_index_space, RangeType + ): + raise TypeError( + "A RangeType kernel cannot have a LocalAccessor argument" + ) + # ty_kernel_fn is type specific to exact function, so we can get function # directly from type and compile it. Thats why we don't need to get it in # codegen diff --git a/numba_dpex/experimental/models.py b/numba_dpex/experimental/models.py index f9e7a2f53d..b0c92e3083 100644 --- a/numba_dpex/experimental/models.py +++ b/numba_dpex/experimental/models.py @@ -12,6 +12,7 @@ from numba.core.extending import register_model import numba_dpex.core.datamodel.models as dpex_core_models +from numba_dpex.core.datamodel.models import USMArrayDeviceModel from numba_dpex.core.types.kernel_api.index_space_ids import ( GroupType, ItemType, @@ -19,6 +20,10 @@ ) from ..core.types.kernel_api.atomic_ref import AtomicRefType +from ..core.types.kernel_api.local_accessor import ( + DpctlMDLocalAccessorType, + LocalAccessorType, +) from .types import KernelDispatcherType @@ -44,6 +49,37 @@ def __init__(self, dmm, fe_type): super().__init__(dmm, fe_type, members) +class DpctlMDLocalAccessorModel(StructModel): + """Data model to represent DpctlMDLocalAccessorType. + + Must be the same structure as + dpctl/syclinterface/dpctl_sycl_queue_interface.h::MDLocalAccessor. + + Structure intended to be used only on host side of the kernel call. + """ + + def __init__(self, dmm, fe_type): + members = [ + ("ndim", types.size_t), + ("dpctl_type_id", types.int32), + ("dim0", types.size_t), + ("dim1", types.size_t), + ("dim2", types.size_t), + ] + super().__init__(dmm, fe_type, members) + + +class LocalAccessorModel(StructModel): + """Data model for the LocalAccessor type when used in a host-only function.""" + + def __init__(self, dmm, fe_type): + ndim = fe_type.ndim + members = [ + ("shape", types.UniTuple(types.intp, ndim)), + ] + super().__init__(dmm, fe_type, members) + + def _init_exp_data_model_manager() -> DataModelManager: """Initializes a DpexExpKernelTarget-specific data model manager. @@ -60,6 +96,9 @@ def _init_exp_data_model_manager() -> DataModelManager: # Register the types and data model in the DpexExpTargetContext dmm.register(AtomicRefType, AtomicRefModel) + # Register the LocalAccessorType type + dmm.register(LocalAccessorType, USMArrayDeviceModel) + # Register the GroupType type dmm.register(GroupType, EmptyStructModel) @@ -85,3 +124,9 @@ def _init_exp_data_model_manager() -> DataModelManager: # Register the NdItemType type register_model(NdItemType)(EmptyStructModel) + +# Register the MDLocalAccessorType type +register_model(DpctlMDLocalAccessorType)(DpctlMDLocalAccessorModel) + +# Register the LocalAccessorType type +register_model(LocalAccessorType)(LocalAccessorModel) diff --git a/numba_dpex/experimental/typeof.py b/numba_dpex/experimental/typeof.py index e72c951a0f..745a861ce9 100644 --- a/numba_dpex/experimental/typeof.py +++ b/numba_dpex/experimental/typeof.py @@ -14,7 +14,8 @@ ItemType, NdItemType, ) -from numba_dpex.kernel_api import AtomicRef, Group, Item, NdItem +from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType +from numba_dpex.kernel_api import AtomicRef, Group, Item, LocalAccessor, NdItem from ..core.types.kernel_api.atomic_ref import AtomicRefType @@ -84,3 +85,16 @@ def typeof_nditem(val: NdItem, c): instance. """ return NdItemType(val.dimensions) + + +@typeof_impl.register(LocalAccessor) +def typeof_local_accessor(val: LocalAccessor, c) -> LocalAccessorType: + """Returns a ``numba_dpex.experimental.dpctpp_types.LocalAccessorType`` + instance for a Python LocalAccessor object. + Args: + val (LocalAccessor): Instance of the LocalAccessor type. + c : Numba typing context used for type inference. + Returns: LocalAccessorType object corresponding to the LocalAccessor object. + """ + # pylint: disable=protected-access + return LocalAccessorType(ndim=len(val._shape), dtype=val._dtype) diff --git a/numba_dpex/kernel_api/__init__.py b/numba_dpex/kernel_api/__init__.py index 4ff9ec742a..a6da7b009c 100644 --- a/numba_dpex/kernel_api/__init__.py +++ b/numba_dpex/kernel_api/__init__.py @@ -14,21 +14,25 @@ from .barrier import group_barrier from .index_space_ids import Group, Item, NdItem from .launcher import call_kernel +from .local_accessor import LocalAccessor from .memory_enums import AddressSpace, MemoryOrder, MemoryScope from .private_array import PrivateArray from .ranges import NdRange, Range __all__ = [ + "call_kernel", + "group_barrier", "AddressSpace", "atomic_fence", "AtomicRef", + "Group", + "Item", + "LocalAccessor", "MemoryOrder", "MemoryScope", + "NdItem", "NdRange", "Range", - "Group", - "NdItem", - "Item", "PrivateArray", "group_barrier", "call_kernel", diff --git a/numba_dpex/kernel_api/launcher.py b/numba_dpex/kernel_api/launcher.py index 98a293e52b..6c746a46bb 100644 --- a/numba_dpex/kernel_api/launcher.py +++ b/numba_dpex/kernel_api/launcher.py @@ -9,6 +9,7 @@ from itertools import product from .index_space_ids import Group, Item, NdItem +from .local_accessor import LocalAccessor, _LocalAccessorMock from .ranges import NdRange, Range @@ -33,6 +34,12 @@ def _range_kernel_launcher(kernel_fn, index_range, *kernel_args): range_sets = [range(ir) for ir in index_range] index_tuples = list(product(*range_sets)) + for karg in kernel_args: + if isinstance(karg, LocalAccessor): + raise TypeError( + "LocalAccessor arguments are only supported for NdRange kernels" + ) + for idx in index_tuples: it = Item(extent=index_range, index=idx) @@ -66,6 +73,12 @@ def _ndrange_kernel_launcher(kernel_fn, index_range, *kernel_args): local_index_tuples = list(product(*local_range_sets)) group_index_tuples = list(product(*group_range_sets)) + modified_kernel_args = [] + for karg in kernel_args: + if isinstance(karg, LocalAccessor): + karg = _LocalAccessorMock(karg) + modified_kernel_args.append(karg) + # Loop over the groups (parallel loop) for gidx in group_index_tuples: # loop over work items in the group (parallel loop) @@ -76,27 +89,27 @@ def _ndrange_kernel_launcher(kernel_fn, index_range, *kernel_args): global_id.append( gidx_val * index_range.local_range[dim] + lidx[dim] ) - # Every NdItem has its own global Item, local Item and Group - nditem = NdItem( - global_item=Item( - extent=index_range.global_range, index=global_id - ), - local_item=Item(extent=index_range.local_range, index=lidx), - group=Group( - index_range.global_range, - index_range.local_range, - group_range, - gidx, - ), - ) - if len(signature(kernel_fn).parameters) - len(kernel_args) != 1: raise ValueError( "Required number of kernel function arguments do not " "match provided number of kernel args" ) - kernel_fn(nditem, *kernel_args) + kernel_fn( + NdItem( + global_item=Item( + extent=index_range.global_range, index=global_id + ), + local_item=Item(extent=index_range.local_range, index=lidx), + group=Group( + index_range.global_range, + index_range.local_range, + group_range, + gidx, + ), + ), + *modified_kernel_args + ) def call_kernel(kernel_fn, index_range, *kernel_args): diff --git a/numba_dpex/kernel_api/local_accessor.py b/numba_dpex/kernel_api/local_accessor.py new file mode 100644 index 0000000000..220ef884d7 --- /dev/null +++ b/numba_dpex/kernel_api/local_accessor.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +"""Implements a Python analogue to SYCL's local_accessor class. The class is +intended to be used in pure Python code when prototyping a kernel function +and to be passed to an actual kernel function for local memory allocation. +""" +import numpy + + +class LocalAccessor: + """ + The ``LocalAccessor`` class is analogous to SYCL's ``local_accessor`` + class. The class acts a s proxy to allocating device local memory and + accessing that memory from within a :func:`numba_dpex.kernel` decorated + function. + """ + + def _verify_positive_integral_list(self, ls): + """Checks if all members of a list are positive integers.""" + + ret = False + try: + ret = all(int(val) > 0 for val in ls) + except ValueError: + pass + + return ret + + def __init__(self, shape, dtype) -> None: + """Creates a new LocalAccessor instance of the given shape and dtype.""" + + if not isinstance(shape, (list, tuple)): + if hasattr(shape, "tolist"): + fn = getattr(shape, "tolist") + if callable(fn): + self._shape = tuple(shape.tolist()) + else: + try: + self._shape = (shape,) + except Exception as e: + raise TypeError( + "Argument shape must a non-negative integer, " + "or a list/tuple of such integers." + ) from e + else: + self._shape = tuple(shape) + + # Make sure shape is made up a supported types + if not self._verify_positive_integral_list(self._shape): + raise TypeError( + "Argument shape must a non-negative integer, " + "or a list/tuple of such integers." + ) + + # Make sure shape has a rank between (1..3) + if len(self._shape) < 1 or len(self._shape) > 3: + raise TypeError("LocalAccessor can only have up to 3 dimensions.") + + self._dtype = dtype + + if self._dtype not in [ + numpy.float32, + numpy.float64, + numpy.int32, + numpy.int64, + numpy.int16, + numpy.int8, + numpy.uint32, + numpy.uint64, + numpy.uint16, + numpy.uint8, + ]: + raise TypeError( + f"Argument dtype {dtype} is not supported. numpy.float32, " + "numpy.float64, numpy.[u]int8, numpy.[u]int16, numpy.[u]int32, " + "numpy.[u]int64 are the currently supported dtypes." + ) + + self._data = numpy.empty(self._shape, dtype=self._dtype) + + def __getitem__(self, idx_obj): + """Returns the value stored at the position represented by idx_obj in + the self._data ndarray. + """ + + raise NotImplementedError( + "The data of a LocalAccessor object can only be accessed " + "inside a kernel." + ) + + def __setitem__(self, idx_obj, val): + """Assigns a new value to the position represented by idx_obj in + the self._data ndarray. + """ + + raise NotImplementedError( + "The data of a LocalAccessor object can only be accessed " + "inside a kernel." + ) + + +class _LocalAccessorMock: + """Mock class that is used to represent a local accessor inside a "kernel". + + A LocalAccessor represents a device-only memory allocation and the + class is designed in a way to not have any data container backing up the + actual memory storage. Instead, the _LocalAccessorMock class is used to + represent a local_accessor that has an actual numpy ndarray backing it up. + Whenever, a LocalAccessor object is passed to `func`:kernel_api.call_kernel` + it is converted to a _LocalAccessor internally. That way the data and + access function on the data only works inside a kernel to simulate + device-only memory allocation and outside the kernel the data for a + LocalAccessor is not accessible. + """ + + def __init__(self, local_accessor: LocalAccessor): + self._data = numpy.empty( + local_accessor._shape, dtype=local_accessor._dtype + ) + + def __getitem__(self, idx_obj): + """Returns the value stored at the position represented by idx_obj in + the self._data ndarray. + """ + + return self._data[idx_obj] + + def __setitem__(self, idx_obj, val): + """Assigns a new value to the position represented by idx_obj in + the self._data ndarray. + """ + + self._data[idx_obj] = val diff --git a/numba_dpex/tests/experimental/codegen/test_local_accessor_kernel_arg.py b/numba_dpex/tests/experimental/codegen/test_local_accessor_kernel_arg.py new file mode 100644 index 0000000000..26905bf19b --- /dev/null +++ b/numba_dpex/tests/experimental/codegen/test_local_accessor_kernel_arg.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import dpctl +from llvmlite import ir as llvmir +from numba.core import types + +from numba_dpex import DpctlSyclQueue, DpnpNdArray +from numba_dpex import experimental as dpex_exp +from numba_dpex import int64 +from numba_dpex.core.types.kernel_api.index_space_ids import NdItemType +from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType +from numba_dpex.kernel_api import ( + AddressSpace, + MemoryScope, + NdItem, + group_barrier, +) + + +def kernel_func(nd_item: NdItem, a, slm): + i = nd_item.get_global_linear_id() + j = nd_item.get_local_linear_id() + + slm[j] = 100 + group_barrier(nd_item.get_group(), MemoryScope.WORK_GROUP) + + a[i] += slm[j] + + +def test_codegen_local_accessor_kernel_arg(): + """Tests if a kernel with a local accessor argument is generated with + expected local address space pointer argument. + """ + + queue_ty = DpctlSyclQueue(dpctl.SyclQueue()) + i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty) + slm_ty = LocalAccessorType(ndim=1, dtype=int64) + disp = dpex_exp.kernel(inline_threshold=3)(kernel_func) + dmm = disp.targetctx.data_model_manager + + i64arr_ty_flattened_arg_count = dmm.lookup(i64arr_ty).flattened_field_count + slm_ty_model = dmm.lookup(slm_ty) + slm_ty_flattened_arg_count = slm_ty_model.flattened_field_count + slm_ptr_pos = slm_ty_model.get_field_position("data") + + llargtys = disp.targetctx.get_arg_packer([i64arr_ty, slm_ty]).argument_types + + # Go over all the arguments to the spir_kernel_func and assert two things: + # a) Number of arguments == i64arr_ty_flattened_arg_count + # + slm_ty_flattened_arg_count + # b) The argument corresponding to the data attribute of the local accessor + # argument is a pointer in address space local address space + + num_kernel_args = 0 + slm_data_ptr_arg = None + for kernel_arg in llargtys: + if num_kernel_args == i64arr_ty_flattened_arg_count + slm_ptr_pos: + slm_data_ptr_arg = kernel_arg + num_kernel_args += 1 + assert ( + num_kernel_args + == i64arr_ty_flattened_arg_count + slm_ty_flattened_arg_count + ) + assert isinstance(slm_data_ptr_arg, llvmir.PointerType) + assert slm_data_ptr_arg.addrspace == AddressSpace.LOCAL diff --git a/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py b/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py new file mode 100644 index 0000000000..d8ae378908 --- /dev/null +++ b/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_local_accessors.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + + +import dpnp +import pytest +from numba.core.errors import TypingError + +import numba_dpex as dpex +import numba_dpex.experimental as dpex_exp +from numba_dpex.kernel_api import LocalAccessor, NdItem +from numba_dpex.kernel_api import call_kernel as kapi_call_kernel +from numba_dpex.tests._helper import get_all_dtypes + +list_of_supported_dtypes = get_all_dtypes( + no_bool=True, no_float16=True, no_none=True, no_complex=True +) + + +def _kernel1(nd_item: NdItem, a, slm): + i = nd_item.get_global_linear_id() + + # TODO: overload nd_item.get_local_id() + j = (nd_item.get_local_id(0),) + + slm[j] = 0 + + for m in range(100): + slm[j] += i * m + + a[i] = slm[j] + + +def _kernel2(nd_item: NdItem, a, slm): + i = nd_item.get_global_linear_id() + + # TODO: overload nd_item.get_local_id() + j = (nd_item.get_local_id(0), nd_item.get_local_id(1)) + + slm[j] = 0 + + for m in range(100): + slm[j] += i * m + + a[i] = slm[j] + + +def _kernel3(nd_item: NdItem, a, slm): + i = nd_item.get_global_linear_id() + + # TODO: overload nd_item.get_local_id() + j = ( + nd_item.get_local_id(0), + nd_item.get_local_id(1), + nd_item.get_local_id(2), + ) + + slm[j] = 0 + + for m in range(100): + slm[j] += i * m + + a[i] = slm[j] + + +def device_func_kernel(func): + _df = dpex_exp.device_func(func) + + @dpex_exp.kernel + def _kernel(item, a, slm): + _df(item, a, slm) + + return _kernel + + +@pytest.mark.parametrize("supported_dtype", list_of_supported_dtypes) +@pytest.mark.parametrize( + "nd_range, _kernel", + [ + (dpex.NdRange((32,), (32,)), _kernel1), + (dpex.NdRange((32, 1), (32, 1)), _kernel2), + (dpex.NdRange((1, 32, 1), (1, 32, 1)), _kernel3), + ], +) +@pytest.mark.parametrize( + "call_kernel, kernel", + [ + (dpex_exp.call_kernel, dpex_exp.kernel), + (dpex_exp.call_kernel, device_func_kernel), + (kapi_call_kernel, lambda f: f), + ], +) +def test_local_accessor( + supported_dtype, nd_range: dpex.NdRange, _kernel, call_kernel, kernel +): + """A test for passing a LocalAccessor object as a kernel argument.""" + + N = 32 + a = dpnp.empty(N, dtype=supported_dtype) + slm = LocalAccessor(nd_range.local_range, dtype=a.dtype) + + # A single work group with 32 work items is launched. Each work item + # computes the sum of (0..99) * its get_global_linear_id i.e., + # `4950 * get_global_linear_id` and stores it into the work groups local + # memory. The local memory is of size 32*64 elements of the requested dtype. + # The result is then stored into `a` in global memory + call_kernel(kernel(_kernel), nd_range, a, slm) + + for idx in range(N): + assert a[idx] == 4950 * idx + + +def test_local_accessor_argument_to_range_kernel(): + """Checks if an exception is raised when passing a local accessor to a + RangeType kernel. + """ + N = 32 + a = dpnp.empty(N) + slm = LocalAccessor((32 * 64), dtype=a.dtype) + + # Passing a local_accessor to a RangeType kernel should raise an exception. + # A TypeError is raised if NUMBA_CAPTURED_ERROR=new_style and a + # numba.TypingError is raised if NUMBA_CAPTURED_ERROR=old_style + with pytest.raises((TypeError, TypingError)): + dpex_exp.call_kernel(_kernel1, dpex.Range(N), a, slm) diff --git a/numba_dpex/tests/kernel_api/test_local_accessor.py b/numba_dpex/tests/kernel_api/test_local_accessor.py new file mode 100644 index 0000000000..0f7a151f9c --- /dev/null +++ b/numba_dpex/tests/kernel_api/test_local_accessor.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: 2024 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import numpy +import pytest + +from numba_dpex import kernel_api as kapi + + +def _slm_kernel(nd_item: kapi.NdItem, a, slm): + i = nd_item.get_global_linear_id() + j = nd_item.get_local_linear_id() + + slm[j] = 100 + a[i] = slm[i] + + +def test_local_accessor_data_inaccessible_outside_kernel(): + la = kapi.LocalAccessor((100,), dtype=numpy.float32) + + with pytest.raises(NotImplementedError): + print(la[0]) + + with pytest.raises(NotImplementedError): + la[0] = 10 + + +def test_local_accessor_use_inside_kernel(): + + a = numpy.empty(32) + slm = kapi.LocalAccessor(32, dtype=a.dtype) + + # launches one work group with 32 work item. Each work item initializes its + # position in the SLM to 100 and then writes it to the global array `a`. + kapi.call_kernel(_slm_kernel, kapi.NdRange((32,), (32,)), a, slm) + + assert numpy.all(a == 100) + + +def test_local_accessor_usage_not_allowed_with_range_kernel(): + + a = numpy.empty(32) + slm = kapi.LocalAccessor(32, dtype=a.dtype) + + with pytest.raises(TypeError): + kapi.call_kernel(_slm_kernel, kapi.Range((32,)), a, slm)