Skip to content

Commit a9ff5cf

Browse files
committed
Use add_dll_directory as a context manager
1 parent 0a13fbd commit a9ff5cf

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

pytensor/link/c/cmodule.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import atexit
7+
from contextlib import nullcontext
78
import importlib
89
import logging
910
import os
@@ -23,7 +24,7 @@
2324
from collections.abc import Callable
2425
from functools import cache
2526
from io import BytesIO, StringIO
26-
from typing import TYPE_CHECKING, Protocol, cast
27+
from typing import TYPE_CHECKING, Generator, Protocol, cast
2728

2829
import numpy as np
2930
from setuptools._distutils.sysconfig import (
@@ -272,14 +273,13 @@ def _get_ext_suffix():
272273
return dist_suffix
273274

274275

275-
@cache # See explanation in docstring.
276-
def add_gcc_dll_directory() -> None:
276+
def add_gcc_dll_directory() -> Generator[None, None, None]:
277277
"""On Windows, detect and add the location of gcc to the DLL search directory.
278278
279279
On non-Windows platforms this is a noop.
280280
281-
The @cache decorator ensures that this function only executes once to avoid
282-
redundant entries. See <https://github.com/pymc-devs/pytensor/pull/678>.
281+
Returns a context manager to be used with `with`. The entry is removed when the
282+
context manager is closed. See <https://github.com/pymc-devs/pytensor/pull/678>.
283283
"""
284284
if (sys.platform == "win32") & (hasattr(os, "add_dll_directory")):
285285
gcc_path = shutil.which("gcc")
@@ -288,7 +288,8 @@ def add_gcc_dll_directory() -> None:
288288
# the ignore[attr-defined] on non-Windows platforms.
289289
# For Windows we need ignore[unused-ignore] since the ignore
290290
# is unnecessary with that platform.
291-
os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore[attr-defined,unused-ignore]
291+
return os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore[attr-defined,unused-ignore]
292+
return nullcontext()
292293

293294

294295
def dlimport(fullpath, suffix=None):
@@ -340,20 +341,20 @@ def dlimport(fullpath, suffix=None):
340341
_logger.debug(f"module_name {module_name}")
341342

342343
sys.path[0:0] = [workdir] # insert workdir at beginning (temporarily)
343-
add_gcc_dll_directory()
344-
global import_time
345-
try:
346-
importlib.invalidate_caches()
347-
t0 = time.perf_counter()
348-
with warnings.catch_warnings():
349-
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
350-
rval = __import__(module_name, {}, {}, [module_name])
351-
t1 = time.perf_counter()
352-
import_time += t1 - t0
353-
if not rval:
354-
raise Exception("__import__ failed", fullpath)
355-
finally:
356-
del sys.path[0]
344+
with add_gcc_dll_directory():
345+
global import_time
346+
try:
347+
importlib.invalidate_caches()
348+
t0 = time.perf_counter()
349+
with warnings.catch_warnings():
350+
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
351+
rval = __import__(module_name, {}, {}, [module_name])
352+
t1 = time.perf_counter()
353+
import_time += t1 - t0
354+
if not rval:
355+
raise Exception("__import__ failed", fullpath)
356+
finally:
357+
del sys.path[0]
357358

358359
assert fullpath.startswith(rval.__file__)
359360
return rval

0 commit comments

Comments
 (0)