Skip to content

Commit 4e5570d

Browse files
maresbricardoV94
authored andcommitted
Use add_dll_directory as a context manager
1 parent 5e612ab commit 4e5570d

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

pytensor/link/c/cmodule.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import time
2222
import warnings
2323
from collections.abc import Callable
24-
from functools import cache
24+
from contextlib import AbstractContextManager, nullcontext
2525
from io import BytesIO, StringIO
2626
from typing import TYPE_CHECKING, Protocol, cast
2727

@@ -272,23 +272,24 @@ def _get_ext_suffix():
272272
return dist_suffix
273273

274274

275-
@cache # See explanation in docstring.
276-
def add_gcc_dll_directory() -> None:
275+
def add_gcc_dll_directory() -> AbstractContextManager[None]:
277276
"""On Windows, detect and add the location of gcc to the DLL search directory.
278277
279278
On non-Windows platforms this is a noop.
280279
281-
The @cache decorator ensures that this function only executes once to avoid
282-
redundant entries. See <https://github.com/pymc-devs/pytensor/pull/678>.
280+
Returns a context manager to be used with `with`. The entry is removed when the
281+
context manager is closed. See <https://github.com/pymc-devs/pytensor/pull/678>.
283282
"""
283+
cm: AbstractContextManager[None] = nullcontext()
284284
if (sys.platform == "win32") & (hasattr(os, "add_dll_directory")):
285285
gcc_path = shutil.which("gcc")
286286
if gcc_path is not None:
287287
# Since add_dll_directory is only defined on windows, we need
288288
# the ignore[attr-defined] on non-Windows platforms.
289289
# For Windows we need ignore[unused-ignore] since the ignore
290290
# is unnecessary with that platform.
291-
os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore[attr-defined,unused-ignore]
291+
cm = os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore[attr-defined,unused-ignore]
292+
return cm
292293

293294

294295
def dlimport(fullpath, suffix=None):
@@ -340,20 +341,20 @@ def dlimport(fullpath, suffix=None):
340341
_logger.debug(f"module_name {module_name}")
341342

342343
sys.path[0:0] = [workdir] # insert workdir at beginning (temporarily)
343-
add_gcc_dll_directory()
344-
global import_time
345-
try:
346-
importlib.invalidate_caches()
347-
t0 = time.perf_counter()
348-
with warnings.catch_warnings():
349-
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
350-
rval = __import__(module_name, {}, {}, [module_name])
351-
t1 = time.perf_counter()
352-
import_time += t1 - t0
353-
if not rval:
354-
raise Exception("__import__ failed", fullpath)
355-
finally:
356-
del sys.path[0]
344+
with add_gcc_dll_directory():
345+
global import_time
346+
try:
347+
importlib.invalidate_caches()
348+
t0 = time.perf_counter()
349+
with warnings.catch_warnings():
350+
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
351+
rval = __import__(module_name, {}, {}, [module_name])
352+
t1 = time.perf_counter()
353+
import_time += t1 - t0
354+
if not rval:
355+
raise Exception("__import__ failed", fullpath)
356+
finally:
357+
del sys.path[0]
357358

358359
assert fullpath.startswith(rval.__file__)
359360
return rval

0 commit comments

Comments
 (0)