Skip to content

Commit 8bdf838

Browse files
committed
Use add_dll_directory as a context manager
1 parent bc3dda0 commit 8bdf838

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

pytensor/link/c/cmodule.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import atexit
7+
from contextlib import nullcontext
78
import importlib
89
import logging
910
import os
@@ -23,7 +24,7 @@
2324
from collections.abc import Callable
2425
from functools import cache
2526
from io import BytesIO, StringIO
26-
from typing import TYPE_CHECKING, Protocol, cast
27+
from typing import TYPE_CHECKING, Generator, Protocol, cast
2728

2829
import numpy as np
2930
from setuptools._distutils.sysconfig import (
@@ -272,19 +273,20 @@ def _get_ext_suffix():
272273
return dist_suffix
273274

274275

275-
@cache # See explanation in docstring.
276-
def add_gcc_dll_directory() -> None:
276+
def add_gcc_dll_directory() -> Generator[None, None, None]:
277277
"""On Windows, detect and add the location of gcc to the DLL search directory.
278278
279279
On non-Windows platforms this is a noop.
280280
281-
The @cache decorator ensures that this function only executes once to avoid
282-
redundant entries. See <https://github.com/pymc-devs/pytensor/pull/678>.
281+
Returns a context manager to be used with `with`. The entry is removed when the
282+
context manager is closed. See <https://github.com/pymc-devs/pytensor/pull/678>.
283283
"""
284284
if (sys.platform == "win32") & (hasattr(os, "add_dll_directory")):
285285
gcc_path = shutil.which("gcc")
286286
if gcc_path is not None:
287-
os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore
287+
return os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore
288+
else:
289+
return nullcontext()
288290

289291

290292
def dlimport(fullpath, suffix=None):
@@ -336,20 +338,20 @@ def dlimport(fullpath, suffix=None):
336338
_logger.debug(f"module_name {module_name}")
337339

338340
sys.path[0:0] = [workdir] # insert workdir at beginning (temporarily)
339-
add_gcc_dll_directory()
340-
global import_time
341-
try:
342-
importlib.invalidate_caches()
343-
t0 = time.perf_counter()
344-
with warnings.catch_warnings():
345-
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
346-
rval = __import__(module_name, {}, {}, [module_name])
347-
t1 = time.perf_counter()
348-
import_time += t1 - t0
349-
if not rval:
350-
raise Exception("__import__ failed", fullpath)
351-
finally:
352-
del sys.path[0]
341+
with add_gcc_dll_directory():
342+
global import_time
343+
try:
344+
importlib.invalidate_caches()
345+
t0 = time.perf_counter()
346+
with warnings.catch_warnings():
347+
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
348+
rval = __import__(module_name, {}, {}, [module_name])
349+
t1 = time.perf_counter()
350+
import_time += t1 - t0
351+
if not rval:
352+
raise Exception("__import__ failed", fullpath)
353+
finally:
354+
del sys.path[0]
353355

354356
assert fullpath.startswith(rval.__file__)
355357
return rval

0 commit comments

Comments
 (0)