@@ -448,10 +448,10 @@ def build_torchaudio(host: RemoteHost, *,
448
448
449
449
450
450
def configure_system (host : RemoteHost , * ,
451
- compiler = "gcc-8" ,
452
- use_conda = True ,
453
- python_version = "3.8" ,
454
- enable_mkldnn = False ) -> None :
451
+ compiler : str = "gcc-8" ,
452
+ use_conda : bool = True ,
453
+ python_version : str = "3.8" ,
454
+ enable_mkldnn : bool = False ) -> None :
455
455
if use_conda :
456
456
install_condaforge_python (host , python_version )
457
457
@@ -478,14 +478,25 @@ def configure_system(host: RemoteHost, *,
478
478
host .run_cmd ("sudo pip3 install numpy" )
479
479
480
480
481
+ def build_domains (host : RemoteHost , * ,
482
+ branch : str = "master" ,
483
+ use_conda : bool = True ,
484
+ git_clone_flags : str = "" ) -> Tuple [str , str , str , str ]:
485
+ vision_wheel_name = build_torchvision (host , branch = branch , use_conda = use_conda , git_clone_flags = git_clone_flags )
486
+ audio_wheel_name = build_torchaudio (host , branch = branch , use_conda = use_conda , git_clone_flags = git_clone_flags )
487
+ data_wheel_name = build_torchdata (host , branch = branch , use_conda = use_conda , git_clone_flags = git_clone_flags )
488
+ text_wheel_name = build_torchtext (host , branch = branch , use_conda = use_conda , git_clone_flags = git_clone_flags )
489
+ return (vision_wheel_name , audio_wheel_name , data_wheel_name , text_wheel_name )
490
+
491
+
481
492
def start_build (host : RemoteHost , * ,
482
- branch = "master" ,
483
- compiler = "gcc-8" ,
484
- use_conda = True ,
485
- python_version = "3.8" ,
493
+ branch : str = "master" ,
494
+ compiler : str = "gcc-8" ,
495
+ use_conda : bool = True ,
496
+ python_version : str = "3.8" ,
486
497
pytorch_only : bool = False ,
487
- shallow_clone = True ,
488
- enable_mkldnn = False ) -> Tuple [str , str ]:
498
+ shallow_clone : bool = True ,
499
+ enable_mkldnn : bool = False ) -> Tuple [str , str , str , str , str ]:
489
500
git_clone_flags = " --depth 1 --shallow-submodules" if shallow_clone else ""
490
501
if host .using_docker () and not use_conda :
491
502
print ("Auto-selecting conda option for docker images" )
@@ -553,13 +564,10 @@ def start_build(host: RemoteHost, *,
553
564
host .run_cmd (f"pip3 install pytorch/dist/{ pytorch_wheel_name } " )
554
565
555
566
if pytorch_only :
556
- return pytorch_wheel_name , None
557
- vision_wheel_name = build_torchvision (host , branch = branch , use_conda = use_conda , git_clone_flags = git_clone_flags )
558
- build_torchaudio (host , branch = branch , use_conda = use_conda , git_clone_flags = git_clone_flags )
559
- build_torchtext (host , branch = branch , use_conda = use_conda , git_clone_flags = git_clone_flags )
560
- build_torchdata (host , branch = branch , use_conda = use_conda , git_clone_flags = git_clone_flags )
567
+ return (pytorch_wheel_name , None , None , None , None )
568
+ domain_wheels = build_domains (host , branch = branch , use_conda = use_conda , git_clone_flags = git_clone_flags )
561
569
562
- return pytorch_wheel_name , vision_wheel_name
570
+ return ( pytorch_wheel_name , * domain_wheels )
563
571
564
572
565
573
embed_library_script = """
@@ -758,9 +766,9 @@ def parse_arguments():
758
766
enable_mkldnn = False )
759
767
print ("Installing PyTorch wheel" )
760
768
host .run_cmd ("pip3 install torch" )
761
- build_torchvision (host ,
762
- branch = args .branch ,
763
- git_clone_flags = " --depth 1 --shallow-submodules" )
769
+ build_domains (host ,
770
+ branch = args .branch ,
771
+ git_clone_flags = " --depth 1 --shallow-submodules" )
764
772
else :
765
773
start_build (host ,
766
774
branch = args .branch ,
0 commit comments