|
4 | 4 | """
|
5 | 5 |
|
6 | 6 | import atexit
|
| 7 | +from contextlib import nullcontext |
7 | 8 | import importlib
|
8 | 9 | import logging
|
9 | 10 | import os
|
|
23 | 24 | from collections.abc import Callable
|
24 | 25 | from functools import cache
|
25 | 26 | from io import BytesIO, StringIO
|
26 |
| -from typing import TYPE_CHECKING, Protocol, cast |
| 27 | +from typing import TYPE_CHECKING, Generator, Protocol, cast |
27 | 28 |
|
28 | 29 | import numpy as np
|
29 | 30 | from setuptools._distutils.sysconfig import (
|
@@ -272,19 +273,20 @@ def _get_ext_suffix():
|
272 | 273 | return dist_suffix
|
273 | 274 |
|
274 | 275 |
|
275 |
| -@cache # See explanation in docstring. |
276 |
| -def add_gcc_dll_directory() -> None: |
| 276 | +def add_gcc_dll_directory() -> Generator[None, None, None]: |
277 | 277 | """On Windows, detect and add the location of gcc to the DLL search directory.
|
278 | 278 |
|
279 | 279 | On non-Windows platforms this is a noop.
|
280 | 280 |
|
281 |
| - The @cache decorator ensures that this function only executes once to avoid |
282 |
| - redundant entries. See <https://github.com/pymc-devs/pytensor/pull/678>. |
| 281 | + Returns a context manager to be used with `with`. The entry is removed when the |
| 282 | + context manager is closed. See <https://github.com/pymc-devs/pytensor/pull/678>. |
283 | 283 | """
|
284 | 284 | if (sys.platform == "win32") & (hasattr(os, "add_dll_directory")):
|
285 | 285 | gcc_path = shutil.which("gcc")
|
286 | 286 | if gcc_path is not None:
|
287 |
| - os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore |
| 287 | + return os.add_dll_directory(os.path.dirname(gcc_path)) # type: ignore |
| 288 | + else: |
| 289 | + return nullcontext() |
288 | 290 |
|
289 | 291 |
|
290 | 292 | def dlimport(fullpath, suffix=None):
|
@@ -336,20 +338,20 @@ def dlimport(fullpath, suffix=None):
|
336 | 338 | _logger.debug(f"module_name {module_name}")
|
337 | 339 |
|
338 | 340 | sys.path[0:0] = [workdir] # insert workdir at beginning (temporarily)
|
339 |
| - add_gcc_dll_directory() |
340 |
| - global import_time |
341 |
| - try: |
342 |
| - importlib.invalidate_caches() |
343 |
| - t0 = time.perf_counter() |
344 |
| - with warnings.catch_warnings(): |
345 |
| - warnings.filterwarnings("ignore", message="numpy.ndarray size changed") |
346 |
| - rval = __import__(module_name, {}, {}, [module_name]) |
347 |
| - t1 = time.perf_counter() |
348 |
| - import_time += t1 - t0 |
349 |
| - if not rval: |
350 |
| - raise Exception("__import__ failed", fullpath) |
351 |
| - finally: |
352 |
| - del sys.path[0] |
| 341 | + with add_gcc_dll_directory(): |
| 342 | + global import_time |
| 343 | + try: |
| 344 | + importlib.invalidate_caches() |
| 345 | + t0 = time.perf_counter() |
| 346 | + with warnings.catch_warnings(): |
| 347 | + warnings.filterwarnings("ignore", message="numpy.ndarray size changed") |
| 348 | + rval = __import__(module_name, {}, {}, [module_name]) |
| 349 | + t1 = time.perf_counter() |
| 350 | + import_time += t1 - t0 |
| 351 | + if not rval: |
| 352 | + raise Exception("__import__ failed", fullpath) |
| 353 | + finally: |
| 354 | + del sys.path[0] |
353 | 355 |
|
354 | 356 | assert fullpath.startswith(rval.__file__)
|
355 | 357 | return rval
|
|
0 commit comments