Skip to content

Commit 9c813c3

Browse files
committed
[WIP] Implement windows
Signed-off-by: Cristian Le <[email protected]>
1 parent a9cd9db commit 9c813c3

File tree

2 files changed

+242
-5
lines changed

2 files changed

+242
-5
lines changed

src/scikit_build_core/repair_wheel/windows.py

Lines changed: 239 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44

55
from __future__ import annotations
66

7-
from typing import TYPE_CHECKING
7+
import dataclasses
8+
import os.path
9+
import textwrap
10+
from pathlib import Path
11+
from typing import TYPE_CHECKING, ClassVar
812

9-
from . import WheelRepairer
13+
from .._logging import logger
14+
from . import WheelRepairer, _get_buildenv_platlib
1015

1116
if TYPE_CHECKING:
1217
from ..file_api.model.codemodel import Target
@@ -18,13 +23,243 @@ def __dir__() -> list[str]:
1823
return __all__
1924

2025

26+
@dataclasses.dataclass
2127
class WindowsWheelRepairer(WheelRepairer):
2228
"""
2329
Do some windows specific magic.
2430
"""
2531

32+
# TODO: Currently this installs the dll libraries in the scripts folder
33+
# Maybe it's better to point them to the original paths instead.
34+
2635
_platform = "Windows"
2736

37+
PATCH_PY_FILE: ClassVar[str] = textwrap.dedent("""\
38+
# start scikit-build-core Windows patch
39+
def _skbuild_patch_dll_dir():
40+
import os
41+
import os.path
42+
43+
mod_dir = os.path.abspath(os.path.dirname(__file__))
44+
path_to_platlib = os.path.normpath({path_to_platlib!r})
45+
dll_paths = {dll_paths!r}
46+
for path in dll_paths:
47+
path = os.path.normpath(path)
48+
path = os.path.join(mod_dir, path_to_platlib, path)
49+
os.add_dll_directory(path)
50+
51+
_skbuild_patch_dll_dir()
52+
del _skbuild_patch_dll_dir
53+
# end scikit-build-core Windows patch
54+
""")
55+
dll_dirs: set[Path] = dataclasses.field(default_factory=set, init=False)
56+
"""All dll paths used relative to ``platlib``."""
57+
58+
def get_dll_path_from_lib(self, lib_path: Path) -> Path | None:
59+
"""Guess the dll path from lib path."""
60+
dll_path = None
61+
platlib = Path(_get_buildenv_platlib())
62+
lib_path = lib_path.relative_to(platlib)
63+
dll_name = lib_path.with_suffix(".dll").name
64+
# Try to find the dll in the same package directory
65+
if len(lib_path.parts) > 1:
66+
pkg_dir = lib_path.parts[0]
67+
for root, files, _ in os.walk(platlib / pkg_dir):
68+
if dll_name in files:
69+
dll_path = Path(root) / dll_name
70+
break
71+
else:
72+
logger.debug(
73+
"Did not find the dll file under {pkg_dir}",
74+
pkg_dir=pkg_dir,
75+
)
76+
if not dll_path:
77+
logger.debug(
78+
"Looking for {dll_name} in all platlib path.",
79+
dll_name=dll_name,
80+
)
81+
for root, files, _ in os.walk(platlib):
82+
if dll_name in files:
83+
dll_path = Path(root) / dll_name
84+
break
85+
else:
86+
logger.warning(
87+
"Could not find dll file {dll_name} corresponding to {lib_path}",
88+
dll_name=dll_name,
89+
lib_path=lib_path,
90+
)
91+
return None
92+
logger.debug(
93+
"Found dll file {dll_path}",
94+
dll_path=dll_path,
95+
)
96+
return self.path_relative_site_packages(dll_path)
97+
98+
def get_library_dependencies(self, target: Target) -> list[Target]:
99+
msg = "get_library_dependencies is not generalized for Windows."
100+
raise NotImplementedError(msg)
101+
102+
def get_dependency_dll(self, target: Target) -> list[Path]:
103+
"""Get the dll due to target link dependencies."""
104+
dll_paths = []
105+
for dep in target.dependencies:
106+
dep_target = next(targ for targ in self.targets if targ.id == dep.id)
107+
if dep_target.type != "SHARED_LIBRARY":
108+
logger.debug(
109+
"Skipping dependency {dep_target} of type {type}",
110+
dep_target=dep_target.name,
111+
type=dep_target.type,
112+
)
113+
continue
114+
if not dep_target.install:
115+
logger.warning(
116+
"Dependency {dep_target} is not installed",
117+
dep_target=dep_target.name,
118+
)
119+
continue
120+
dll_artifact = next(
121+
artifact.path
122+
for artifact in dep_target.artifacts
123+
if artifact.path.suffix == ".dll"
124+
)
125+
for install_path in self.get_wheel_install_paths(dep_target):
126+
dep_install_path = self.install_dir / install_path
127+
if (dep_install_path / dll_artifact).exists():
128+
break
129+
else:
130+
logger.warning(
131+
"Could not find installed {dll_artifact} location in install paths: {install_path}",
132+
dll_artifact=dll_artifact,
133+
install_path=[
134+
dest.path for dest in dep_target.install.destinations
135+
],
136+
)
137+
continue
138+
dll_path = self.path_relative_site_packages(dep_install_path)
139+
dll_paths.append(dll_path)
140+
return dll_paths
141+
142+
def get_package_dll(self, target: Target) -> list[Path]:
143+
"""
144+
Get the dll due to external package linkage.
145+
146+
Have to use the guess the dll paths until the package targets are exposed.
147+
https://gitlab.kitware.com/cmake/cmake/-/issues/26755
148+
"""
149+
if not target.link:
150+
return []
151+
dll_paths = []
152+
for link_command in target.link.commandFragments:
153+
if link_command.role == "flags":
154+
if not link_command.fragment:
155+
logger.debug(
156+
"Skipping {target} link-flags: {flags}",
157+
target=target.name,
158+
flags=link_command.fragment,
159+
)
160+
continue
161+
if link_command.role != "libraries":
162+
logger.warning(
163+
"File-api link role {role} is not supported. "
164+
"Target={target}, command={command}",
165+
target=target.name,
166+
role=link_command.role,
167+
command=link_command.fragment,
168+
)
169+
continue
170+
# The remaining case should be a path
171+
try:
172+
# TODO: how to best catch if a string is a valid path?
173+
lib_path = Path(link_command.fragment)
174+
if not lib_path.is_absolute():
175+
# If the link_command is a space-separated list of libraries, this should be skipped
176+
logger.debug(
177+
"Skipping non-absolute-path library: {fragment}",
178+
fragment=link_command.fragment,
179+
)
180+
continue
181+
dll_path = self.get_dll_path_from_lib(lib_path)
182+
if not dll_path:
183+
continue
184+
dll_paths.append(dll_path)
185+
except Exception as exc:
186+
logger.warning(
187+
"Could not parse link-library as a path: {fragment}\nexc = {exc}",
188+
fragment=link_command.fragment,
189+
exc=exc,
190+
)
191+
continue
192+
return dll_paths
193+
28194
def patch_target(self, target: Target) -> None:
29-
# TODO: Implement patching
30-
pass
195+
# Here we just gather all dll paths needed for each target
196+
package_dlls = self.get_package_dll(target)
197+
dependency_dlls = self.get_dependency_dll(target)
198+
if not package_dlls and not dependency_dlls:
199+
logger.warning(
200+
"No dll files found for target {target}",
201+
target=target.name,
202+
)
203+
return
204+
logger.debug(
205+
"Found dlls for target {target}:\n"
206+
"package_dlls={package_dlls}\n"
207+
"dependency_dlls={dependency_dlls}\n",
208+
target=target.name,
209+
package_dlls=package_dlls,
210+
dependency_dlls=dependency_dlls,
211+
)
212+
self.dll_dirs.update(package_dlls)
213+
self.dll_dirs.update(dependency_dlls)
214+
215+
def patch_python_file(self, file: Path) -> None:
216+
"""
217+
Patch python package or top-level module.
218+
219+
Make sure the python files have an appropriate ``os.add_dll_directory``
220+
for the scripts directory.
221+
"""
222+
assert self.dll_dirs
223+
assert all(not path.is_absolute() for path in self.dll_dirs)
224+
logger.debug(
225+
"Patching python file: {file}",
226+
file=file,
227+
)
228+
platlib = Path(self.wheel_dirs["platlib"])
229+
content = file.read_text()
230+
mod_dir = file.parent
231+
path_to_platlib = os.path.relpath(platlib, mod_dir)
232+
patch_script = self.PATCH_PY_FILE.format(
233+
path_to_platlib=path_to_platlib,
234+
dll_paths=[str(path) for path in self.dll_dirs],
235+
)
236+
# TODO: Account for the header comments, __future__.annotations, etc.
237+
with file.open("w") as f:
238+
f.write(f"{patch_script}\n" + content)
239+
240+
def repair_wheel(self) -> None:
241+
super().repair_wheel()
242+
platlib = Path(self.wheel_dirs["platlib"])
243+
if not self.dll_dirs:
244+
logger.debug(
245+
"Skipping wheel repair because no site-package dlls were found."
246+
)
247+
return
248+
logger.debug(
249+
"Patching dll directories: {dll_dirs}",
250+
dll_dirs=self.dll_dirs,
251+
)
252+
# TODO: Not handling namespace packages with this
253+
for path in platlib.iterdir():
254+
assert isinstance(path, Path)
255+
if path.is_dir():
256+
pkg_file = path / "__init__.py"
257+
if not pkg_file.exists():
258+
logger.debug(
259+
"Ignoring non-python package: {pkg_file}",
260+
pkg_file=pkg_file,
261+
)
262+
continue
263+
self.patch_python_file(pkg_file)
264+
elif path.suffix == ".py":
265+
self.patch_python_file(path)

tests/test_repair_wheel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def test_full_build(
7272
wheels = list(dist.glob("*.whl"))
7373
isolated.install(*wheels)
7474

75-
isolated.run("main")
75+
if platform.system() != "Windows":
76+
# For some reason isolated.run cannot run this on windows
77+
isolated.run("main")
7678
isolated.module("repair_wheel")
7779
isolated.execute(
7880
"from repair_wheel._module import hello; hello()",

0 commit comments

Comments
 (0)