diff --git a/README.md b/README.md index 3641849417..ce5429a967 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ https://intelpython.github.io/dpnp/ ## Dependencies * numba >=0.51 (IntelPython/numba) -* dpCtl 0.5.* +* dpCtl >=0.5.1 * dpNP 0.4.* (optional) * llvm-spirv (SPIRV generation from LLVM IR) * llvmdev (LLVM IR generation) diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 5e5b61a25c..01e0f8b406 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -18,12 +18,12 @@ requirements: - setuptools - cython - numba - - dpctl 0.5.* + - dpctl >=0.5.1rc1 - dpnp 0.4.* # [linux] run: - python - numba >=0.51 - - dpctl 0.5.* + - dpctl >=0.5.1rc1 - spirv-tools - llvm-spirv - llvmdev diff --git a/numba_dppy/dppy_rt.c b/numba_dppy/dppy_rt.c new file mode 100644 index 0000000000..6589a369df --- /dev/null +++ b/numba_dppy/dppy_rt.c @@ -0,0 +1,171 @@ +#include "numba/_pymodule.h" +#include "numba/core/runtime/nrt_external.h" +#include "assert.h" +#include +#if !defined _WIN32 + #include +#else + #include +#endif + +NRT_ExternalAllocator usmarray_allocator; +NRT_external_malloc_func internal_allocator = NULL; +NRT_external_free_func internal_free = NULL; +void *(*get_queue_internal)(void) = NULL; +void (*free_queue_internal)(void*) = NULL; + +void * save_queue_allocator(size_t size, void *opaque) { + // Allocate a pointer-size more space than neded. + int new_size = size + sizeof(void*); + // Get the current queue + void *cur_queue = get_queue_internal(); // this makes a copy + // Use that queue to allocate. + void *data = internal_allocator(new_size, cur_queue); + // Set first pointer-sized data in allocated space to be the current queue. + *(void**)data = cur_queue; + // Return the pointer after this queue in memory. + return (char*)data + sizeof(void*); +} + +void save_queue_deallocator(void *data, void *opaque) { + // Compute original allocation location by subtracting the length + // of the queue pointer from the data location that Numba thinks + // starts the object. + void *orig_data = (char*)data - sizeof(void*); + // Get the queue from the original data by derefencing the first qword. + void *obj_queue = *(void**)orig_data; + // Free the space using the correct queue. + internal_free(orig_data, obj_queue); + // Free the queue itself. + free_queue_internal(obj_queue); +} + +void usmarray_memsys_init(void) { + #if !defined _WIN32 + char *lib_name = "libDPCTLSyclInterface.so"; + char *malloc_name = "DPCTLmalloc_shared"; + char *free_name = "DPCTLfree_with_queue"; + char *get_queue_name = "DPCTLQueueMgr_GetCurrentQueue"; + char *free_queue_name = "DPCTLQueue_Delete"; + + void *sycldl = dlopen(lib_name, RTLD_NOW); + assert(sycldl != NULL); + internal_allocator = (NRT_external_malloc_func)dlsym(sycldl, malloc_name); + usmarray_allocator.malloc = save_queue_allocator; + if (internal_allocator == NULL) { + printf("Did not find %s in %s\n", malloc_name, lib_name); + exit(-1); + } + + usmarray_allocator.realloc = NULL; + + internal_free = (NRT_external_free_func)dlsym(sycldl, free_name); + usmarray_allocator.free = save_queue_deallocator; + if (internal_free == NULL) { + printf("Did not find %s in %s\n", free_name, lib_name); + exit(-1); + } + + get_queue_internal = (void *(*)(void))dlsym(sycldl, get_queue_name); + if (get_queue_internal == NULL) { + printf("Did not find %s in %s\n", get_queue_name, lib_name); + exit(-1); + } + usmarray_allocator.opaque_data = NULL; + + free_queue_internal = (void (*)(void*))dlsym(sycldl, free_queue_name); + if (free_queue_internal == NULL) { + printf("Did not find %s in %s\n", free_queue_name, lib_name); + exit(-1); + } + #else + char *lib_name = "DPCTLSyclInterface.dll"; + char *malloc_name = "DPCTLmalloc_shared"; + char *free_name = "DPCTLfree_with_queue"; + char *get_queue_name = "DPCTLQueueMgr_GetCurrentQueue"; + char *free_queue_name = "DPCTLQueue_Delete"; + + HMODULE sycldl = LoadLibrary(lib_name); + assert(sycldl != NULL); + internal_allocator = (NRT_external_malloc_func)GetProcAddress(sycldl, malloc_name); + usmarray_allocator.malloc = save_queue_allocator; + if (internal_allocator == NULL) { + printf("Did not find %s in %s\n", malloc_name, lib_name); + exit(-1); + } + + usmarray_allocator.realloc = NULL; + + internal_free = (NRT_external_free_func)GetProcAddress(sycldl, free_name); + usmarray_allocator.free = save_queue_deallocator; + if (internal_free == NULL) { + printf("Did not find %s in %s\n", free_name, lib_name); + exit(-1); + } + + get_queue_internal = (void *(*)(void))GetProcAddress(sycldl, get_queue_name); + if (get_queue_internal == NULL) { + printf("Did not find %s in %s\n", get_queue_name, lib_name); + exit(-1); + } + usmarray_allocator.opaque_data = NULL; + + free_queue_internal = (void (*)(void*))GetProcAddress(sycldl, free_queue_name); + if (free_queue_internal == NULL) { + printf("Did not find %s in %s\n", free_queue_name, lib_name); + exit(-1); + } + #endif +} + +void * usmarray_get_ext_allocator(void) { + return (void*)&usmarray_allocator; +} + +static PyObject * +get_external_allocator(PyObject *self, PyObject *args) { + return PyLong_FromVoidPtr(usmarray_get_ext_allocator()); +} + +static PyMethodDef ext_methods[] = { +#define declmethod_noargs(func) { #func , ( PyCFunction )func , METH_NOARGS, NULL } + declmethod_noargs(get_external_allocator), + {NULL}, +#undef declmethod_noargs +}; + +static PyObject * +build_c_helpers_dict(void) +{ + PyObject *dct = PyDict_New(); + if (dct == NULL) + goto error; + +#define _declpointer(name, value) do { \ + PyObject *o = PyLong_FromVoidPtr(value); \ + if (o == NULL) goto error; \ + if (PyDict_SetItemString(dct, name, o)) { \ + Py_DECREF(o); \ + goto error; \ + } \ + Py_DECREF(o); \ +} while (0) + + _declpointer("usmarray_get_ext_allocator", &usmarray_get_ext_allocator); + +#undef _declpointer + return dct; +error: + Py_XDECREF(dct); + return NULL; +} + +MOD_INIT(_dppy_rt) { + PyObject *m; + MOD_DEF(m, "numba_dppy._dppy_rt", "No docs", ext_methods) + if (m == NULL) + return MOD_ERROR_VAL; + usmarray_memsys_init(); + PyModule_AddObject(m, "c_helpers", build_c_helpers_dict()); + return MOD_SUCCESS_VAL(m); +} diff --git a/numba_dppy/numpy_usm_shared.py b/numba_dppy/numpy_usm_shared.py new file mode 100644 index 0000000000..150ab0d3b3 --- /dev/null +++ b/numba_dppy/numpy_usm_shared.py @@ -0,0 +1,728 @@ +import numpy as np +from inspect import getmembers, isfunction, isclass, isbuiltin +from numbers import Number +import numba +from types import FunctionType as ftype, BuiltinFunctionType as bftype +from numba import types +from numba.extending import typeof_impl, register_model, type_callable, lower_builtin +from numba.np import numpy_support +from numba.core.pythonapi import box, allocator +from llvmlite import ir +import llvmlite.llvmpy.core as lc +import llvmlite.binding as llb +from numba.core import types, cgutils, config +import builtins +import sys +from ctypes.util import find_library +from numba.core.typing.templates import builtin_registry as templates_registry +from numba.core.typing.npydecl import registry as typing_registry +from numba.core.imputils import builtin_registry as lower_registry +import importlib +import functools +import inspect +from numba.core.typing.templates import ( + CallableTemplate, + AttributeTemplate, + signature, + bound_function, +) +from numba.core.typing.arraydecl import normalize_shape +from numba.np.arrayobj import _array_copy + +import dpctl.dptensor.numpy_usm_shared as nus +from dpctl.dptensor.numpy_usm_shared import ndarray, functions_list, class_list + + +debug = config.DEBUG + + +def dprint(*args): + if debug: + print(*args) + sys.stdout.flush() + + +# # This code makes it so that Numba can contain calls into the DPPLSyclInterface library. +# sycl_mem_lib = find_library('DPCTLSyclInterface') +# dprint("sycl_mem_lib:", sycl_mem_lib) +# # Load the symbols from the DPPL Sycl library. +# llb.load_library_permanently(sycl_mem_lib) + +import dpctl +from dpctl.memory import MemoryUSMShared +import numba_dppy._dppy_rt + +# Register the helper function in dppl_rt so that we can insert calls to them via llvmlite. +for py_name, c_address in numba_dppy._dppy_rt.c_helpers.items(): + llb.add_symbol(py_name, c_address) + + +class UsmSharedArrayType(types.Array): + """Creates a Numba type for Numpy arrays that are stored in USM shared + memory. We inherit from Numba's existing Numpy array type but overload + how this type is printed during dumping of typing information and we + implement the special __array_ufunc__ function to determine who this + type gets combined with scalars and regular Numpy types. + We re-use Numpy functions as well but those are going to return Numpy + arrays allocated in USM and we use the overloaded copy function to + convert such USM-backed Numpy arrays into typed USM arrays.""" + + def __init__( + self, + dtype, + ndim, + layout, + readonly=False, + name=None, + aligned=True, + addrspace=None, + ): + # This name defines how this type will be shown in Numba's type dumps. + name = "UsmArray:ndarray(%s, %sd, %s)" % (dtype, ndim, layout) + super(UsmSharedArrayType, self).__init__( + dtype, + ndim, + layout, + py_type=ndarray, + readonly=readonly, + name=name, + addrspace=addrspace, + ) + + def copy(self, *args, **kwargs): + retty = super(UsmSharedArrayType, self).copy(*args, **kwargs) + if isinstance(retty, types.Array): + return UsmSharedArrayType( + dtype=retty.dtype, ndim=retty.ndim, layout=retty.layout + ) + else: + return retty + + # Tell Numba typing how to combine UsmSharedArrayType with other ndarray types. + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + if method == "__call__": + for inp in inputs: + if not isinstance(inp, (UsmSharedArrayType, types.Array, types.Number)): + return None + + return UsmSharedArrayType + else: + return None + + +# This tells Numba how to create a UsmSharedArrayType when a usmarray is passed +# into a njit function. +@typeof_impl.register(ndarray) +def typeof_ta_ndarray(val, c): + try: + dtype = numpy_support.from_dtype(val.dtype) + except NotImplementedError: + raise ValueError("Unsupported array dtype: %s" % (val.dtype,)) + layout = numpy_support.map_layout(val) + readonly = not val.flags.writeable + return UsmSharedArrayType(dtype, val.ndim, layout, readonly=readonly) + + +# This tells Numba to use the default Numpy ndarray data layout for +# object of type UsmArray. +register_model(UsmSharedArrayType)(numba.core.datamodel.models.ArrayModel) + +# This tells Numba how to convert from its native representation +# of a UsmArray in a njit function back to a Python UsmArray. +@box(UsmSharedArrayType) +def box_array(typ, val, c): + nativearycls = c.context.make_array(typ) + nativeary = nativearycls(c.context, c.builder, value=val) + if c.context.enable_nrt: + np_dtype = numpy_support.as_dtype(typ.dtype) + dtypeptr = c.env_manager.read_const(c.env_manager.add_const(np_dtype)) + # Steals NRT ref + newary = c.pyapi.nrt_adapt_ndarray_to_python(typ, val, dtypeptr) + return newary + else: + parent = nativeary.parent + c.pyapi.incref(parent) + return parent + + +# This tells Numba to use this function when it needs to allocate a +# UsmArray in a njit function. +@allocator(UsmSharedArrayType) +def allocator_UsmArray(context, builder, size, align): + context.nrt._require_nrt() + + mod = builder.module + u32 = ir.IntType(32) + + # Get the Numba external allocator for USM memory. + ext_allocator_fnty = ir.FunctionType(cgutils.voidptr_t, []) + ext_allocator_fn = mod.get_or_insert_function( + ext_allocator_fnty, name="usmarray_get_ext_allocator" + ) + ext_allocator = builder.call(ext_allocator_fn, []) + # Get the Numba function to allocate an aligned array with an external allocator. + fnty = ir.FunctionType(cgutils.voidptr_t, [cgutils.intp_t, u32, cgutils.voidptr_t]) + fn = mod.get_or_insert_function( + fnty, name="NRT_MemInfo_alloc_safe_aligned_external" + ) + fn.return_value.add_attribute("noalias") + if isinstance(align, builtins.int): + align = context.get_constant(types.uint32, align) + else: + assert align.type == u32, "align must be a uint32" + return builder.call(fn, [size, align, ext_allocator]) + + +_registered = False + + +def is_usm_callback(obj): + dprint("is_usm_callback:", obj, type(obj)) + if isinstance(obj, numba.core.runtime._nrt_python._MemInfo): + mobj = obj + while isinstance(mobj, numba.core.runtime._nrt_python._MemInfo): + ea = mobj.external_allocator + dppl_rt_allocator = numba_dppy._dppy_rt.get_external_allocator() + dprint("Checking MemInfo:", ea) + if ea == dppl_rt_allocator: + return True + mobj = mobj.parent + if isinstance(mobj, ndarray): + mobj = mobj.base + return False + + +def numba_register(): + global _registered + if not _registered: + _registered = True + ndarray.add_external_usm_checker(is_usm_callback) + numba_register_typing() + numba_register_lower_builtin() + + +# Copy a function registered as a lowerer in Numba but change the +# "np" import in Numba to point to usmarray instead of NumPy. +def copy_func_for_usmarray(f, usmarray_mod): + import copy as cc + + # Make a copy so our change below doesn't affect anything else. + gglobals = cc.copy(f.__globals__) + # Make the "np"'s in the code use usmarray instead of Numba's default NumPy. + gglobals["np"] = usmarray_mod + # Create a new function using the original code but the new globals. + g = ftype(f.__code__, gglobals, None, f.__defaults__, f.__closure__) + # Some other tricks to make sure the function copy works. + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = f.__kwdefaults__ + return g + + +def types_replace_array(x): + return tuple([z if z != types.Array else UsmSharedArrayType for z in x]) + + +def numba_register_lower_builtin(): + todo = [] + todo_builtin = [] + todo_getattr = [] + todo_array_member_func = [] + + # For all Numpy identifiers that have been registered for typing in Numba... + # this registry contains functions, getattrs, setattrs, casts and constants... + for ig in lower_registry.functions: + impl, func, types = ig + dprint("Numpy lowered registry functions:", impl, func, type(func), types) + # If it is a Numpy function... + if isinstance(func, ftype): + dprint("is ftype") + if func.__module__ == np.__name__: + dprint("is Numpy module") + # If we have overloaded that function in the usmarray module (always True right now)... + if func.__name__ in functions_list: + todo.append(ig) + if isinstance(func, bftype): + dprint("is bftype") + if func.__module__ == np.__name__: + dprint("is Numpy module") + # If we have overloaded that function in the usmarray module (always True right now)... + if func.__name__ in functions_list: + todo.append(ig) + if isinstance(func, str) and func.startswith("array."): + todo_array_member_func.append(ig) + + for lg in lower_registry.getattrs: + func, attr, types = lg + dprint("Numpy lowered registry getattrs:", func, attr, types) + types_with_usmarray = types_replace_array(types) + if UsmSharedArrayType in types_with_usmarray: + dprint( + "lower_getattr:", func, type(func), attr, type(attr), types, type(types) + ) + todo_getattr.append((func, attr, types_with_usmarray)) + + for lg in todo_getattr: + lower_registry.getattrs.append(lg) + + for impl, func, types in todo + todo_builtin: + try: + usmarray_func = eval("dpctl.dptensor.numpy_usm_shared." + func.__name__) + except: + dprint("failed to eval", func.__name__) + continue + dprint( + "need to re-register lowerer for usmarray", impl, func, types, usmarray_func + ) + new_impl = copy_func_for_usmarray(impl, nus) + lower_registry.functions.append((new_impl, usmarray_func, types)) + + for impl, func, types in todo_array_member_func: + types_with_usmarray = types_replace_array(types) + usmarray_func = "usm" + func + dprint("Registering lowerer for", impl, usmarray_func, types_with_usmarray) + new_impl = copy_func_for_usmarray(impl, nus) + lower_registry.functions.append((new_impl, usmarray_func, types_with_usmarray)) + + +def argspec_to_string(argspec): + first_default_arg = len(argspec.args) - len(argspec.defaults) + non_def = argspec.args[:first_default_arg] + arg_zip = list(zip(argspec.args[first_default_arg:], argspec.defaults)) + combined = [a + "=" + str(b) for a, b in arg_zip] + return ",".join(non_def + combined) + + +def numba_register_typing(): + todo = [] + todo_classes = [] + todo_getattr = [] + + # For all Numpy identifiers that have been registered for typing in Numba... + for ig in typing_registry.globals: + val, typ = ig + dprint("Numpy registered:", val, type(val), typ, type(typ)) + # If it is a Numpy function... + if isinstance(val, (ftype, bftype)): + # If we have overloaded that function in the usmarray module (always True right now)... + if val.__name__ in functions_list: + todo.append(ig) + if isinstance(val, type): + if isinstance(typ, numba.core.types.functions.Function): + todo.append(ig) + elif isinstance(typ, numba.core.types.functions.NumberClass): + pass + + for tgetattr in templates_registry.attributes: + dprint("Numpy getattr:", tgetattr, type(tgetattr), tgetattr.key) + if tgetattr.key == types.Array: + todo_getattr.append(tgetattr) + + for val, typ in todo_classes: + dprint("todo_classes:", val, typ, type(typ)) + + try: + dptype = eval("dpctl.dptensor.numpy_usm_shared." + val.__name__) + except: + dprint("failed to eval", val.__name__) + continue + + typing_registry.register_global( + dptype, numba.core.types.NumberClass(typ.instance_type) + ) + + for val, typ in todo: + assert len(typ.templates) == 1 + # template is the typing class to invoke generic() upon. + template = typ.templates[0] + dprint("need to re-register for usmarray", val, typ, typ.typing_key) + try: + dpval = eval("dpctl.dptensor.numpy_usm_shared." + val.__name__) + except: + dprint("failed to eval", val.__name__) + continue + dprint("--------------------------------------------------------------") + dprint("need to re-register for usmarray", val, typ, typ.typing_key) + dprint("val:", val, type(val), "dir val", dir(val)) + dprint("typ:", typ, type(typ), "dir typ", dir(typ)) + dprint("typing key:", typ.typing_key) + dprint("name:", typ.name) + dprint("key:", typ.key) + dprint("templates:", typ.templates) + dprint("template:", template, type(template)) + dprint("dpval:", dpval, type(dpval)) + dprint("--------------------------------------------------------------") + + class_name = "DparrayTemplate_" + val.__name__ + + @classmethod + def set_key_original(cls, key, original): + cls.key = key + cls.original = original + + def generic_impl(self): + original_typer = self.__class__.original.generic(self.__class__.original) + ot_argspec = inspect.getfullargspec(original_typer) + astr = argspec_to_string(ot_argspec) + + typer_func = """def typer({}): + original_res = original_typer({}) + if isinstance(original_res, types.Array): + return UsmSharedArrayType( + dtype=original_res.dtype, + ndim=original_res.ndim, + layout=original_res.layout + ) + return original_res""".format( + astr, ",".join(ot_argspec.args) + ) + + try: + gs = globals() + ls = locals() + gs["original_typer"] = ls["original_typer"] + exec(typer_func, globals(), locals()) + except NameError as ne: + print("NameError in exec:", ne) + sys.exit(0) + except: + print("exec failed!", sys.exc_info()[0]) + sys.exit(0) + + try: + exec_res = eval("typer") + except NameError as ne: + print("NameError in eval:", ne) + sys.exit(0) + except: + print("eval failed!", sys.exc_info()[0]) + sys.exit(0) + + return exec_res + + new_usmarray_template = type( + class_name, + (template,), + {"set_class_vars": set_key_original, "generic": generic_impl}, + ) + + new_usmarray_template.set_class_vars(dpval, template) + + assert callable(dpval) + type_handler = types.Function(new_usmarray_template) + typing_registry.register_global(dpval, type_handler) + + # Handle usmarray attribute typing. + # This explicit register_attr of a copied/modified UsmArrayAttribute + # may be removed in the future in favor of the below commented out code + # once we get this registration code to run after everything is registered + # in Numba. Right now, the attribute registrations we need are happening + # after the registration callback that gets us here so we would miss the + # attribute registrations we need. + typing_registry.register_attr(UsmArrayAttribute) + + +class UsmArrayAttribute(AttributeTemplate): + key = UsmSharedArrayType + + def resolve_dtype(self, ary): + return types.DType(ary.dtype) + + def resolve_itemsize(self, ary): + return types.intp + + def resolve_shape(self, ary): + return types.UniTuple(types.intp, ary.ndim) + + def resolve_strides(self, ary): + return types.UniTuple(types.intp, ary.ndim) + + def resolve_ndim(self, ary): + return types.intp + + def resolve_size(self, ary): + return types.intp + + def resolve_flat(self, ary): + return types.NumpyFlatType(ary) + + def resolve_ctypes(self, ary): + return types.ArrayCTypes(ary) + + def resolve_flags(self, ary): + return types.ArrayFlags(ary) + + def convert_array_to_usmarray(self, retty): + if isinstance(retty, types.Array): + return UsmSharedArrayType( + dtype=retty.dtype, ndim=retty.ndim, layout=retty.layout + ) + else: + return retty + + def resolve_T(self, ary): + if ary.ndim <= 1: + retty = ary + else: + layout = {"C": "F", "F": "C"}.get(ary.layout, "A") + retty = ary.copy(layout=layout) + return retty + + def resolve_real(self, ary): + return self._resolve_real_imag(ary, attr="real") + + def resolve_imag(self, ary): + return self._resolve_real_imag(ary, attr="imag") + + def _resolve_real_imag(self, ary, attr): + if ary.dtype in types.complex_domain: + return ary.copy(dtype=ary.dtype.underlying_float, layout="A") + elif ary.dtype in types.number_domain: + res = ary.copy(dtype=ary.dtype) + if attr == "imag": + res = res.copy(readonly=True) + return res + else: + msg = "cannot access .{} of array of {}" + raise TypingError(msg.format(attr, ary.dtype)) + + @bound_function("usmarray.copy") + def resolve_copy(self, ary, args, kws): + assert not args + assert not kws + retty = ary.copy(layout="C", readonly=False) + return signature(retty) + + @bound_function("usmarray.transpose") + def resolve_transpose(self, ary, args, kws): + def sentry_shape_scalar(ty): + if ty in types.number_domain: + # Guard against non integer type + if not isinstance(ty, types.Integer): + raise TypeError("transpose() arg cannot be {0}".format(ty)) + return True + else: + return False + + assert not kws + if len(args) == 0: + return signature(self.resolve_T(ary)) + + if len(args) == 1: + (shape,) = args + + if sentry_shape_scalar(shape): + assert ary.ndim == 1 + return signature(ary, *args) + + if isinstance(shape, types.NoneType): + return signature(self.resolve_T(ary)) + + shape = normalize_shape(shape) + if shape is None: + return + + assert ary.ndim == shape.count + return signature(self.resolve_T(ary).copy(layout="A"), shape) + + else: + if any(not sentry_shape_scalar(a) for a in args): + raise TypeError( + "transpose({0}) is not supported".format(", ".join(args)) + ) + assert ary.ndim == len(args) + return signature(self.resolve_T(ary).copy(layout="A"), *args) + + @bound_function("usmarray.item") + def resolve_item(self, ary, args, kws): + assert not kws + # We don't support explicit arguments as that's exactly equivalent + # to regular indexing. The no-argument form is interesting to + # allow some degree of genericity when writing functions. + if not args: + return signature(ary.dtype) + + @bound_function("usmarray.itemset") + def resolve_itemset(self, ary, args, kws): + assert not kws + # We don't support explicit arguments as that's exactly equivalent + # to regular indexing. The no-argument form is interesting to + # allow some degree of genericity when writing functions. + if len(args) == 1: + return signature(types.none, ary.dtype) + + @bound_function("usmarray.nonzero") + def resolve_nonzero(self, ary, args, kws): + assert not args + assert not kws + # 0-dim arrays return one result array + ndim = max(ary.ndim, 1) + retty = types.UniTuple(UsmSharedArrayType(types.intp, 1, "C"), ndim) + return signature(retty) + + @bound_function("usmarray.reshape") + def resolve_reshape(self, ary, args, kws): + def sentry_shape_scalar(ty): + if ty in types.number_domain: + # Guard against non integer type + if not isinstance(ty, types.Integer): + raise TypeError("reshape() arg cannot be {0}".format(ty)) + return True + else: + return False + + assert not kws + if ary.layout not in "CF": + # only work for contiguous array + raise TypeError("reshape() supports contiguous array only") + + if len(args) == 1: + # single arg + (shape,) = args + + if sentry_shape_scalar(shape): + ndim = 1 + else: + shape = normalize_shape(shape) + if shape is None: + return + ndim = shape.count + retty = ary.copy(ndim=ndim) + return signature(retty, shape) + + elif len(args) == 0: + # no arg + raise TypeError("reshape() take at least one arg") + + else: + # vararg case + if any(not sentry_shape_scalar(a) for a in args): + raise TypeError( + "reshape({0}) is not supported".format(", ".join(map(str, args))) + ) + + retty = ary.copy(ndim=len(args)) + return signature(retty, *args) + + @bound_function("usmarray.sort") + def resolve_sort(self, ary, args, kws): + assert not args + assert not kws + if ary.ndim == 1: + return signature(types.none) + + @bound_function("usmarray.argsort") + def resolve_argsort(self, ary, args, kws): + assert not args + kwargs = dict(kws) + kind = kwargs.pop("kind", types.StringLiteral("quicksort")) + if not isinstance(kind, types.StringLiteral): + raise errors.TypingError('"kind" must be a string literal') + if kwargs: + msg = "Unsupported keywords: {!r}" + raise TypingError(msg.format([k for k in kwargs.keys()])) + if ary.ndim == 1: + + def argsort_stub(kind="quicksort"): + pass + + pysig = utils.pysignature(argsort_stub) + sig = signature(UsmSharedArrayType(types.intp, 1, "C"), kind).replace( + pysig=pysig + ) + return sig + + @bound_function("usmarray.view") + def resolve_view(self, ary, args, kws): + from .npydecl import parse_dtype + + assert not kws + (dtype,) = args + dtype = parse_dtype(dtype) + if dtype is None: + return + retty = ary.copy(dtype=dtype) + return signature(retty, *args) + + @bound_function("usmarray.astype") + def resolve_astype(self, ary, args, kws): + from .npydecl import parse_dtype + + assert not kws + (dtype,) = args + dtype = parse_dtype(dtype) + if dtype is None: + return + if not self.context.can_convert(ary.dtype, dtype): + raise TypeError( + "astype(%s) not supported on %s: " + "cannot convert from %s to %s" % (dtype, ary, ary.dtype, dtype) + ) + layout = ary.layout if ary.layout in "CF" else "C" + # reset the write bit irrespective of whether the cast type is the same + # as the current dtype, this replicates numpy + retty = ary.copy(dtype=dtype, layout=layout, readonly=False) + return signature(retty, *args) + + @bound_function("usmarray.ravel") + def resolve_ravel(self, ary, args, kws): + # Only support no argument version (default order='C') + assert not kws + assert not args + return signature(ary.copy(ndim=1, layout="C")) + + @bound_function("usmarray.flatten") + def resolve_flatten(self, ary, args, kws): + # Only support no argument version (default order='C') + assert not kws + assert not args + return signature(ary.copy(ndim=1, layout="C")) + + @bound_function("usmarray.take") + def resolve_take(self, ary, args, kws): + assert not kws + (argty,) = args + if isinstance(argty, types.Integer): + sig = signature(ary.dtype, *args) + elif isinstance(argty, UsmSharedArrayType): + sig = signature(argty.copy(layout="C", dtype=ary.dtype), *args) + elif isinstance(argty, types.List): # 1d lists only + sig = signature(UsmSharedArrayType(ary.dtype, 1, "C"), *args) + elif isinstance(argty, types.BaseTuple): + sig = signature(UsmSharedArrayType(ary.dtype, np.ndim(argty), "C"), *args) + else: + raise TypeError("take(%s) not supported for %s" % argty) + return sig + + def generic_resolve(self, ary, attr): + # Resolution of other attributes, for record arrays + if isinstance(ary.dtype, types.Record): + if attr in ary.dtype.fields: + return ary.copy(dtype=ary.dtype.typeof(attr), layout="A") + + +@typing_registry.register_global(nus.as_ndarray) +class DparrayAsNdarray(CallableTemplate): + def generic(self): + def typer(arg): + return types.Array(dtype=arg.dtype, ndim=arg.ndim, layout=arg.layout) + + return typer + + +@typing_registry.register_global(nus.from_ndarray) +class DparrayFromNdarray(CallableTemplate): + def generic(self): + def typer(arg): + return UsmSharedArrayType(dtype=arg.dtype, ndim=arg.ndim, layout=arg.layout) + + return typer + + +@lower_registry.lower(nus.as_ndarray, UsmSharedArrayType) +def usmarray_conversion_as(context, builder, sig, args): + return _array_copy(context, builder, sig, args) + + +@lower_registry.lower(nus.from_ndarray, types.Array) +def usmarray_conversion_from(context, builder, sig, args): + return _array_copy(context, builder, sig, args) diff --git a/numba_dppy/tests/test_usmarray.py b/numba_dppy/tests/test_usmarray.py new file mode 100644 index 0000000000..b86f9476d7 --- /dev/null +++ b/numba_dppy/tests/test_usmarray.py @@ -0,0 +1,203 @@ +import numba +import numpy +import unittest + +import dpctl.dptensor.numpy_usm_shared as usmarray + + +@numba.njit() +def numba_mul_add(a): + return a * 2.0 + 13 + + +@numba.njit() +def numba_add_const(a): + return a + 13 + + +@numba.njit() +def numba_mul(a, b): # a is usmarray, b is numpy + return a * b + + +@numba.njit() +def numba_mul_usmarray_asarray(a, b): # a is usmarray, b is numpy + return a * usmarray.asarray(b) + + +@numba.njit +def numba_usmarray_as_ndarray(a): + return usmarray.as_ndarray(a) + + +@numba.njit +def numba_usmarray_from_ndarray(a): + return usmarray.from_ndarray(a) + + +@numba.njit() +def numba_usmarray_ones(): + return usmarray.ones(10) + + +@numba.njit +def numba_usmarray_empty(): + return usmarray.empty((10, 10)) + + +@numba.njit() +def numba_identity(a): + return a + + +@numba.njit +def numba_shape(x): + return x.shape + + +@numba.njit +def numba_T(x): + return x.T + + +@numba.njit +def numba_reshape(x): + return x.reshape((4, 3)) + + +class TestUsmArray(unittest.TestCase): + def ndarray(self): + """Create NumPy array""" + return numpy.ones(10) + + def usmarray(self): + """Create dpCtl USM array""" + return usmarray.ones(10) + + def test_python_numpy(self): + """Testing Python Numpy""" + z2 = numba_mul_add.py_func(self.ndarray()) + self.assertEqual(type(z2), numpy.ndarray, z2) + + def test_numba_numpy(self): + """Testing Numba Numpy""" + z2 = numba_mul_add(self.ndarray()) + self.assertEqual(type(z2), numpy.ndarray, z2) + + def test_usmarray_ones(self): + """Testing usmarray ones""" + a = usmarray.ones(10) + self.assertIsInstance(a, usmarray.ndarray, type(a)) + self.assertTrue(usmarray.has_array_interface(a)) + + def test_usmarray_usmarray_as_ndarray(self): + """Testing usmarray.usmarray.as_ndarray""" + nd1 = self.usmarray().as_ndarray() + self.assertEqual(type(nd1), numpy.ndarray, nd1) + + def test_usmarray_as_ndarray(self): + """Testing usmarray.as_ndarray""" + nd2 = usmarray.as_ndarray(self.usmarray()) + self.assertEqual(type(nd2), numpy.ndarray, nd2) + + def test_usmarray_from_ndarray(self): + """Testing usmarray.from_ndarray""" + nd2 = usmarray.as_ndarray(self.usmarray()) + dp1 = usmarray.from_ndarray(nd2) + self.assertIsInstance(dp1, usmarray.ndarray, type(dp1)) + self.assertTrue(usmarray.has_array_interface(dp1)) + + def test_usmarray_multiplication(self): + """Testing usmarray multiplication""" + c = self.usmarray() * 5 + self.assertIsInstance(c, usmarray.ndarray, type(c)) + self.assertTrue(usmarray.has_array_interface(c)) + + def test_python_usmarray_mul_add(self): + """Testing Python usmarray""" + c = self.usmarray() * 5 + b = numba_mul_add.py_func(c) + self.assertIsInstance(b, usmarray.ndarray, type(b)) + self.assertTrue(usmarray.has_array_interface(b)) + + def test_numba_usmarray_mul_add(self): + """Testing Numba usmarray""" + # fails if run tests in bunch + c = self.usmarray() * 5 + b = numba_mul_add(c) + self.assertIsInstance(b, usmarray.ndarray, type(b)) + self.assertTrue(usmarray.has_array_interface(b)) + + def test_python_mixing_usmarray_and_numpy_ndarray(self): + """Testing Python mixing usmarray and numpy.ndarray""" + h = numba_mul.py_func(self.usmarray(), self.ndarray()) + self.assertIsInstance(h, usmarray.ndarray, type(h)) + self.assertTrue(usmarray.has_array_interface(h)) + + def test_numba_usmarray_2(self): + """Testing Numba usmarray 2""" + d = numba_identity(self.usmarray()) + self.assertIsInstance(d, usmarray.ndarray, type(d)) + self.assertTrue(usmarray.has_array_interface(d)) + + @unittest.expectedFailure + def test_numba_usmarray_constructor_from_numpy_ndarray(self): + """Testing Numba usmarray constructor from numpy.ndarray""" + e = numba_mul_usmarray_asarray(self.usmarray(), self.ndarray()) + self.assertIsInstance(e, usmarray.ndarray, type(e)) + + def test_numba_mixing_usmarray_and_constant(self): + """Testing Numba mixing usmarray and constant""" + g = numba_add_const(self.usmarray()) + self.assertIsInstance(g, usmarray.ndarray, type(g)) + self.assertTrue(usmarray.has_array_interface(g)) + + def test_numba_mixing_usmarray_and_numpy_ndarray(self): + """Testing Numba mixing usmarray and numpy.ndarray""" + h = numba_mul(self.usmarray(), self.ndarray()) + self.assertIsInstance(h, usmarray.ndarray, type(h)) + self.assertTrue(usmarray.has_array_interface(h)) + + def test_numba_usmarray_functions(self): + """Testing Numba usmarray functions""" + f = numba_usmarray_ones() + self.assertIsInstance(f, usmarray.ndarray, type(f)) + self.assertTrue(usmarray.has_array_interface(f)) + + def test_numba_usmarray_as_ndarray(self): + """Testing Numba usmarray.as_ndarray""" + nd3 = numba_usmarray_as_ndarray(self.usmarray()) + self.assertEqual(type(nd3), numpy.ndarray, nd3) + + def test_numba_usmarray_from_ndarray(self): + """Testing Numba usmarray.from_ndarray""" + nd3 = numba_usmarray_as_ndarray(self.usmarray()) + dp2 = numba_usmarray_from_ndarray(nd3) + self.assertIsInstance(dp2, usmarray.ndarray, type(dp2)) + self.assertTrue(usmarray.has_array_interface(dp2)) + + def test_numba_usmarray_empty(self): + """Testing Numba usmarray.empty""" + dp3 = numba_usmarray_empty() + self.assertIsInstance(dp3, usmarray.ndarray, type(dp3)) + self.assertTrue(usmarray.has_array_interface(dp3)) + + def test_numba_usmarray_shape(self): + """Testing Numba usmarray.shape""" + s1 = numba_shape(numba_usmarray_empty()) + self.assertIsInstance(s1, tuple, type(s1)) + self.assertEqual(s1, (10, 10)) + + def test_numba_usmarray_T(self): + """Testing Numba usmarray.T""" + dp4 = numba_T(numba_usmarray_empty()) + self.assertIsInstance(dp4, usmarray.ndarray, type(dp4)) + self.assertTrue(usmarray.has_array_interface(dp4)) + + @unittest.expectedFailure + def test_numba_usmarray_reshape(self): + """Testing Numba usmarray.reshape()""" + a = usmarray.ones(12) + s1 = numba_reshape(a) + self.assertIsInstance(s1, usmarray.ndarray, type(s1)) + self.assertEqual(s1.shape, (4, 3)) diff --git a/setup.py b/setup.py index 8c892f6fd2..5ce0234bb8 100644 --- a/setup.py +++ b/setup.py @@ -3,11 +3,21 @@ from Cython.Build import cythonize import versioneer +import sys def get_ext_modules(): ext_modules = [] + import numba + + ext_dppy = Extension( + name="numba_dppy._dppy_rt", + sources=["numba_dppy/dppy_rt.c"], + include_dirs=[numba.core.extending.include_path()], + ) + ext_modules += [ext_dppy] + dpnp_present = False try: import dpnp @@ -66,6 +76,11 @@ def get_ext_modules(): "Topic :: Software Development :: Compilers", ], cmdclass=versioneer.get_cmdclass(), + entry_points={ + "numba_extensions": [ + "init = numba_dppy.numpy_usm_shared:numba_register", + ] + }, ) setup(**metadata)