Skip to content

Commit 1214893

Browse files
committed
There were some spots where there was a silent assumption that the class and the Numba integration were in the same file. I changed those to explicitly refer to the usmarray module in dpctl.
1 parent c87a94b commit 1214893

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

numba_dppy/numpy_usm_shared.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from numba.core.typing.templates import CallableTemplate
2424
from numba.np.arrayobj import _array_copy
2525

26-
import dpctl.dptensor.numpy_usm_shared as numpy_usm_shared
27-
from dpctl.dptensor.numpy_usm_shared import ndarray, functions_list
26+
import dpctl.dptensor.numpy_usm_shared as nus
27+
from dpctl.dptensor.numpy_usm_shared import ndarray, functions_list, class_list
2828

2929

3030
debug = config.DEBUG
@@ -233,7 +233,7 @@ def numba_register_lower_builtin():
233233
cur_mod = importlib.import_module(__name__)
234234
for impl, func, types in todo + todo_builtin:
235235
try:
236-
usmarray_func = eval("numpy_usm_shared."+func.__name__)
236+
usmarray_func = eval("dpctl.dptensor.numpy_usm_shared." + func.__name__)
237237
except:
238238
dprint("failed to eval", func.__name__)
239239
continue
@@ -260,28 +260,44 @@ def numba_register_typing():
260260
# For all Numpy identifiers that have been registered for typing in Numba...
261261
for ig in typing_registry.globals:
262262
val, typ = ig
263+
dprint("Numpy registered:", val, type(val), typ, type(typ))
263264
# If it is a Numpy function...
264265
if isinstance(val, (ftype, bftype)):
265266
# If we have overloaded that function in the usmarray module (always True right now)...
266267
if val.__name__ in functions_list:
267268
todo.append(ig)
268269
if isinstance(val, type):
269-
todo_classes.append(ig)
270+
if isinstance(typ, numba.core.types.functions.Function):
271+
todo.append(ig)
272+
elif isinstance(typ, numba.core.types.functions.NumberClass):
273+
pass
274+
#todo_classes.append(ig)
270275

271276
for tgetattr in templates_registry.attributes:
272277
if tgetattr.key == types.Array:
273278
todo_getattr.append(tgetattr)
274279

280+
for val, typ in todo_classes:
281+
dprint("todo_classes:", val, typ, type(typ))
282+
283+
try:
284+
dptype = eval("dpctl.dptensor.numpy_usm_shared." + val.__name__)
285+
except:
286+
dprint("failed to eval", val.__name__)
287+
continue
288+
289+
typing_registry.register_global(dptype, numba.core.types.NumberClass(typ.instance_type))
290+
275291
for val, typ in todo:
276292
assert len(typ.templates) == 1
277293
# template is the typing class to invoke generic() upon.
278294
template = typ.templates[0]
295+
dprint("need to re-register for usmarray", val, typ, typ.typing_key)
279296
try:
280-
dpval = eval("numpy_usm_shared."+val.__name__)
297+
dpval = eval("dpctl.dptensor.numpy_usm_shared." + val.__name__)
281298
except:
282299
dprint("failed to eval", val.__name__)
283300
continue
284-
dprint("need to re-register for usmarray", val, typ, typ.typing_key)
285301
"""
286302
if debug:
287303
print("--------------------------------------------------------------")
@@ -307,9 +323,7 @@ def set_key_original(cls, key, original):
307323
def generic_impl(self):
308324
original_typer = self.__class__.original.generic(self.__class__.original)
309325
ot_argspec = inspect.getfullargspec(original_typer)
310-
# print("ot_argspec:", ot_argspec)
311326
astr = argspec_to_string(ot_argspec)
312-
# print("astr:", astr)
313327

314328
typer_func = """def typer({}):
315329
original_res = original_typer({})
@@ -321,8 +335,6 @@ def generic_impl(self):
321335
astr, ",".join(ot_argspec.args)
322336
)
323337

324-
# print("typer_func:", typer_func)
325-
326338
try:
327339
gs = globals()
328340
ls = locals()
@@ -344,7 +356,6 @@ def generic_impl(self):
344356
print("eval failed!", sys.exc_info()[0])
345357
sys.exit(0)
346358

347-
# print("exec_res:", exec_res)
348359
return exec_res
349360

350361
new_usmarray_template = type(
@@ -370,7 +381,6 @@ def set_key(cls, key):
370381

371382
def getattr_impl(self, attr):
372383
if attr.startswith("resolve_"):
373-
# print("getattr_impl starts with resolve_:", self, type(self), attr)
374384
def wrapper(*args, **kwargs):
375385
attr_res = tgetattr.__getattribute__(self, attr)(*args, **kwargs)
376386
if isinstance(attr_res, types.Array):
@@ -394,15 +404,7 @@ def wrapper(*args, **kwargs):
394404
templates_registry.register_attr(new_usmarray_template)
395405

396406

397-
def from_ndarray(x):
398-
return copy(x)
399-
400-
401-
def as_ndarray(x):
402-
return np.copy(x)
403-
404-
405-
@typing_registry.register_global(as_ndarray)
407+
@typing_registry.register_global(nus.as_ndarray)
406408
class DparrayAsNdarray(CallableTemplate):
407409
def generic(self):
408410
def typer(arg):
@@ -411,7 +413,7 @@ def typer(arg):
411413
return typer
412414

413415

414-
@typing_registry.register_global(from_ndarray)
416+
@typing_registry.register_global(nus.from_ndarray)
415417
class DparrayFromNdarray(CallableTemplate):
416418
def generic(self):
417419
def typer(arg):
@@ -420,11 +422,11 @@ def typer(arg):
420422
return typer
421423

422424

423-
@lower_registry.lower(as_ndarray, UsmSharedArrayType)
425+
@lower_registry.lower(nus.as_ndarray, UsmSharedArrayType)
424426
def usmarray_conversion_as(context, builder, sig, args):
425427
return _array_copy(context, builder, sig, args)
426428

427429

428-
@lower_registry.lower(from_ndarray, types.Array)
430+
@lower_registry.lower(nus.from_ndarray, types.Array)
429431
def usmarray_conversion_from(context, builder, sig, args):
430432
return _array_copy(context, builder, sig, args)

0 commit comments

Comments
 (0)