|
21 | 21 | import time
|
22 | 22 | import warnings
|
23 | 23 | from collections.abc import Callable
|
24 |
| -from functools import cache |
| 24 | +from contextlib import AbstractContextManager, nullcontext |
25 | 25 | from io import BytesIO, StringIO
|
26 | 26 | from typing import TYPE_CHECKING, Protocol, cast
|
27 | 27 |
|
@@ -272,23 +272,24 @@ def _get_ext_suffix():
|
272 | 272 | return dist_suffix
|
273 | 273 |
|
274 | 274 |
|
275 |
| -@cache # See explanation in docstring. |
276 |
| -def add_gcc_dll_directory() -> None: |
| 275 | +def add_gcc_dll_directory() -> AbstractContextManager[None]: |
277 | 276 | """On Windows, detect and add the location of gcc to the DLL search directory.
|
278 | 277 |
|
279 | 278 | On non-Windows platforms this is a noop.
|
280 | 279 |
|
281 |
| - The @cache decorator ensures that this function only executes once to avoid |
282 |
| - redundant entries. See <https://github.com/pymc-devs/pytensor/pull/678>. |
| 280 | + Returns a context manager to be used with `with`. The entry is removed when the |
| 281 | + context manager is closed. See <https://github.com/pymc-devs/pytensor/pull/678>. |
283 | 282 | """
|
| 283 | + cm: AbstractContextManager[None] = nullcontext() |
284 | 284 | if (sys.platform == "win32") & (hasattr(os, "add_dll_directory")):
|
285 | 285 | gcc_path = shutil.which("gcc")
|
286 | 286 | if gcc_path is not None:
|
287 | 287 | # Since add_dll_directory is only defined on windows, we need
|
288 | 288 | # the ignore[attr-defined] on non-Windows platforms.
|
289 | 289 | # For Windows we need ignore[unused-ignore] since the ignore
|
290 | 290 | # is unnecessary with that platform.
|
291 |
| - os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore[attr-defined,unused-ignore] |
| 291 | + cm = os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore[attr-defined,unused-ignore] |
| 292 | + return cm |
292 | 293 |
|
293 | 294 |
|
294 | 295 | def dlimport(fullpath, suffix=None):
|
@@ -340,20 +341,20 @@ def dlimport(fullpath, suffix=None):
|
340 | 341 | _logger.debug(f"module_name {module_name}")
|
341 | 342 |
|
342 | 343 | sys.path[0:0] = [workdir] # insert workdir at beginning (temporarily)
|
343 |
| - add_gcc_dll_directory() |
344 |
| - global import_time |
345 |
| - try: |
346 |
| - importlib.invalidate_caches() |
347 |
| - t0 = time.perf_counter() |
348 |
| - with warnings.catch_warnings(): |
349 |
| - warnings.filterwarnings("ignore", message="numpy.ndarray size changed") |
350 |
| - rval = __import__(module_name, {}, {}, [module_name]) |
351 |
| - t1 = time.perf_counter() |
352 |
| - import_time += t1 - t0 |
353 |
| - if not rval: |
354 |
| - raise Exception("__import__ failed", fullpath) |
355 |
| - finally: |
356 |
| - del sys.path[0] |
| 344 | + with add_gcc_dll_directory(): |
| 345 | + global import_time |
| 346 | + try: |
| 347 | + importlib.invalidate_caches() |
| 348 | + t0 = time.perf_counter() |
| 349 | + with warnings.catch_warnings(): |
| 350 | + warnings.filterwarnings("ignore", message="numpy.ndarray size changed") |
| 351 | + rval = __import__(module_name, {}, {}, [module_name]) |
| 352 | + t1 = time.perf_counter() |
| 353 | + import_time += t1 - t0 |
| 354 | + if not rval: |
| 355 | + raise Exception("__import__ failed", fullpath) |
| 356 | + finally: |
| 357 | + del sys.path[0] |
357 | 358 |
|
358 | 359 | assert fullpath.startswith(rval.__file__)
|
359 | 360 | return rval
|
|
0 commit comments