23
23
from numba .core .typing .templates import CallableTemplate
24
24
from numba .np .arrayobj import _array_copy
25
25
26
- import dpctl .dptensor .numpy_usm_shared as numpy_usm_shared
27
- from dpctl .dptensor .numpy_usm_shared import ndarray , functions_list
26
+ import dpctl .dptensor .numpy_usm_shared as nus
27
+ from dpctl .dptensor .numpy_usm_shared import ndarray , functions_list , class_list
28
28
29
29
30
30
debug = config .DEBUG
@@ -233,7 +233,7 @@ def numba_register_lower_builtin():
233
233
cur_mod = importlib .import_module (__name__ )
234
234
for impl , func , types in todo + todo_builtin :
235
235
try :
236
- usmarray_func = eval ("numpy_usm_shared." + func .__name__ )
236
+ usmarray_func = eval ("dpctl.dptensor. numpy_usm_shared." + func .__name__ )
237
237
except :
238
238
dprint ("failed to eval" , func .__name__ )
239
239
continue
@@ -260,28 +260,44 @@ def numba_register_typing():
260
260
# For all Numpy identifiers that have been registered for typing in Numba...
261
261
for ig in typing_registry .globals :
262
262
val , typ = ig
263
+ dprint ("Numpy registered:" , val , type (val ), typ , type (typ ))
263
264
# If it is a Numpy function...
264
265
if isinstance (val , (ftype , bftype )):
265
266
# If we have overloaded that function in the usmarray module (always True right now)...
266
267
if val .__name__ in functions_list :
267
268
todo .append (ig )
268
269
if isinstance (val , type ):
269
- todo_classes .append (ig )
270
+ if isinstance (typ , numba .core .types .functions .Function ):
271
+ todo .append (ig )
272
+ elif isinstance (typ , numba .core .types .functions .NumberClass ):
273
+ pass
274
+ #todo_classes.append(ig)
270
275
271
276
for tgetattr in templates_registry .attributes :
272
277
if tgetattr .key == types .Array :
273
278
todo_getattr .append (tgetattr )
274
279
280
+ for val , typ in todo_classes :
281
+ dprint ("todo_classes:" , val , typ , type (typ ))
282
+
283
+ try :
284
+ dptype = eval ("dpctl.dptensor.numpy_usm_shared." + val .__name__ )
285
+ except :
286
+ dprint ("failed to eval" , val .__name__ )
287
+ continue
288
+
289
+ typing_registry .register_global (dptype , numba .core .types .NumberClass (typ .instance_type ))
290
+
275
291
for val , typ in todo :
276
292
assert len (typ .templates ) == 1
277
293
# template is the typing class to invoke generic() upon.
278
294
template = typ .templates [0 ]
295
+ dprint ("need to re-register for usmarray" , val , typ , typ .typing_key )
279
296
try :
280
- dpval = eval ("numpy_usm_shared." + val .__name__ )
297
+ dpval = eval ("dpctl.dptensor. numpy_usm_shared." + val .__name__ )
281
298
except :
282
299
dprint ("failed to eval" , val .__name__ )
283
300
continue
284
- dprint ("need to re-register for usmarray" , val , typ , typ .typing_key )
285
301
"""
286
302
if debug:
287
303
print("--------------------------------------------------------------")
@@ -307,9 +323,7 @@ def set_key_original(cls, key, original):
307
323
def generic_impl (self ):
308
324
original_typer = self .__class__ .original .generic (self .__class__ .original )
309
325
ot_argspec = inspect .getfullargspec (original_typer )
310
- # print("ot_argspec:", ot_argspec)
311
326
astr = argspec_to_string (ot_argspec )
312
- # print("astr:", astr)
313
327
314
328
typer_func = """def typer({}):
315
329
original_res = original_typer({})
@@ -321,8 +335,6 @@ def generic_impl(self):
321
335
astr , "," .join (ot_argspec .args )
322
336
)
323
337
324
- # print("typer_func:", typer_func)
325
-
326
338
try :
327
339
gs = globals ()
328
340
ls = locals ()
@@ -344,7 +356,6 @@ def generic_impl(self):
344
356
print ("eval failed!" , sys .exc_info ()[0 ])
345
357
sys .exit (0 )
346
358
347
- # print("exec_res:", exec_res)
348
359
return exec_res
349
360
350
361
new_usmarray_template = type (
@@ -370,7 +381,6 @@ def set_key(cls, key):
370
381
371
382
def getattr_impl (self , attr ):
372
383
if attr .startswith ("resolve_" ):
373
- # print("getattr_impl starts with resolve_:", self, type(self), attr)
374
384
def wrapper (* args , ** kwargs ):
375
385
attr_res = tgetattr .__getattribute__ (self , attr )(* args , ** kwargs )
376
386
if isinstance (attr_res , types .Array ):
@@ -394,15 +404,7 @@ def wrapper(*args, **kwargs):
394
404
templates_registry .register_attr (new_usmarray_template )
395
405
396
406
397
- def from_ndarray (x ):
398
- return copy (x )
399
-
400
-
401
- def as_ndarray (x ):
402
- return np .copy (x )
403
-
404
-
405
- @typing_registry .register_global (as_ndarray )
407
+ @typing_registry .register_global (nus .as_ndarray )
406
408
class DparrayAsNdarray (CallableTemplate ):
407
409
def generic (self ):
408
410
def typer (arg ):
@@ -411,7 +413,7 @@ def typer(arg):
411
413
return typer
412
414
413
415
414
- @typing_registry .register_global (from_ndarray )
416
+ @typing_registry .register_global (nus . from_ndarray )
415
417
class DparrayFromNdarray (CallableTemplate ):
416
418
def generic (self ):
417
419
def typer (arg ):
@@ -420,11 +422,11 @@ def typer(arg):
420
422
return typer
421
423
422
424
423
- @lower_registry .lower (as_ndarray , UsmSharedArrayType )
425
+ @lower_registry .lower (nus . as_ndarray , UsmSharedArrayType )
424
426
def usmarray_conversion_as (context , builder , sig , args ):
425
427
return _array_copy (context , builder , sig , args )
426
428
427
429
428
- @lower_registry .lower (from_ndarray , types .Array )
430
+ @lower_registry .lower (nus . from_ndarray , types .Array )
429
431
def usmarray_conversion_from (context , builder , sig , args ):
430
432
return _array_copy (context , builder , sig , args )
0 commit comments