Skip to content

Commit 4e99a55

Browse files
committed
Add imports to usmarray module and fixed setup.py extension initialization
1 parent f895e1d commit 4e99a55

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

numba_dppy/numpy_usm_shared.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,56 @@
1+
import numpy as np
2+
from inspect import getmembers, isfunction, isclass, isbuiltin
3+
from numbers import Number
4+
import numba
5+
from types import FunctionType as ftype, BuiltinFunctionType as bftype
6+
from numba import types
7+
from numba.extending import typeof_impl, register_model, type_callable, lower_builtin
8+
from numba.np import numpy_support
9+
from numba.core.pythonapi import box, allocator
10+
from llvmlite import ir
11+
import llvmlite.llvmpy.core as lc
12+
import llvmlite.binding as llb
13+
from numba.core import types, cgutils, config
14+
import builtins
15+
import sys
16+
from ctypes.util import find_library
17+
from numba.core.typing.templates import builtin_registry as templates_registry
18+
from numba.core.typing.npydecl import registry as typing_registry
19+
from numba.core.imputils import builtin_registry as lower_registry
20+
import importlib
21+
import functools
22+
import inspect
23+
from numba.core.typing.templates import CallableTemplate
24+
from numba.np.arrayobj import _array_copy
25+
26+
from dpctl.dptensor.numpy_usm_shared import ndarray, functions_list
27+
28+
29+
debug = config.DEBUG
30+
31+
def dprint(*args):
32+
if debug:
33+
print(*args)
34+
sys.stdout.flush()
35+
36+
# # This code makes it so that Numba can contain calls into the DPPLSyclInterface library.
37+
# sycl_mem_lib = find_library('DPCTLSyclInterface')
38+
# dprint("sycl_mem_lib:", sycl_mem_lib)
39+
# # Load the symbols from the DPPL Sycl library.
40+
# llb.load_library_permanently(sycl_mem_lib)
41+
42+
import dpctl
43+
from dpctl.memory import MemoryUSMShared
44+
import numba_dppy._dppy_rt
45+
46+
# functions_list = [o[0] for o in getmembers(np) if isfunction(o[1]) or isbuiltin(o[1])]
47+
# class_list = [o for o in getmembers(np) if isclass(o[1])]
48+
49+
# Register the helper function in dppl_rt so that we can insert calls to them via llvmlite.
50+
for py_name, c_address in numba_dppy._dppy_rt.c_helpers.items():
51+
llb.add_symbol(py_name, c_address)
52+
53+
154
# This class creates a type in Numba.
255
class UsmSharedArrayType(types.Array):
356
def __init__(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def get_ext_modules():
7474
cmdclass=versioneer.get_cmdclass(),
7575
entry_points={
7676
"numba_extensions": [
77-
"init = numba_dppy.usmarray:numba_register",
77+
"init = numba_dppy.numpy_usm_shared:numba_register",
7878
]},
7979
)
8080

0 commit comments

Comments
 (0)