From 982aa3eced5560491b068cdb8d026535c9727a6b Mon Sep 17 00:00:00 2001 From: Sunita Nadampalli Date: Sat, 11 Feb 2023 20:10:25 -0600 Subject: [PATCH] [aarch64] add support for torchdata wheel building --- build_aarch64_wheel.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/build_aarch64_wheel.py b/build_aarch64_wheel.py index ccb622e01..48ed70bc9 100755 --- a/build_aarch64_wheel.py +++ b/build_aarch64_wheel.py @@ -321,6 +321,40 @@ def build_torchvision(host: RemoteHost, *, return vision_wheel_name +def build_torchdata(host: RemoteHost, *, + branch: str = "master", + use_conda: bool = True, + git_clone_flags: str = "") -> str: + print('Checking out TorchData repo') + git_clone_flags += " --recurse-submodules" + build_version = checkout_repo(host, + branch=branch, + url="https://github.com/pytorch/data", + git_clone_flags=git_clone_flags, + mapping={ + "v1.13.1": ("0.5.1", ""), + }) + print('Building TorchData wheel') + build_vars = "" + if branch == 'nightly': + version = host.check_output(["if [ -f data/version.txt ]; then cat data/version.txt; fi"]).strip() + build_date = host.check_output("cd pytorch ; git log --pretty=format:%s -1").strip().split()[0].replace("-", "") + build_vars += f"BUILD_VERSION={version}.dev{build_date}" + elif build_version is not None: + build_vars += f"BUILD_VERSION={build_version}" + if host.using_docker(): + build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" + + host.run_cmd(f"cd data; {build_vars} python3 setup.py bdist_wheel") + wheel_name = host.list_dir("data/dist")[0] + embed_libgomp(host, use_conda, os.path.join('data', 'dist', wheel_name)) + + print('Copying TorchData wheel') + host.download_wheel(os.path.join('data', 'dist', wheel_name)) + + return wheel_name + + def build_torchtext(host: RemoteHost, *, branch: str = "master", use_conda: bool = True, @@ -512,6 +546,7 @@ def start_build(host: RemoteHost, *, vision_wheel_name = build_torchvision(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags) build_torchaudio(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags) build_torchtext(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags) + build_torchdata(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags) return pytorch_wheel_name, vision_wheel_name