Skip to content

Commit 96d6aee

Browse files
committed
[BE] Introduce build_domains function
And call it to rebuild only domains if torch wheel is available
1 parent d5c5b18 commit 96d6aee

File tree

1 file changed

+27
-19
lines changed

1 file changed

+27
-19
lines changed

aarch64_linux/build_aarch64_wheel.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -448,10 +448,10 @@ def build_torchaudio(host: RemoteHost, *,
448448

449449

450450
def configure_system(host: RemoteHost, *,
451-
compiler="gcc-8",
452-
use_conda=True,
453-
python_version="3.8",
454-
enable_mkldnn=False) -> None:
451+
compiler: str = "gcc-8",
452+
use_conda: bool = True,
453+
python_version: str = "3.8",
454+
enable_mkldnn: bool = False) -> None:
455455
if use_conda:
456456
install_condaforge_python(host, python_version)
457457

@@ -478,14 +478,25 @@ def configure_system(host: RemoteHost, *,
478478
host.run_cmd("sudo pip3 install numpy")
479479

480480

481+
def build_domains(host: RemoteHost, *,
482+
branch: str = "master",
483+
use_conda: bool = True,
484+
git_clone_flags: str = "") -> Tuple[str, str, str, str]:
485+
vision_wheel_name = build_torchvision(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags)
486+
audio_wheel_name = build_torchaudio(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags)
487+
data_wheel_name = build_torchdata(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags)
488+
text_wheel_name = build_torchtext(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags)
489+
return (vision_wheel_name, audio_wheel_name, data_wheel_name, text_wheel_name)
490+
491+
481492
def start_build(host: RemoteHost, *,
482-
branch="master",
483-
compiler="gcc-8",
484-
use_conda=True,
485-
python_version="3.8",
493+
branch: str = "master",
494+
compiler: str = "gcc-8",
495+
use_conda: bool = True,
496+
python_version: str = "3.8",
486497
pytorch_only: bool = False,
487-
shallow_clone=True,
488-
enable_mkldnn=False) -> Tuple[str, str]:
498+
shallow_clone: bool = True,
499+
enable_mkldnn: bool = False) -> Tuple[str, str, str, str, str]:
489500
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
490501
if host.using_docker() and not use_conda:
491502
print("Auto-selecting conda option for docker images")
@@ -553,13 +564,10 @@ def start_build(host: RemoteHost, *,
553564
host.run_cmd(f"pip3 install pytorch/dist/{pytorch_wheel_name}")
554565

555566
if pytorch_only:
556-
return pytorch_wheel_name, None
557-
vision_wheel_name = build_torchvision(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags)
558-
build_torchaudio(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags)
559-
build_torchtext(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags)
560-
build_torchdata(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags)
567+
return (pytorch_wheel_name, None, None, None, None)
568+
domain_wheels = build_domains(host, branch=branch, use_conda=use_conda, git_clone_flags=git_clone_flags)
561569

562-
return pytorch_wheel_name, vision_wheel_name
570+
return (pytorch_wheel_name, *domain_wheels)
563571

564572

565573
embed_library_script = """
@@ -758,9 +766,9 @@ def parse_arguments():
758766
enable_mkldnn=False)
759767
print("Installing PyTorch wheel")
760768
host.run_cmd("pip3 install torch")
761-
build_torchvision(host,
762-
branch=args.branch,
763-
git_clone_flags=" --depth 1 --shallow-submodules")
769+
build_domains(host,
770+
branch=args.branch,
771+
git_clone_flags=" --depth 1 --shallow-submodules")
764772
else:
765773
start_build(host,
766774
branch=args.branch,

0 commit comments

Comments
 (0)