From 1ffbae3f54dccb90547a079925c99dbfc7c20173 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Fri, 12 Apr 2024 17:09:29 +0200 Subject: [PATCH] Use add_dll_directory as a context manager --- pytensor/link/c/cmodule.py | 41 +++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/pytensor/link/c/cmodule.py b/pytensor/link/c/cmodule.py index 99fc410719..0666a191bd 100644 --- a/pytensor/link/c/cmodule.py +++ b/pytensor/link/c/cmodule.py @@ -21,7 +21,7 @@ import time import warnings from collections.abc import Callable -from functools import cache +from contextlib import AbstractContextManager, nullcontext from io import BytesIO, StringIO from typing import TYPE_CHECKING, Protocol, cast @@ -272,15 +272,15 @@ def _get_ext_suffix(): return dist_suffix -@cache # See explanation in docstring. -def add_gcc_dll_directory() -> None: +def add_gcc_dll_directory() -> AbstractContextManager[None]: """On Windows, detect and add the location of gcc to the DLL search directory. On non-Windows platforms this is a noop. - The @cache decorator ensures that this function only executes once to avoid - redundant entries. See . + Returns a context manager to be used with `with`. The entry is removed when the + context manager is closed. See . """ + cm: AbstractContextManager[None] = nullcontext() if (sys.platform == "win32") & (hasattr(os, "add_dll_directory")): gcc_path = shutil.which("gcc") if gcc_path is not None: @@ -288,7 +288,8 @@ def add_gcc_dll_directory() -> None: # the ignore[attr-defined] on non-Windows platforms. # For Windows we need ignore[unused-ignore] since the ignore # is unnecessary with that platform. - os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore[attr-defined,unused-ignore] + cm = os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore[attr-defined,unused-ignore] + return cm def dlimport(fullpath, suffix=None): @@ -340,20 +341,20 @@ def dlimport(fullpath, suffix=None): _logger.debug(f"module_name {module_name}") sys.path[0:0] = [workdir] # insert workdir at beginning (temporarily) - add_gcc_dll_directory() - global import_time - try: - importlib.invalidate_caches() - t0 = time.perf_counter() - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="numpy.ndarray size changed") - rval = __import__(module_name, {}, {}, [module_name]) - t1 = time.perf_counter() - import_time += t1 - t0 - if not rval: - raise Exception("__import__ failed", fullpath) - finally: - del sys.path[0] + with add_gcc_dll_directory(): + global import_time + try: + importlib.invalidate_caches() + t0 = time.perf_counter() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="numpy.ndarray size changed") + rval = __import__(module_name, {}, {}, [module_name]) + t1 = time.perf_counter() + import_time += t1 - t0 + if not rval: + raise Exception("__import__ failed", fullpath) + finally: + del sys.path[0] assert fullpath.startswith(rval.__file__) return rval