diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index 7682c34..60d38ac 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -1,6 +1,10 @@ name: Conda package -on: push +on: + push: + branches: + - master + pull_request: permissions: read-all diff --git a/mkl_umath/src/_patch.pyx b/mkl_umath/src/_patch.pyx index fd78f8d..5da3038 100644 --- a/mkl_umath/src/_patch.pyx +++ b/mkl_umath/src/_patch.pyx @@ -27,7 +27,6 @@ # cython: language_level=3 import mkl_umath._ufuncs as mu -import numpy.core.umath as nu cimport numpy as cnp import numpy as np @@ -59,7 +58,7 @@ cdef class patch: self.functions_count = 0 for umath in umaths: mkl_umath = getattr(mu, umath) - self.functions_count = self.functions_count + mkl_umath.ntypes + self.functions_count += mkl_umath.ntypes self.functions = malloc(self.functions_count * sizeof(function_info)) @@ -67,7 +66,7 @@ cdef class patch: for umath in umaths: patch_umath = getattr(mu, umath) c_patch_umath = patch_umath - c_orig_umath = getattr(nu, umath) + c_orig_umath = getattr(np, umath) nargs = c_patch_umath.nargs for pi in range(c_patch_umath.ntypes): oi = 0 @@ -103,7 +102,7 @@ cdef class patch: cdef int* signature for func in self.functions_dict: - np_umath = getattr(nu, func[0]) + np_umath = getattr(np, func[0]) index = self.functions_dict[func] function = self.functions[index].patch_function signature = self.functions[index].signature @@ -118,7 +117,7 @@ cdef class patch: cdef int* signature for func in self.functions_dict: - np_umath = getattr(nu, func[0]) + np_umath = getattr(np, func[0]) index = self.functions_dict[func] function = self.functions[index].original_function signature = self.functions[index].signature @@ -143,34 +142,97 @@ def _initialize_tls(): def use_in_numpy(): - ''' + """ Enables using of mkl_umath in Numpy. - ''' + + Examples + -------- + >>> import mkl_umath, numpy as np + >>> mkl_umath.is_patched() + # False + + >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy + >>> mkl_umath.is_patched() + # True + + >>> mkl_umath.restore() # Disable mkl_umath in Numpy + >>> mkl_umath.is_patched() + # False + + """ if not _is_tls_initialized(): _initialize_tls() _tls.patch.do_patch() def restore(): - ''' + """ Disables using of mkl_umath in Numpy. - ''' + + Examples + -------- + >>> import mkl_umath, numpy as np + >>> mkl_umath.is_patched() + # False + + >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy + >>> mkl_umath.is_patched() + # True + + >>> mkl_umath.restore() # Disable mkl_umath in Numpy + >>> mkl_umath.is_patched() + # False + + """ if not _is_tls_initialized(): _initialize_tls() _tls.patch.do_unpatch() def is_patched(): - ''' + """ Returns whether Numpy has been patched with mkl_umath. - ''' + + Examples + -------- + >>> import mkl_umath, numpy as np + >>> mkl_umath.is_patched() + # False + + >>> mkl_umath.use_in_numpy() # Enable mkl_umath in Numpy + >>> mkl_umath.is_patched() + # True + + >>> mkl_umath.restore() # Disable mkl_umath in Numpy + >>> mkl_umath.is_patched() + # False + + """ if not _is_tls_initialized(): _initialize_tls() - _tls.patch.is_patched() + return _tls.patch.is_patched() from contextlib import ContextDecorator class mkl_umath(ContextDecorator): + """ + Context manager and decorator to temporarily patch NumPy ufuncs + with MKL-based implementations. + + Examples + -------- + >>> import mkl_umath, numpy as np + >>> mkl_umath.is_patched() + # False + + >>> with mkl_umath.mkl_umath(): # Enable mkl_umath in Numpy + >>> print(mkl_umath.is_patched()) + # True + + >>> mkl_umath.is_patched() + # False + + """ def __enter__(self): use_in_numpy() return self