4
4
"""
5
5
6
6
import atexit
7
+ from contextlib import nullcontext
7
8
import importlib
8
9
import logging
9
10
import os
23
24
from collections .abc import Callable
24
25
from functools import cache
25
26
from io import BytesIO , StringIO
26
- from typing import TYPE_CHECKING , Protocol , cast
27
+ from typing import TYPE_CHECKING , Generator , Protocol , cast
27
28
28
29
import numpy as np
29
30
from setuptools ._distutils .sysconfig import (
@@ -272,14 +273,13 @@ def _get_ext_suffix():
272
273
return dist_suffix
273
274
274
275
275
- @cache # See explanation in docstring.
276
- def add_gcc_dll_directory () -> None :
276
+ def add_gcc_dll_directory () -> Generator [None , None , None ]:
277
277
"""On Windows, detect and add the location of gcc to the DLL search directory.
278
278
279
279
On non-Windows platforms this is a noop.
280
280
281
- The @cache decorator ensures that this function only executes once to avoid
282
- redundant entries . See <https://github.com/pymc-devs/pytensor/pull/678>.
281
+ Returns a context manager to be used with `with`. The entry is removed when the
282
+ context manager is closed . See <https://github.com/pymc-devs/pytensor/pull/678>.
283
283
"""
284
284
if (sys .platform == "win32" ) & (hasattr (os , "add_dll_directory" )):
285
285
gcc_path = shutil .which ("gcc" )
@@ -288,7 +288,8 @@ def add_gcc_dll_directory() -> None:
288
288
# the ignore[attr-defined] on non-Windows platforms.
289
289
# For Windows we need ignore[unused-ignore] since the ignore
290
290
# is unnecessary with that platform.
291
- os .add_dll_directory (os .path .dirname (gcc_path )) # type: ignore[attr-defined,unused-ignore]
291
+ return os .add_dll_directory (os .path .dirname (gcc_path )) # type: ignore[attr-defined,unused-ignore]
292
+ return nullcontext ()
292
293
293
294
294
295
def dlimport (fullpath , suffix = None ):
@@ -340,20 +341,20 @@ def dlimport(fullpath, suffix=None):
340
341
_logger .debug (f"module_name { module_name } " )
341
342
342
343
sys .path [0 :0 ] = [workdir ] # insert workdir at beginning (temporarily)
343
- add_gcc_dll_directory ()
344
- global import_time
345
- try :
346
- importlib .invalidate_caches ()
347
- t0 = time .perf_counter ()
348
- with warnings .catch_warnings ():
349
- warnings .filterwarnings ("ignore" , message = "numpy.ndarray size changed" )
350
- rval = __import__ (module_name , {}, {}, [module_name ])
351
- t1 = time .perf_counter ()
352
- import_time += t1 - t0
353
- if not rval :
354
- raise Exception ("__import__ failed" , fullpath )
355
- finally :
356
- del sys .path [0 ]
344
+ with add_gcc_dll_directory ():
345
+ global import_time
346
+ try :
347
+ importlib .invalidate_caches ()
348
+ t0 = time .perf_counter ()
349
+ with warnings .catch_warnings ():
350
+ warnings .filterwarnings ("ignore" , message = "numpy.ndarray size changed" )
351
+ rval = __import__ (module_name , {}, {}, [module_name ])
352
+ t1 = time .perf_counter ()
353
+ import_time += t1 - t0
354
+ if not rval :
355
+ raise Exception ("__import__ failed" , fullpath )
356
+ finally :
357
+ del sys .path [0 ]
357
358
358
359
assert fullpath .startswith (rval .__file__ )
359
360
return rval
0 commit comments