diff --git a/numba/_typeof.c b/numba/_typeof.c index ffe0e3a3c58..9b259164800 100644 --- a/numba/_typeof.c +++ b/numba/_typeof.c @@ -768,6 +768,7 @@ int typeof_typecode(PyObject *dispatcher, PyObject *val) { PyTypeObject *tyobj = Py_TYPE(val); + int no_subtype_attr; /* This needs to be kept in sync with Dispatcher.typeof_pyval(), * otherwise funny things may happen. */ @@ -794,9 +795,19 @@ typeof_typecode(PyObject *dispatcher, PyObject *val) return typecode_arrayscalar(dispatcher, val); } /* Array handling */ - else if (PyType_IsSubtype(tyobj, &PyArray_Type)) { + else if (tyobj == &PyArray_Type) { return typecode_ndarray(dispatcher, (PyArrayObject*)val); } + /* Subtypes of Array handling */ + else if (PyType_IsSubtype(tyobj, &PyArray_Type)) { + /* If the class has an attribute named __numba_no_subtype_ndarray then + don't treat it as a normal variant of a Numpy ndarray but as it's own + separate type. */ + no_subtype_attr = PyObject_HasAttrString(val, "__numba_no_subtype_ndarray__"); + if (!no_subtype_attr) { + return typecode_ndarray(dispatcher, (PyArrayObject*)val); + } + } return typecode_using_fingerprint(dispatcher, val); } diff --git a/numba/core/extending.py b/numba/core/extending.py index 8d8d8525e21..09373708b48 100644 --- a/numba/core/extending.py +++ b/numba/core/extending.py @@ -14,7 +14,7 @@ lower_setattr, lower_setattr_generic, lower_cast) # noqa: F401 from numba.core.datamodel import models # noqa: F401 from numba.core.datamodel import register_default as register_model # noqa: F401, E501 -from numba.core.pythonapi import box, unbox, reflect, NativeValue # noqa: F401 +from numba.core.pythonapi import box, unbox, reflect, NativeValue, allocator # noqa: F401 from numba._helperlib import _import_cython_function # noqa: F401 from numba.core.serialize import ReduceMixin diff --git a/numba/core/ir_utils.py b/numba/core/ir_utils.py index 1d58c5c8b5b..9ffdfb16b07 100644 --- a/numba/core/ir_utils.py +++ b/numba/core/ir_utils.py @@ -64,6 +64,8 @@ def mk_alloc(typemap, calltypes, lhs, size_var, dtype, scope, loc): out = [] ndims = 1 size_typ = types.intp + # Get the type of the array being allocated. + arr_typ = typemap[lhs.name] if isinstance(size_var, tuple): if len(size_var) == 1: size_var = size_var[0] @@ -108,11 +110,13 @@ def mk_alloc(typemap, calltypes, lhs, size_var, dtype, scope, loc): typ_var_assign = ir.Assign(np_typ_getattr, typ_var, loc) alloc_call = ir.Expr.call(attr_var, [size_var, typ_var], (), loc) if calltypes: - calltypes[alloc_call] = typemap[attr_var.name].get_call_type( + cac = typemap[attr_var.name].get_call_type( typing.Context(), [size_typ, types.functions.NumberClass(dtype)], {}) - # signature( - # types.npytypes.Array(dtype, ndims, 'C'), size_typ, - # types.functions.NumberClass(dtype)) + # By default, all calls to "empty" are typed as returning a standard + # Numpy ndarray. If we are allocating a ndarray subclass here then + # just change the return type to be that of the subclass. + cac._return_type = arr_typ + calltypes[alloc_call] = cac alloc_assign = ir.Assign(alloc_call, lhs, loc) out.extend([g_np_assign, attr_assign, typ_var_assign, alloc_assign]) diff --git a/numba/core/pythonapi.py b/numba/core/pythonapi.py index 7901e761d9f..f84ad7b2ce1 100644 --- a/numba/core/pythonapi.py +++ b/numba/core/pythonapi.py @@ -45,10 +45,13 @@ def lookup(self, typeclass, default=None): _boxers = _Registry() _unboxers = _Registry() _reflectors = _Registry() +# Registry of special allocators for types. +_allocators = _Registry() box = _boxers.register unbox = _unboxers.register reflect = _reflectors.register +allocator = _allocators.register class _BoxContext(namedtuple("_BoxContext", ("context", "builder", "pyapi", "env_manager"))): @@ -1186,8 +1189,11 @@ def nrt_adapt_ndarray_to_python(self, aryty, ary, dtypeptr): assert self.context.enable_nrt, "NRT required" intty = ir.IntType(32) + # Embed the Python type of the array (maybe subclass) in the LLVM. + serial_aryty_pytype = self.unserialize(self.serialize_object(aryty.py_type)) + fnty = Type.function(self.pyobj, - [self.voidptr, intty, intty, self.pyobj]) + [self.voidptr, self.pyobj, intty, intty, self.pyobj]) fn = self._get_function(fnty, name="NRT_adapt_ndarray_to_python") fn.args[0].add_attribute(lc.ATTR_NO_CAPTURE) @@ -1197,6 +1203,7 @@ def nrt_adapt_ndarray_to_python(self, aryty, ary, dtypeptr): aryptr = cgutils.alloca_once_value(self.builder, ary) return self.builder.call(fn, [self.builder.bitcast(aryptr, self.voidptr), + serial_aryty_pytype, ndim, writable, dtypeptr]) def nrt_meminfo_new_from_pyobject(self, data, pyobj): diff --git a/numba/core/runtime/_nrt_python.c b/numba/core/runtime/_nrt_python.c index 33620fd4f1a..efe4467df70 100644 --- a/numba/core/runtime/_nrt_python.c +++ b/numba/core/runtime/_nrt_python.c @@ -55,6 +55,8 @@ int MemInfo_init(MemInfoObject *self, PyObject *args, PyObject *kwds) { return -1; } raw_ptr = PyLong_AsVoidPtr(raw_ptr_obj); + NRT_Debug(nrt_debug_print("MemInfo_init self=%p raw_ptr=%p\n", self, raw_ptr)); + if(PyErr_Occurred()) return -1; self->meminfo = (NRT_MemInfo *)raw_ptr; assert (NRT_MemInfo_refcount(self->meminfo) > 0 && "0 refcount"); @@ -109,6 +111,27 @@ MemInfo_get_refcount(MemInfoObject *self, void *closure) { return PyLong_FromSize_t(refct); } +static +PyObject* +MemInfo_get_external_allocator(MemInfoObject *self, void *closure) { + void *p = NRT_MemInfo_external_allocator(self->meminfo); + printf("MemInfo_get_external_allocator %p\n", p); + return PyLong_FromVoidPtr(p); +} + +static +PyObject* +MemInfo_get_parent(MemInfoObject *self, void *closure) { + void *p = NRT_MemInfo_parent(self->meminfo); + if (p) { + Py_INCREF(p); + return (PyObject*)p; + } else { + Py_INCREF(Py_None); + return Py_None; + } +} + static void MemInfo_dealloc(MemInfoObject *self) { @@ -136,6 +159,13 @@ static PyGetSetDef MemInfo_getsets[] = { (getter)MemInfo_get_refcount, NULL, "Get the refcount", NULL}, + {"external_allocator", + (getter)MemInfo_get_external_allocator, NULL, + "Get the external allocator", + NULL}, + {"parent", + (getter)MemInfo_get_parent, NULL, + NULL}, {NULL} /* Sentinel */ }; @@ -286,7 +316,7 @@ PyObject* try_to_return_parent(arystruct_t *arystruct, int ndim, } NUMBA_EXPORT_FUNC(PyObject *) -NRT_adapt_ndarray_to_python(arystruct_t* arystruct, int ndim, +NRT_adapt_ndarray_to_python(arystruct_t* arystruct, PyTypeObject *retty, int ndim, int writeable, PyArray_Descr *descr) { PyArrayObject *array; @@ -324,10 +354,13 @@ NRT_adapt_ndarray_to_python(arystruct_t* arystruct, int ndim, args = PyTuple_New(1); /* SETITEM steals reference */ PyTuple_SET_ITEM(args, 0, PyLong_FromVoidPtr(arystruct->meminfo)); + NRT_Debug(nrt_debug_print("NRT_adapt_ndarray_to_python arystruct->meminfo=%p\n", arystruct->meminfo)); /* Note: MemInfo_init() does not incref. This function steals the * NRT reference. */ + NRT_Debug(nrt_debug_print("NRT_adapt_ndarray_to_python created MemInfo=%p\n", miobj)); if (MemInfo_init(miobj, args, NULL)) { + NRT_Debug(nrt_debug_print("MemInfo_init returned 0.\n")); return NULL; } Py_DECREF(args); @@ -336,7 +369,7 @@ NRT_adapt_ndarray_to_python(arystruct_t* arystruct, int ndim, shape = arystruct->shape_and_strides; strides = shape + ndim; Py_INCREF((PyObject *) descr); - array = (PyArrayObject *) PyArray_NewFromDescr(&PyArray_Type, descr, ndim, + array = (PyArrayObject *) PyArray_NewFromDescr(retty, descr, ndim, shape, strides, arystruct->data, flags, (PyObject *) miobj); diff --git a/numba/core/runtime/_nrt_pythonmod.c b/numba/core/runtime/_nrt_pythonmod.c index 31e1155fd9f..d1300ee8e9a 100644 --- a/numba/core/runtime/_nrt_pythonmod.c +++ b/numba/core/runtime/_nrt_pythonmod.c @@ -163,6 +163,7 @@ declmethod(MemInfo_alloc); declmethod(MemInfo_alloc_safe); declmethod(MemInfo_alloc_aligned); declmethod(MemInfo_alloc_safe_aligned); +declmethod(MemInfo_alloc_safe_aligned_external); declmethod(MemInfo_alloc_dtor_safe); declmethod(MemInfo_call_dtor); declmethod(MemInfo_new_varsize); diff --git a/numba/core/runtime/nrt.c b/numba/core/runtime/nrt.c index 534681d5417..fe63a691537 100644 --- a/numba/core/runtime/nrt.c +++ b/numba/core/runtime/nrt.c @@ -19,6 +19,7 @@ struct MemInfo { void *dtor_info; void *data; size_t size; /* only used for NRT allocated memory */ + NRT_ExternalAllocator *external_allocator; }; @@ -170,13 +171,16 @@ void NRT_MemSys_set_atomic_cas_stub(void) { */ void NRT_MemInfo_init(NRT_MemInfo *mi,void *data, size_t size, - NRT_dtor_function dtor, void *dtor_info) + NRT_dtor_function dtor, void *dtor_info, + NRT_ExternalAllocator *external_allocator) { mi->refct = 1; /* starts with 1 refct */ mi->dtor = dtor; mi->dtor_info = dtor_info; mi->data = data; mi->size = size; + mi->external_allocator = external_allocator; + NRT_Debug(nrt_debug_print("NRT_MemInfo_init mi=%p external_allocator=%p\n", mi, external_allocator)); /* Update stats */ TheMSys.atomic_inc(&TheMSys.stats_mi_alloc); } @@ -185,7 +189,8 @@ NRT_MemInfo *NRT_MemInfo_new(void *data, size_t size, NRT_dtor_function dtor, void *dtor_info) { NRT_MemInfo *mi = NRT_Allocate(sizeof(NRT_MemInfo)); - NRT_MemInfo_init(mi, data, size, dtor, dtor_info); + NRT_Debug(nrt_debug_print("NRT_MemInfo_new mi=%p\n", mi)); + NRT_MemInfo_init(mi, data, size, dtor, dtor_info, NULL); return mi; } @@ -206,9 +211,10 @@ void nrt_internal_dtor_safe(void *ptr, size_t size, void *info) { } static -void *nrt_allocate_meminfo_and_data(size_t size, NRT_MemInfo **mi_out) { +void *nrt_allocate_meminfo_and_data(size_t size, NRT_MemInfo **mi_out, NRT_ExternalAllocator *allocator) { NRT_MemInfo *mi; - char *base = NRT_Allocate(sizeof(NRT_MemInfo) + size); + NRT_Debug(nrt_debug_print("nrt_allocate_meminfo_and_data %p\n", allocator)); + char *base = NRT_Allocate_External(sizeof(NRT_MemInfo) + size, allocator); mi = (NRT_MemInfo *) base; *mi_out = mi; return base + sizeof(NRT_MemInfo); @@ -230,9 +236,17 @@ void nrt_internal_custom_dtor_safe(void *ptr, size_t size, void *info) { NRT_MemInfo *NRT_MemInfo_alloc(size_t size) { NRT_MemInfo *mi; - void *data = nrt_allocate_meminfo_and_data(size, &mi); + void *data = nrt_allocate_meminfo_and_data(size, &mi, NULL); NRT_Debug(nrt_debug_print("NRT_MemInfo_alloc %p\n", data)); - NRT_MemInfo_init(mi, data, size, NULL, NULL); + NRT_MemInfo_init(mi, data, size, NULL, NULL, NULL); + return mi; +} + +NRT_MemInfo *NRT_MemInfo_alloc_external(size_t size, NRT_ExternalAllocator *allocator) { + NRT_MemInfo *mi; + void *data = nrt_allocate_meminfo_and_data(size, &mi, allocator); + NRT_Debug(nrt_debug_print("NRT_MemInfo_alloc %p\n", data)); + NRT_MemInfo_init(mi, data, size, NULL, NULL, allocator); return mi; } @@ -242,22 +256,23 @@ NRT_MemInfo *NRT_MemInfo_alloc_safe(size_t size) { NRT_MemInfo* NRT_MemInfo_alloc_dtor_safe(size_t size, NRT_dtor_function dtor) { NRT_MemInfo *mi; - void *data = nrt_allocate_meminfo_and_data(size, &mi); + void *data = nrt_allocate_meminfo_and_data(size, &mi, NULL); /* Only fill up a couple cachelines with debug markers, to minimize overhead. */ memset(data, 0xCB, MIN(size, 256)); NRT_Debug(nrt_debug_print("NRT_MemInfo_alloc_dtor_safe %p %zu\n", data, size)); - NRT_MemInfo_init(mi, data, size, nrt_internal_custom_dtor_safe, dtor); + NRT_MemInfo_init(mi, data, size, nrt_internal_custom_dtor_safe, dtor, NULL); return mi; } static void *nrt_allocate_meminfo_and_data_align(size_t size, unsigned align, - NRT_MemInfo **mi) + NRT_MemInfo **mi, NRT_ExternalAllocator *allocator) { size_t offset, intptr, remainder; - char *base = nrt_allocate_meminfo_and_data(size + 2 * align, mi); + NRT_Debug(nrt_debug_print("nrt_allocate_meminfo_and_data_align %p\n", allocator)); + char *base = nrt_allocate_meminfo_and_data(size + 2 * align, mi, allocator); intptr = (size_t) base; /* See if we are aligned */ remainder = intptr % align; @@ -271,26 +286,48 @@ void *nrt_allocate_meminfo_and_data_align(size_t size, unsigned align, NRT_MemInfo *NRT_MemInfo_alloc_aligned(size_t size, unsigned align) { NRT_MemInfo *mi; - void *data = nrt_allocate_meminfo_and_data_align(size, align, &mi); + void *data = nrt_allocate_meminfo_and_data_align(size, align, &mi, NULL); NRT_Debug(nrt_debug_print("NRT_MemInfo_alloc_aligned %p\n", data)); - NRT_MemInfo_init(mi, data, size, NULL, NULL); + NRT_MemInfo_init(mi, data, size, NULL, NULL, NULL); return mi; } NRT_MemInfo *NRT_MemInfo_alloc_safe_aligned(size_t size, unsigned align) { NRT_MemInfo *mi; - void *data = nrt_allocate_meminfo_and_data_align(size, align, &mi); + void *data = nrt_allocate_meminfo_and_data_align(size, align, &mi, NULL); /* Only fill up a couple cachelines with debug markers, to minimize overhead. */ memset(data, 0xCB, MIN(size, 256)); NRT_Debug(nrt_debug_print("NRT_MemInfo_alloc_safe_aligned %p %zu\n", data, size)); - NRT_MemInfo_init(mi, data, size, nrt_internal_dtor_safe, (void*)size); + NRT_MemInfo_init(mi, data, size, nrt_internal_dtor_safe, (void*)size, NULL); return mi; } +NRT_MemInfo *NRT_MemInfo_alloc_safe_aligned_external(size_t size, unsigned align, NRT_ExternalAllocator *allocator) { + NRT_MemInfo *mi; + NRT_Debug(nrt_debug_print("NRT_MemInfo_alloc_safe_aligned_external %p\n", allocator)); + void *data = nrt_allocate_meminfo_and_data_align(size, align, &mi, allocator); + /* Only fill up a couple cachelines with debug markers, to minimize + overhead. */ + memset(data, 0xCB, MIN(size, 256)); + NRT_Debug(nrt_debug_print("NRT_MemInfo_alloc_safe_aligned %p %zu\n", + data, size)); + NRT_MemInfo_init(mi, data, size, nrt_internal_dtor_safe, (void*)size, allocator); + return mi; +} + +void NRT_dealloc(NRT_MemInfo *mi) { + NRT_Debug(nrt_debug_print("NRT_dealloc meminfo: %p external_allocator: %p\n", mi, mi->external_allocator)); + if (mi->external_allocator) { + mi->external_allocator->free(mi, mi->external_allocator->opaque_data); + } else { + NRT_Free(mi); + } +} + void NRT_MemInfo_destroy(NRT_MemInfo *mi) { - NRT_Free(mi); + NRT_dealloc(mi); TheMSys.atomic_inc(&TheMSys.stats_mi_free); } @@ -328,6 +365,14 @@ size_t NRT_MemInfo_size(NRT_MemInfo* mi) { return mi->size; } +void * NRT_MemInfo_external_allocator(NRT_MemInfo *mi) { + NRT_Debug(nrt_debug_print("NRT_MemInfo_external_allocator meminfo: %p external_allocator: %p\n", mi, mi->external_allocator)); + return mi->external_allocator; +} + +void *NRT_MemInfo_parent(NRT_MemInfo *mi) { + return mi->dtor_info; +} void NRT_MemInfo_dump(NRT_MemInfo *mi, FILE *out) { fprintf(out, "MemInfo %p refcount %zu\n", mi, mi->refct); @@ -414,8 +459,18 @@ void NRT_MemInfo_varsize_free(NRT_MemInfo *mi, void *ptr) */ void* NRT_Allocate(size_t size) { - void *ptr = TheMSys.allocator.malloc(size); - NRT_Debug(nrt_debug_print("NRT_Allocate bytes=%zu ptr=%p\n", size, ptr)); + return NRT_Allocate_External(size, NULL); +} + +void* NRT_Allocate_External(size_t size, NRT_ExternalAllocator *allocator) { + void *ptr; + if (allocator) { + ptr = allocator->malloc(size, allocator->opaque_data); + NRT_Debug(nrt_debug_print("NRT_Allocate custom bytes=%zu ptr=%p\n", size, ptr)); + } else { + ptr = TheMSys.allocator.malloc(size); + NRT_Debug(nrt_debug_print("NRT_Allocate bytes=%zu ptr=%p\n", size, ptr)); + } TheMSys.atomic_inc(&TheMSys.stats_alloc); return ptr; } @@ -460,6 +515,7 @@ NRT_MemInfo* nrt_manage_memory(void *data, NRT_managed_dtor dtor) { static const NRT_api_functions nrt_functions_table = { NRT_MemInfo_alloc, + NRT_MemInfo_alloc_external, nrt_manage_memory, NRT_MemInfo_acquire, NRT_MemInfo_release, diff --git a/numba/core/runtime/nrt.h b/numba/core/runtime/nrt.h index 3c74dc58f58..9fb23532964 100644 --- a/numba/core/runtime/nrt.h +++ b/numba/core/runtime/nrt.h @@ -15,13 +15,14 @@ All functions described here are threadsafe. /* Debugging facilities - enabled at compile-time */ /* #undef NDEBUG */ #if 0 -# define NRT_Debug(X) X +# define NRT_Debug(X) {X; fflush(stdout); } #else # define NRT_Debug(X) if (0) { X; } #endif /* TypeDefs */ typedef void (*NRT_dtor_function)(void *ptr, size_t size, void *info); +typedef void (*NRT_dealloc_func)(void *ptr, void *dealloc_info); typedef size_t (*NRT_atomic_inc_dec_func)(size_t *ptr); typedef int (*NRT_atomic_cas_func)(void * volatile *ptr, void *cmp, void *repl, void **oldptr); @@ -32,7 +33,6 @@ typedef void *(*NRT_malloc_func)(size_t size); typedef void *(*NRT_realloc_func)(void *ptr, size_t new_size); typedef void (*NRT_free_func)(void *ptr); - /* Memory System API */ /* Initialize the memory system */ @@ -101,7 +101,8 @@ NRT_MemInfo* NRT_MemInfo_new(void *data, size_t size, VISIBILITY_HIDDEN void NRT_MemInfo_init(NRT_MemInfo *mi, void *data, size_t size, - NRT_dtor_function dtor, void *dtor_info); + NRT_dtor_function dtor, void *dtor_info, + NRT_ExternalAllocator *external_allocator); /* * Returns the refcount of a MemInfo or (size_t)-1 if error. @@ -116,6 +117,8 @@ size_t NRT_MemInfo_refcount(NRT_MemInfo *mi); VISIBILITY_HIDDEN NRT_MemInfo *NRT_MemInfo_alloc(size_t size); +NRT_MemInfo *NRT_MemInfo_alloc_external(size_t size, NRT_ExternalAllocator *allocator); + /* * The "safe" NRT_MemInfo_alloc performs additional steps to help debug * memory errors. @@ -141,6 +144,8 @@ NRT_MemInfo *NRT_MemInfo_alloc_aligned(size_t size, unsigned align); VISIBILITY_HIDDEN NRT_MemInfo *NRT_MemInfo_alloc_safe_aligned(size_t size, unsigned align); +NRT_MemInfo *NRT_MemInfo_alloc_safe_aligned_external(size_t size, unsigned align, NRT_ExternalAllocator *allocator); + /* * Internal API. * Release a MemInfo. Calls NRT_MemSys_insert_meminfo. @@ -179,6 +184,18 @@ void* NRT_MemInfo_data(NRT_MemInfo* mi); VISIBILITY_HIDDEN size_t NRT_MemInfo_size(NRT_MemInfo* mi); +/* + * Returns the external allocator + */ +VISIBILITY_HIDDEN +void* NRT_MemInfo_external_allocator(NRT_MemInfo* mi); + +/* + * Returns the parent MemInfo + */ +VISIBILITY_HIDDEN +void* NRT_MemInfo_parent(NRT_MemInfo* mi); + /* * NRT API for resizable buffers. @@ -207,6 +224,7 @@ void NRT_MemInfo_dump(NRT_MemInfo *mi, FILE *out); * Allocate memory of `size` bytes. */ VISIBILITY_HIDDEN void* NRT_Allocate(size_t size); +VISIBILITY_HIDDEN void* NRT_Allocate_External(size_t size, NRT_ExternalAllocator *allocator); /* * Deallocate memory pointed by `ptr`. diff --git a/numba/core/runtime/nrt_external.h b/numba/core/runtime/nrt_external.h index 391b6fa1b0e..a4835c36f67 100644 --- a/numba/core/runtime/nrt_external.h +++ b/numba/core/runtime/nrt_external.h @@ -7,6 +7,18 @@ typedef struct MemInfo NRT_MemInfo; typedef void NRT_managed_dtor(void *data); +typedef void *(*NRT_external_malloc_func)(size_t size, void *opaque_data); +typedef void *(*NRT_external_realloc_func)(void *ptr, size_t new_size, void *opaque_data); +typedef void (*NRT_external_free_func)(void *ptr, void *opaque_data); + +struct ExternalMemAllocator { + NRT_external_malloc_func malloc; + NRT_external_realloc_func realloc; + NRT_external_free_func free; + void *opaque_data; +}; + +typedef struct ExternalMemAllocator NRT_ExternalAllocator; typedef struct { /* Methods to create MemInfos. @@ -21,6 +33,10 @@ typedef struct { Returning a new reference. */ NRT_MemInfo* (*allocate)(size_t nbytes); + /* Allocator memory using an external allocator but still using Numba's MemInfo. + + */ + NRT_MemInfo* (*allocate_external)(size_t nbytes, NRT_ExternalAllocator *allocator); /* Convert externally allocated memory into a MemInfo. diff --git a/numba/core/types/npytypes.py b/numba/core/types/npytypes.py index 6f6307c5526..3c2191ca23e 100644 --- a/numba/core/types/npytypes.py +++ b/numba/core/types/npytypes.py @@ -8,6 +8,7 @@ from numba.core import utils from .misc import UnicodeType from .containers import Bytes +import numpy as np class CharSeq(Type): """ @@ -394,8 +395,9 @@ class Array(Buffer): Type class for Numpy arrays. """ - def __init__(self, dtype, ndim, layout, readonly=False, name=None, + def __init__(self, dtype, ndim, layout, py_type=np.ndarray, readonly=False, name=None, aligned=True, addrspace=None): + self.py_type = py_type if readonly: self.mutable = False if (not aligned or diff --git a/numba/core/typing/npydecl.py b/numba/core/typing/npydecl.py index 2dbbed39be9..e7ecf452fe9 100644 --- a/numba/core/typing/npydecl.py +++ b/numba/core/typing/npydecl.py @@ -126,7 +126,21 @@ def generic(self, args, kws): ret_tys = ufunc_loop.outputs[-implicit_output_count:] if ndims > 0: assert layout is not None - ret_tys = [types.Array(dtype=ret_ty, ndim=ndims, layout=layout) + # If either of the types involved in the ufunc operation have a + # __array_ufunc__ method then invoke the first such one to + # determine the output type of the ufunc. + array_ufunc_type = None + for a in args: + if hasattr(a, "__array_ufunc__"): + array_ufunc_type = a + break + output_type = types.Array + if array_ufunc_type is not None: + output_type = array_ufunc_type.__array_ufunc__(ufunc, "__call__", *args, **kws) + # Eventually better error handling! FIX ME! + assert(output_type is not None) + + ret_tys = [output_type(dtype=ret_ty, ndim=ndims, layout=layout) for ret_ty in ret_tys] ret_tys = [resolve_output_type(self.context, args, ret_ty) for ret_ty in ret_tys] @@ -517,6 +531,7 @@ def typer(shape, dtype=None): @infer_global(np.empty_like) @infer_global(np.zeros_like) +@infer_global(np.ones_like) class NdConstructorLike(CallableTemplate): """ Typing template for np.empty_like(), .zeros_like(), .ones_like(). @@ -544,9 +559,6 @@ def typer(arg, dtype=None): return typer -infer_global(np.ones_like)(NdConstructorLike) - - @infer_global(np.full) class NdFull(CallableTemplate): @@ -563,6 +575,7 @@ def typer(shape, fill_value, dtype=None): return typer + @infer_global(np.full_like) class NdFullLike(CallableTemplate): diff --git a/numba/np/arrayobj.py b/numba/np/arrayobj.py index 933b1c6565e..5749e7d9b5b 100644 --- a/numba/np/arrayobj.py +++ b/numba/np/arrayobj.py @@ -32,7 +32,7 @@ from numba.misc import quicksort, mergesort from numba.cpython import slicing from numba.cpython.unsafe.tuple import tuple_setitem - +from numba.core.pythonapi import _allocators def set_range_metadata(builder, load, lower_bound, upper_bound): """ @@ -3399,8 +3399,13 @@ def _empty_nd_impl(context, builder, arrtype, shapes): ) align = context.get_preferred_array_alignment(arrtype.dtype) - meminfo = context.nrt.meminfo_alloc_aligned(builder, size=allocsize, - align=align) + def alloc_unsupported(context, builder, size, align): + return context.nrt.meminfo_alloc_aligned(builder, size, align) + + # See if the type has a special allocator, if not use the default + # alloc_unsuppported allocator above. + allocator_impl = _allocators.lookup(arrtype.__class__, alloc_unsupported) + meminfo = allocator_impl(context, builder, size=allocsize, align=align) data = context.nrt.meminfo_data(builder, meminfo)