From 2a8c2b1fa6ba106b389ace444ccb79d56735f44f Mon Sep 17 00:00:00 2001 From: Sergey Pokhodenko Date: Mon, 9 Nov 2020 21:53:25 +0300 Subject: [PATCH] Patch for with context This modifications make jit() decorator use TargetDispatcher from dppl. Changes made in #57 by @AlexanderKalistratov and @1e-to. --- numba/core/decorators.py | 51 +++++++++++--------------------- numba/core/dispatcher.py | 12 +------- numba/core/registry.py | 6 ---- numba/tests/test_dispatcher.py | 2 -- numba/tests/test_nrt.py | 2 -- numba/tests/test_record_dtype.py | 4 +-- numba/tests/test_serialize.py | 6 ++-- 7 files changed, 23 insertions(+), 60 deletions(-) diff --git a/numba/core/decorators.py b/numba/core/decorators.py index aafef16ed11..cfe91168969 100644 --- a/numba/core/decorators.py +++ b/numba/core/decorators.py @@ -149,7 +149,7 @@ def bar(x, y): target = options.pop('target') warnings.warn("The 'target' keyword argument is deprecated.", NumbaDeprecationWarning) else: - target = options.pop('_target', None) + target = options.pop('_target', 'cpu') options['boundscheck'] = boundscheck @@ -183,16 +183,27 @@ def bar(x, y): def _jit(sigs, locals, target, cache, targetoptions, **dispatcher_args): + dispatcher = registry.dispatcher_registry[target] + + def wrapper(func): + if extending.is_jitted(func): + raise TypeError( + "A jit decorator was called on an already jitted function " + f"{func}. If trying to access the original python " + f"function, use the {func}.py_func attribute." + ) + + if not inspect.isfunction(func): + raise TypeError( + "The decorated object is not a function (got type " + f"{type(func)})." + ) - def wrapper(func, dispatcher): if config.ENABLE_CUDASIM and target == 'cuda': from numba import cuda return cuda.jit(func) if config.DISABLE_JIT and not target == 'npyufunc': return func - if target == 'dppl': - from . import dppl - return dppl.jit(func) disp = dispatcher(py_func=func, locals=locals, targetoptions=targetoptions, **dispatcher_args) @@ -208,35 +219,7 @@ def wrapper(func, dispatcher): disp.disable_compile() return disp - def __wrapper(func): - if extending.is_jitted(func): - raise TypeError( - "A jit decorator was called on an already jitted function " - f"{func}. If trying to access the original python " - f"function, use the {func}.py_func attribute." - ) - - if not inspect.isfunction(func): - raise TypeError( - "The decorated object is not a function (got type " - f"{type(func)})." - ) - - from numba import dppl_config - if (target == 'npyufunc' or targetoptions.get('no_cpython_wrapper') - or sigs or config.DISABLE_JIT or not targetoptions.get('nopython') - or dppl_config.dppl_present is not True): - target_ = target - if target_ is None: - target_ = 'cpu' - disp = registry.dispatcher_registry[target_] - return wrapper(func, disp) - - from numba.dppl.target_dispatcher import TargetDispatcher - disp = TargetDispatcher(func, wrapper, target, targetoptions.get('parallel')) - return disp - - return __wrapper + return wrapper def generated_jit(function=None, target='cpu', cache=False, diff --git a/numba/core/dispatcher.py b/numba/core/dispatcher.py index 42418fe5783..18d9426cd4d 100644 --- a/numba/core/dispatcher.py +++ b/numba/core/dispatcher.py @@ -673,14 +673,7 @@ def _set_uuid(self, u): self._recent.append(self) -import abc - -class DispatcherMeta(abc.ABCMeta): - def __instancecheck__(self, other): - return type(type(other)) == DispatcherMeta - - -class Dispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase, metaclass=DispatcherMeta): +class Dispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase): """ Implementation of user-facing dispatcher objects (i.e. created using the @jit decorator). @@ -906,9 +899,6 @@ def get_function_type(self): cres = tuple(self.overloads.values())[0] return types.FunctionType(cres.signature) - def get_compiled(self): - return self - class LiftedCode(serialize.ReduceMixin, _MemoMixin, _DispatcherBase): """ diff --git a/numba/core/registry.py b/numba/core/registry.py index a3396ecc5a6..a543cd0d789 100644 --- a/numba/core/registry.py +++ b/numba/core/registry.py @@ -2,7 +2,6 @@ from numba.core.descriptors import TargetDescriptor from numba.core import utils, typing, dispatcher, cpu -from numba.core.compiler_lock import global_compiler_lock # ----------------------------------------------------------------------------- # Default CPU target descriptors @@ -27,19 +26,16 @@ class CPUTarget(TargetDescriptor): _nested = _NestedContext() @utils.cached_property - @global_compiler_lock def _toplevel_target_context(self): # Lazily-initialized top-level target context, for all threads return cpu.CPUContext(self.typing_context) @utils.cached_property - @global_compiler_lock def _toplevel_typing_context(self): # Lazily-initialized top-level typing context, for all threads return typing.Context() @property - @global_compiler_lock def target_context(self): """ The target context for CPU targets. @@ -51,7 +47,6 @@ def target_context(self): return self._toplevel_target_context @property - @global_compiler_lock def typing_context(self): """ The typing context for CPU targets. @@ -62,7 +57,6 @@ def typing_context(self): else: return self._toplevel_typing_context - @global_compiler_lock def nested_context(self, typing_context, target_context): """ A context manager temporarily replacing the contexts with the diff --git a/numba/tests/test_dispatcher.py b/numba/tests/test_dispatcher.py index b90d42ede26..30a8e081485 100644 --- a/numba/tests/test_dispatcher.py +++ b/numba/tests/test_dispatcher.py @@ -398,8 +398,6 @@ def test_serialization(self): def foo(x): return x + 1 - foo = foo.get_compiled() - self.assertEqual(foo(1), 2) # get serialization memo diff --git a/numba/tests/test_nrt.py b/numba/tests/test_nrt.py index 602132258e8..e0c94605671 100644 --- a/numba/tests/test_nrt.py +++ b/numba/tests/test_nrt.py @@ -249,8 +249,6 @@ def alloc_nrt_memory(): """ return np.empty(N, dtype) - alloc_nrt_memory = alloc_nrt_memory.get_compiled() - def keep_memory(): return alloc_nrt_memory() diff --git a/numba/tests/test_record_dtype.py b/numba/tests/test_record_dtype.py index e674bacc957..6d479c413fa 100644 --- a/numba/tests/test_record_dtype.py +++ b/numba/tests/test_record_dtype.py @@ -803,8 +803,8 @@ def test_record_arg_transform(self): self.assertIn('Array', transformed) self.assertNotIn('first', transformed) self.assertNotIn('second', transformed) - # Length is usually 60 - 5 chars tolerance as above. - self.assertLess(len(transformed), 60) + # Length is usually 50 - 5 chars tolerance as above. + self.assertLess(len(transformed), 50) def test_record_two_arrays(self): """ diff --git a/numba/tests/test_serialize.py b/numba/tests/test_serialize.py index 90c3db44a16..2bcf843458a 100644 --- a/numba/tests/test_serialize.py +++ b/numba/tests/test_serialize.py @@ -135,9 +135,9 @@ def test_reuse(self): Note that "same function" is intentionally under-specified. """ - func = closure(5).get_compiled() + func = closure(5) pickled = pickle.dumps(func) - func2 = closure(6).get_compiled() + func2 = closure(6) pickled2 = pickle.dumps(func2) f = pickle.loads(pickled) @@ -152,7 +152,7 @@ def test_reuse(self): self.assertEqual(h(2, 3), 11) # Now make sure the original object doesn't exist when deserializing - func = closure(7).get_compiled() + func = closure(7) func(42, 43) pickled = pickle.dumps(func) del func