2
2
3
3
import os
4
4
import subprocess
5
+ from pygit2 import Repository
5
6
from typing import Dict , List , Optional , Tuple
6
7
7
8
@@ -15,18 +16,20 @@ def list_dir(path: str) -> List[str]:
15
16
'''
16
17
Helper to get repo branches for specific versions
17
18
'''
18
- def checkout_repo (branch : str = "main" ,
19
- url : str = "" ,
20
- git_clone_flags : str = "" ,
21
- mapping : Dict [str , Tuple [str , str ]] = []) -> Optional [str ]:
19
+ def checkout_repo (
20
+ package : str ,
21
+ branch : str = "main" ,
22
+ url : str = "" ,
23
+ git_clone_flags : str = "" ,
24
+ mapping : Dict [str , Tuple [str , str ]] = []) -> Optional [str ]:
22
25
for prefix in mapping :
23
26
if not branch .startswith (prefix ):
24
27
continue
25
28
tag = f"v{ mapping [prefix ][0 ]} -{ mapping [prefix ][1 ]} "
26
- os .system (f"git clone { url } -b { tag } { git_clone_flags } " )
29
+ os .system (f"git clone { url } / { package } -b { tag } { git_clone_flags } " )
27
30
return mapping [prefix ][0 ]
28
31
29
- os .system (f"git clone { url } { git_clone_flags } " )
32
+ os .system (f"git clone { url } / { package } { git_clone_flags } " )
30
33
return None
31
34
32
35
@@ -70,7 +73,8 @@ def embed_libgomp(wheel_name) -> None:
70
73
def build_torchvision (branch : str = "main" ,
71
74
git_clone_flags : str = "" ) -> str :
72
75
print ('Checking out TorchVision repo' )
73
- build_version = checkout_repo (branch = branch ,
76
+ build_version = checkout_repo (package = "vision" ,
77
+ branch = branch ,
74
78
url = "https://github.com/pytorch/vision" ,
75
79
git_clone_flags = git_clone_flags ,
76
80
mapping = {
@@ -118,7 +122,8 @@ def build_torchaudio(branch: str = "main",
118
122
git_clone_flags : str = "" ) -> str :
119
123
print ('Checking out TorchAudio repo' )
120
124
git_clone_flags += " --recurse-submodules"
121
- build_version = checkout_repo (branch = branch ,
125
+ build_version = checkout_repo (package = "audio" ,
126
+ branch = branch ,
122
127
url = "https://github.com/pytorch/audio" ,
123
128
git_clone_flags = git_clone_flags ,
124
129
mapping = {
@@ -161,7 +166,8 @@ def build_torchtext(branch: str = "main",
161
166
print ('Checking out TorchText repo' )
162
167
os .system (f"cd /" )
163
168
git_clone_flags += " --recurse-submodules"
164
- build_version = checkout_repo (branch = branch ,
169
+ build_version = checkout_repo (package = "text" ,
170
+ branch = branch ,
165
171
url = "https://github.com/pytorch/text" ,
166
172
git_clone_flags = git_clone_flags ,
167
173
mapping = {
@@ -187,7 +193,7 @@ def build_torchtext(branch: str = "main",
187
193
elif build_version is not None :
188
194
build_vars += f"BUILD_VERSION={ build_version } "
189
195
190
- os .system (f"cd text; { build_vars } python3 setup.py bdist_wheel" )
196
+ os .system (f"cd / text; { build_vars } python3 setup.py bdist_wheel" )
191
197
wheel_name = list_dir ("/text/dist" )[0 ]
192
198
embed_libgomp (f"/text/dist/{ wheel_name } " )
193
199
@@ -203,7 +209,8 @@ def build_torchdata(branch: str = "main",
203
209
git_clone_flags : str = "" ) -> str :
204
210
print ('Checking out TorchData repo' )
205
211
git_clone_flags += " --recurse-submodules"
206
- build_version = checkout_repo (branch = branch ,
212
+ build_version = checkout_repo (package = "data" ,
213
+ branch = branch ,
207
214
url = "https://github.com/pytorch/data" ,
208
215
git_clone_flags = git_clone_flags ,
209
216
mapping = {
@@ -250,8 +257,10 @@ def parse_arguments():
250
257
251
258
args = parse_arguments ()
252
259
enable_mkldnn = args .enable_mkldnn
253
- os .system ("cd /pytorch" )
254
- branch = subprocess .check_output ("git rev-parse --abbrev-ref HEAD" )
260
+ repo = Repository ('/pytorch' )
261
+ branch = repo .head .name
262
+ if branch == 'HEAD' :
263
+ branch = 'master'
255
264
256
265
git_clone_flags = " --depth 1 --shallow-submodules"
257
266
os .system (f"conda install -y ninja scons" )
@@ -261,31 +270,35 @@ def parse_arguments():
261
270
262
271
print ('Building PyTorch wheel' )
263
272
build_vars = "CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000 "
264
- os .system (f"cd /pytorch; pip install -r requirements.txt" )
265
- os .system (f"pip install auditwheel" )
266
273
os .system (f"python setup.py clean" )
267
274
268
275
if branch == 'nightly' or branch == 'master' :
269
276
build_date = subprocess .check_output (['git' ,'log' ,'--pretty=format:%cs' ,'-1' ], cwd = '/pytorch' ).decode ().replace ('-' ,'' )
270
277
version = subprocess .check_output (['cat' ,'version.txt' ], cwd = '/pytorch' ).decode ().strip ()[:- 2 ]
271
- build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={ version } .dev{ build_date } PYTORCH_BUILD_NUMBER=1"
278
+ build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={ version } .dev{ build_date } PYTORCH_BUILD_NUMBER=1 "
272
279
if branch .startswith ("v1." ) or branch .startswith ("v2." ):
273
- build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={ branch [1 :branch .find ('-' )]} PYTORCH_BUILD_NUMBER=1"
280
+ build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={ branch [1 :branch .find ('-' )]} PYTORCH_BUILD_NUMBER=1 "
274
281
if enable_mkldnn :
275
282
build_ArmComputeLibrary (git_clone_flags )
276
283
print ("build pytorch with mkldnn+acl backend" )
277
- os .system (f"export ACL_ROOT_DIR=/acl; export LD_LIBRARY_PATH=/acl/build; export ACL_LIBRARY=/acl/build" )
278
- build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
284
+ build_vars += "USE_MKLDNN=ON USE_MKLDNN_ACL=ON " \
285
+ "ACL_ROOT_DIR=/acl " \
286
+ "LD_LIBRARY_PATH=/pytorch/build/lib:/acl/build " \
287
+ "ACL_INCLUDE_DIR=/acl/build " \
288
+ "ACL_LIBRARY=/acl/build "
279
289
os .system (f"cd /pytorch; { build_vars } python3 setup.py bdist_wheel" )
290
+
291
+ ## Using AuditWheel on the pip package.
280
292
print ('Repair the wheel' )
281
- pytorch_wheel_name = list_dir ("pytorch/dist" )[0 ]
282
- os .system (f"export LD_LIBRARY_PATH=/pytorch/build/lib:$LD_LIBRARY_PATH; auditwheel repair /pytorch/dist/{ pytorch_wheel_name } " )
293
+ pytorch_wheel_name = list_dir ("/ pytorch/dist" )[0 ]
294
+ os .system (f"LD_LIBRARY_PATH=/pytorch/build/lib:/acl/build auditwheel repair /pytorch/dist/{ pytorch_wheel_name } " )
283
295
print ('replace the original wheel with the repaired one' )
284
296
pytorch_repaired_wheel_name = list_dir ("wheelhouse" )[0 ]
285
297
os .system (f"cp /wheelhouse/{ pytorch_repaired_wheel_name } /pytorch/dist/{ pytorch_wheel_name } " )
286
298
else :
287
299
print ("build pytorch without mkldnn backend" )
288
- os .system (f"cd pytorch ; { build_vars } python3 setup.py bdist_wheel" )
300
+ build_vars += "LD_LIBRARY_PATH=/pytorch/build/lib "
301
+ os .system (f"cd /pytorch; { build_vars } python3 setup.py bdist_wheel" )
289
302
290
303
print ("Deleting build folder" )
291
304
os .system ("cd /pytorch; rm -rf build" )
0 commit comments