diff --git a/numba/core/cpu.py b/numba/core/cpu.py index a54ae7ee339..62eba99638c 100644 --- a/numba/core/cpu.py +++ b/numba/core/cpu.py @@ -62,6 +62,13 @@ def init(self): import numba.typed.dictimpl import numba.experimental.function_type + # Add lower_extension attribute + self.lower_extensions = {} + from numba.parfors.parfor_lowering import _lower_parfor_parallel + from numba.parfors.parfor import Parfor + # Specify how to lower Parfor nodes using the lower_extensions + self.lower_extensions[Parfor] = _lower_parfor_parallel + def load_additional_registries(self): # Add target specific implementations from numba.np import npyimpl diff --git a/numba/core/lowering.py b/numba/core/lowering.py index 44009bde487..57efe28f768 100644 --- a/numba/core/lowering.py +++ b/numba/core/lowering.py @@ -267,20 +267,9 @@ def debug_print(self, msg): self.context.debug_print(self.builder, "DEBUGJIT: {0}".format(msg)) -# Dictionary mapping instruction class to its lowering function. -lower_extensions = {} - - class Lower(BaseLower): GeneratorLower = generators.GeneratorLower - def __init__(self, context, library, fndesc, func_ir, metadata=None): - BaseLower.__init__(self, context, library, fndesc, func_ir, metadata) - from numba.parfors.parfor_lowering import _lower_parfor_parallel - from numba.parfors import parfor - if parfor.Parfor not in lower_extensions: - lower_extensions[parfor.Parfor] = [_lower_parfor_parallel] - def pre_block(self, block): from numba.core.unsafe import eh @@ -445,10 +434,11 @@ def lower_inst(self, inst): self.lower_static_try_raise(inst) else: - for _class, func in lower_extensions.items(): - if isinstance(inst, _class): - func[-1](self, inst) - return + if hasattr(self.context, "lower_extensions"): + for _class, func in self.context.lower_extensions.items(): + if isinstance(inst, _class): + func(self, inst) + return raise NotImplementedError(type(inst)) def lower_setitem(self, target_var, index_var, value_var, signature): diff --git a/numba/parfors/parfor_lowering.py b/numba/parfors/parfor_lowering.py index 791b9b79403..765bdb52314 100644 --- a/numba/parfors/parfor_lowering.py +++ b/numba/parfors/parfor_lowering.py @@ -479,9 +479,6 @@ def _lower_parfor_parallel(lowerer, parfor): if config.DEBUG_ARRAY_OPT: print("_lower_parfor_parallel done") -# A work-around to prevent circular imports -#lowering.lower_extensions[parfor.Parfor] = _lower_parfor_parallel - def _create_shape_signature( get_shape_classes,